Omar Solano
add upload to hf script
129499e
raw
history blame
3.85 kB
import logging
import os
import pickle
import chromadb
import logfire
from custom_retriever import CustomRetriever
from dotenv import load_dotenv
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from utils import init_mongo_db
load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
logfire.configure()
if not os.path.exists("data/chroma-db-transformers"):
# Download the vector database from the Hugging Face Hub if it doesn't exist locally
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
logfire.warn(
f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub"
)
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="towardsai-buster/ai-tutor-vector-db",
local_dir="data",
repo_type="dataset",
)
logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'")
def setup_database(db_collection, dict_file_name):
db = chromadb.PersistentClient(path=f"data/{db_collection}")
chroma_collection = db.get_or_create_collection(db_collection)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
show_progress=True,
use_async=True,
)
vector_retriever = VectorIndexRetriever(
index=index,
similarity_top_k=10,
use_async=True,
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
)
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
document_dict = pickle.load(f)
return CustomRetriever(vector_retriever, document_dict)
# Setup retrievers
custom_retriever_transformers = setup_database(
"chroma-db-transformers",
"document_dict_transformers.pkl",
)
custom_retriever_peft = setup_database("chroma-db-peft", "document_dict_peft.pkl")
custom_retriever_trl = setup_database("chroma-db-trl", "document_dict_trl.pkl")
custom_retriever_llama_index = setup_database(
"chroma-db-llama_index",
"document_dict_llama_index.pkl",
)
custom_retriever_openai_cookbooks = setup_database(
"chroma-db-openai_cookbooks",
"document_dict_openai_cookbooks.pkl",
)
custom_retriever_langchain = setup_database(
"chroma-db-langchain",
"document_dict_langchain.pkl",
)
# Constants
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
MONGODB_URI = os.getenv("MONGODB_URI")
AVAILABLE_SOURCES_UI = [
"Transformers Docs",
"PEFT Docs",
"TRL Docs",
"LlamaIndex Docs",
"LangChain Docs",
"OpenAI Cookbooks",
# "Towards AI Blog",
# "RAG Course",
]
AVAILABLE_SOURCES = [
"transformers",
"peft",
"trl",
"llama_index",
"langchain",
"openai_cookbooks",
# "towards_ai_blog",
# "rag_course",
]
mongo_db = (
init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster")
if MONGODB_URI
else logfire.warn("No mongodb uri found, you will not be able to save data.")
)
__all__ = [
"custom_retriever_transformers",
"custom_retriever_peft",
"custom_retriever_trl",
"custom_retriever_llama_index",
"custom_retriever_openai_cookbooks",
"custom_retriever_langchain",
"mongo_db",
"CONCURRENCY_COUNT",
"AVAILABLE_SOURCES_UI",
"AVAILABLE_SOURCES",
]