|
import logging |
|
import os |
|
import pickle |
|
|
|
import chromadb |
|
import logfire |
|
from custom_retriever import CustomRetriever |
|
from dotenv import load_dotenv |
|
from llama_index.core import VectorStoreIndex |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.retrievers import VectorIndexRetriever |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
from utils import init_mongo_db |
|
|
|
load_dotenv() |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
logging.getLogger("httpx").setLevel(logging.WARNING) |
|
logfire.configure() |
|
|
|
|
|
if not os.path.exists("data/chroma-db-transformers"): |
|
|
|
|
|
logfire.warn( |
|
f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub" |
|
) |
|
from huggingface_hub import snapshot_download |
|
|
|
snapshot_download( |
|
repo_id="towardsai-buster/ai-tutor-vector-db", |
|
local_dir="data", |
|
repo_type="dataset", |
|
) |
|
logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'") |
|
|
|
|
|
def setup_database(db_collection, dict_file_name): |
|
db = chromadb.PersistentClient(path=f"data/{db_collection}") |
|
chroma_collection = db.get_or_create_collection(db_collection) |
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store=vector_store, |
|
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"), |
|
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)], |
|
show_progress=True, |
|
use_async=True, |
|
) |
|
vector_retriever = VectorIndexRetriever( |
|
index=index, |
|
similarity_top_k=10, |
|
use_async=True, |
|
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"), |
|
) |
|
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f: |
|
document_dict = pickle.load(f) |
|
|
|
return CustomRetriever(vector_retriever, document_dict) |
|
|
|
|
|
|
|
custom_retriever_transformers = setup_database( |
|
"chroma-db-transformers", |
|
"document_dict_transformers.pkl", |
|
) |
|
custom_retriever_peft = setup_database("chroma-db-peft", "document_dict_peft.pkl") |
|
custom_retriever_trl = setup_database("chroma-db-trl", "document_dict_trl.pkl") |
|
custom_retriever_llama_index = setup_database( |
|
"chroma-db-llama_index", |
|
"document_dict_llama_index.pkl", |
|
) |
|
custom_retriever_openai_cookbooks = setup_database( |
|
"chroma-db-openai_cookbooks", |
|
"document_dict_openai_cookbooks.pkl", |
|
) |
|
custom_retriever_langchain = setup_database( |
|
"chroma-db-langchain", |
|
"document_dict_langchain.pkl", |
|
) |
|
|
|
|
|
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) |
|
MONGODB_URI = os.getenv("MONGODB_URI") |
|
|
|
AVAILABLE_SOURCES_UI = [ |
|
"Transformers Docs", |
|
"PEFT Docs", |
|
"TRL Docs", |
|
"LlamaIndex Docs", |
|
"LangChain Docs", |
|
"OpenAI Cookbooks", |
|
|
|
|
|
] |
|
|
|
AVAILABLE_SOURCES = [ |
|
"transformers", |
|
"peft", |
|
"trl", |
|
"llama_index", |
|
"langchain", |
|
"openai_cookbooks", |
|
|
|
|
|
] |
|
|
|
mongo_db = ( |
|
init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster") |
|
if MONGODB_URI |
|
else logfire.warn("No mongodb uri found, you will not be able to save data.") |
|
) |
|
|
|
__all__ = [ |
|
"custom_retriever_transformers", |
|
"custom_retriever_peft", |
|
"custom_retriever_trl", |
|
"custom_retriever_llama_index", |
|
"custom_retriever_openai_cookbooks", |
|
"custom_retriever_langchain", |
|
"mongo_db", |
|
"CONCURRENCY_COUNT", |
|
"AVAILABLE_SOURCES_UI", |
|
"AVAILABLE_SOURCES", |
|
] |
|
|