Omar Solano
remove unused AsyncKeywordTableSimpleRetriever import from setup.py
394e7a8
import asyncio
import json
import logging
import os
import pickle
import chromadb
import logfire
from custom_retriever import CustomRetriever
from dotenv import load_dotenv
from llama_index.core import Document, SimpleKeywordTableIndex, VectorStoreIndex
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.retrievers import (
KeywordTableSimpleRetriever,
VectorIndexRetriever,
)
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from utils import init_mongo_db
load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
logfire.configure()
if not os.path.exists("data/chroma-db-all_sources"):
# Download the vector database from the Hugging Face Hub if it doesn't exist locally
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
logfire.warn(
f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub"
)
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="towardsai-buster/ai-tutor-vector-db",
local_dir="data",
repo_type="dataset",
)
logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'")
def create_docs(input_file: str) -> list[Document]:
with open(input_file, "r") as f:
documents = []
for line in f:
data = json.loads(line)
documents.append(
Document(
doc_id=data["doc_id"],
text=data["content"],
metadata={ # type: ignore
"url": data["url"],
"title": data["name"],
"tokens": data["tokens"],
"retrieve_doc": data["retrieve_doc"],
"source": data["source"],
},
excluded_llm_metadata_keys=[
"title",
"tokens",
"retrieve_doc",
"source",
],
excluded_embed_metadata_keys=[
"url",
"tokens",
"retrieve_doc",
"source",
],
)
)
return documents
def setup_database(db_collection, dict_file_name):
db = chromadb.PersistentClient(path=f"data/{db_collection}")
chroma_collection = db.get_or_create_collection(db_collection)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
embed_model = CohereEmbedding(
api_key=os.environ["COHERE_API_KEY"],
model_name="embed-english-v3.0",
input_type="search_query",
)
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
show_progress=True,
use_async=True,
)
vector_retriever = VectorIndexRetriever(
index=index,
similarity_top_k=15,
embed_model=embed_model,
use_async=True,
)
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
document_dict = pickle.load(f)
with open("data/keyword_retriever_sync.pkl", "rb") as f:
keyword_retriever: KeywordTableSimpleRetriever = pickle.load(f)
# # Creating the keyword index and retriever
# logfire.info("Creating nodes from documents")
# documents = create_docs("data/all_sources_data.jsonl")
# pipeline = IngestionPipeline(
# transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)]
# )
# all_nodes = pipeline.run(documents=documents, show_progress=True)
# # with open("data/all_nodes.pkl", "wb") as f:
# # pickle.dump(all_nodes, f)
# # all_nodes = pickle.load(open("data/all_nodes.pkl", "rb"))
# logfire.info(f"Number of nodes: {len(all_nodes)}")
# keyword_index = SimpleKeywordTableIndex(
# nodes=all_nodes, max_keywords_per_chunk=10, show_progress=True, use_async=False
# )
# # with open("data/keyword_index.pkl", "wb") as f:
# # pickle.dump(keyword_index, f)
# # keyword_index = pickle.load(open("data/keyword_index.pkl", "rb"))
# logfire.info("Creating keyword retriever")
# keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index)
# with open("data/keyword_retriever_sync.pkl", "wb") as f:
# pickle.dump(keyword_retriever, f)
return CustomRetriever(vector_retriever, document_dict, keyword_retriever, "OR")
# Setup retrievers
# custom_retriever_transformers: CustomRetriever = setup_database(
# "chroma-db-transformers",
# "document_dict_transformers.pkl",
# )
# custom_retriever_peft: CustomRetriever = setup_database(
# "chroma-db-peft", "document_dict_peft.pkl"
# )
# custom_retriever_trl: CustomRetriever = setup_database(
# "chroma-db-trl", "document_dict_trl.pkl"
# )
# custom_retriever_llama_index: CustomRetriever = setup_database(
# "chroma-db-llama_index",
# "document_dict_llama_index.pkl",
# )
# custom_retriever_openai_cookbooks: CustomRetriever = setup_database(
# "chroma-db-openai_cookbooks",
# "document_dict_openai_cookbooks.pkl",
# )
# custom_retriever_langchain: CustomRetriever = setup_database(
# "chroma-db-langchain",
# "document_dict_langchain.pkl",
# )
custom_retriever_all_sources: CustomRetriever = setup_database(
"chroma-db-all_sources",
"document_dict_all_sources.pkl",
)
# Constants
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
MONGODB_URI = os.getenv("MONGODB_URI")
AVAILABLE_SOURCES_UI = [
"Transformers Docs",
"PEFT Docs",
"TRL Docs",
"LlamaIndex Docs",
"LangChain Docs",
"OpenAI Cookbooks",
"Towards AI Blog",
# "All Sources",
# "RAG Course",
]
AVAILABLE_SOURCES = [
"transformers",
"peft",
"trl",
"llama_index",
"langchain",
"openai_cookbooks",
"tai_blog",
# "all_sources",
# "rag_course",
]
mongo_db = (
init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster")
if MONGODB_URI
else logfire.warn("No mongodb uri found, you will not be able to save data.")
)
__all__ = [
# "custom_retriever_transformers",
# "custom_retriever_peft",
# "custom_retriever_trl",
# "custom_retriever_llama_index",
# "custom_retriever_openai_cookbooks",
# "custom_retriever_langchain",
"custom_retriever_all_sources",
"mongo_db",
"CONCURRENCY_COUNT",
"AVAILABLE_SOURCES_UI",
"AVAILABLE_SOURCES",
]