|
import asyncio |
|
import json |
|
import logging |
|
import os |
|
import pickle |
|
|
|
import chromadb |
|
import logfire |
|
from custom_retriever import CustomRetriever |
|
from dotenv import load_dotenv |
|
from llama_index.core import Document, SimpleKeywordTableIndex, VectorStoreIndex |
|
from llama_index.core.ingestion import IngestionPipeline |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.retrievers import ( |
|
KeywordTableSimpleRetriever, |
|
VectorIndexRetriever, |
|
) |
|
from llama_index.core.schema import NodeWithScore, QueryBundle |
|
from llama_index.embeddings.cohere import CohereEmbedding |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
from utils import init_mongo_db |
|
|
|
load_dotenv() |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
logging.getLogger("httpx").setLevel(logging.WARNING) |
|
logfire.configure() |
|
|
|
|
|
if not os.path.exists("data/chroma-db-all_sources"): |
|
|
|
|
|
logfire.warn( |
|
f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub" |
|
) |
|
from huggingface_hub import snapshot_download |
|
|
|
snapshot_download( |
|
repo_id="towardsai-buster/ai-tutor-vector-db", |
|
local_dir="data", |
|
repo_type="dataset", |
|
) |
|
logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'") |
|
|
|
|
|
def create_docs(input_file: str) -> list[Document]: |
|
with open(input_file, "r") as f: |
|
documents = [] |
|
for line in f: |
|
data = json.loads(line) |
|
documents.append( |
|
Document( |
|
doc_id=data["doc_id"], |
|
text=data["content"], |
|
metadata={ |
|
"url": data["url"], |
|
"title": data["name"], |
|
"tokens": data["tokens"], |
|
"retrieve_doc": data["retrieve_doc"], |
|
"source": data["source"], |
|
}, |
|
excluded_llm_metadata_keys=[ |
|
"title", |
|
"tokens", |
|
"retrieve_doc", |
|
"source", |
|
], |
|
excluded_embed_metadata_keys=[ |
|
"url", |
|
"tokens", |
|
"retrieve_doc", |
|
"source", |
|
], |
|
) |
|
) |
|
return documents |
|
|
|
|
|
def setup_database(db_collection, dict_file_name): |
|
db = chromadb.PersistentClient(path=f"data/{db_collection}") |
|
chroma_collection = db.get_or_create_collection(db_collection) |
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
embed_model = CohereEmbedding( |
|
api_key=os.environ["COHERE_API_KEY"], |
|
model_name="embed-english-v3.0", |
|
input_type="search_query", |
|
) |
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store=vector_store, |
|
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], |
|
show_progress=True, |
|
use_async=True, |
|
) |
|
vector_retriever = VectorIndexRetriever( |
|
index=index, |
|
similarity_top_k=15, |
|
embed_model=embed_model, |
|
use_async=True, |
|
) |
|
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f: |
|
document_dict = pickle.load(f) |
|
|
|
with open("data/keyword_retriever_sync.pkl", "rb") as f: |
|
keyword_retriever: KeywordTableSimpleRetriever = pickle.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return CustomRetriever(vector_retriever, document_dict, keyword_retriever, "OR") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_retriever_all_sources: CustomRetriever = setup_database( |
|
"chroma-db-all_sources", |
|
"document_dict_all_sources.pkl", |
|
) |
|
|
|
|
|
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) |
|
MONGODB_URI = os.getenv("MONGODB_URI") |
|
|
|
AVAILABLE_SOURCES_UI = [ |
|
"Transformers Docs", |
|
"PEFT Docs", |
|
"TRL Docs", |
|
"LlamaIndex Docs", |
|
"LangChain Docs", |
|
"OpenAI Cookbooks", |
|
"Towards AI Blog", |
|
|
|
|
|
] |
|
|
|
AVAILABLE_SOURCES = [ |
|
"transformers", |
|
"peft", |
|
"trl", |
|
"llama_index", |
|
"langchain", |
|
"openai_cookbooks", |
|
"tai_blog", |
|
|
|
|
|
] |
|
|
|
mongo_db = ( |
|
init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster") |
|
if MONGODB_URI |
|
else logfire.warn("No mongodb uri found, you will not be able to save data.") |
|
) |
|
|
|
__all__ = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
"custom_retriever_all_sources", |
|
"mongo_db", |
|
"CONCURRENCY_COUNT", |
|
"AVAILABLE_SOURCES_UI", |
|
"AVAILABLE_SOURCES", |
|
] |
|
|