|
import os |
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.document_loaders import PyPDFLoader |
|
|
|
from .embeddings import EMBEDDING_MODEL_NAME |
|
from .vectorstore import PERSIST_DIRECTORY, get_vectorstore |
|
|
|
|
|
def load_data(): |
|
docs = parse_data() |
|
embedding_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) |
|
vectorstore = get_vectorstore(embedding_function) |
|
|
|
assert isinstance(vectorstore, Chroma) |
|
vectorstore.from_documents( |
|
docs, embedding_function, persist_directory=PERSIST_DIRECTORY |
|
) |
|
return vectorstore |
|
|
|
|
|
def parse_data(): |
|
docs = [] |
|
for root, dirs, files in os.walk("data"): |
|
for file in files: |
|
if file.endswith(".pdf"): |
|
file_path = os.path.join(root, file) |
|
loader = PyPDFLoader(file_path) |
|
pages = loader.load_and_split() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, chunk_overlap=0 |
|
) |
|
doc_chunks = text_splitter.split_documents(pages) |
|
|
|
for chunk in doc_chunks: |
|
chunk.metadata["name"] = parse_name(chunk.metadata["source"]) |
|
chunk.metadata["domain"] = parse_domain(chunk.metadata["source"]) |
|
chunk.metadata["page_number"] = chunk.metadata["page"] |
|
chunk.metadata["short_name"] = chunk.metadata["name"] |
|
docs.append(chunk) |
|
|
|
return docs |
|
|
|
|
|
def parse_name(source: str) -> str: |
|
return source.split("/")[-1].split(".")[0] |
|
|
|
|
|
def parse_domain(source: str) -> str: |
|
return source.split("/")[1] |
|
|
|
|
|
def clear_index(): |
|
folder = PERSIST_DIRECTORY |
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
try: |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
except Exception as e: |
|
print("Failed to delete %s. Reason: %s" % (file_path, e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
clear_index() |
|
db = load_data() |
|
|
|
query = ( |
|
"He who can bear the misfortune of a nation is called the ruler of the world." |
|
) |
|
docs = db.similarity_search(query) |
|
print(docs) |
|
|