researchradar / scripts /embed_chunks.py
thearkforyou's picture
Update chroma_store, vector_retriever, embed_chunks for rebuilt ChromaDB
0c39e68 verified
"""Embed all chunks into ChromaDB in memory-safe mega-batches.
Processes chunks in mega-batches (default 10K) to avoid OOM:
encode batch β†’ store in ChromaDB β†’ free memory β†’ next batch
Usage:
python scripts/embed_chunks.py # defaults
python scripts/embed_chunks.py --mega-batch 5000 # smaller mega-batches
python scripts/embed_chunks.py --encode-batch 256 # bigger GPU batches
"""
import argparse
import gc
import logging
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.config import get_config
from src.storage.chroma_store import ChromaStore
from src.storage.sqlite_db import SQLiteDB
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Embed chunks into ChromaDB")
parser.add_argument("--mega-batch", type=int, default=10000,
help="Chunks per mega-batch (encode+store cycle)")
parser.add_argument("--encode-batch", type=int, default=128,
help="GPU encoding batch size")
parser.add_argument("--chroma-batch", type=int, default=500,
help="ChromaDB insertion batch size")
parser.add_argument("--no-reset", action="store_true",
help="Don't reset ChromaDB (resume mode)")
args = parser.parse_args()
config = get_config()
db = SQLiteDB(config.sqlite_db_path)
chroma = ChromaStore(config.chroma_db_path)
# Reset ChromaDB unless resuming
if not args.no_reset:
logger.info("Resetting ChromaDB collection...")
chroma.reset()
# Load all chunks from DB
logger.info("Loading chunks from SQLite...")
all_chunks = db.get_all_chunks()
total = len(all_chunks)
logger.info("Total chunks: %d", total)
# Load model
logger.info("Loading embedding model...")
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(config.embedding_model)
dim = model.get_sentence_embedding_dimension()
logger.info("Model loaded on %s (dim=%d)", model.device, dim)
# Process in mega-batches
start_time = time.time()
total_stored = 0
for mega_start in range(0, total, args.mega_batch):
mega_end = min(mega_start + args.mega_batch, total)
batch_chunks = all_chunks[mega_start:mega_end]
batch_size = len(batch_chunks)
logger.info("=== Mega-batch %d-%d / %d (%d chunks) ===",
mega_start, mega_end, total, batch_size)
# Extract texts
texts = [c["chunk_text"] for c in batch_chunks]
# Encode on GPU
t0 = time.time()
embeddings = model.encode(
texts,
batch_size=args.encode_batch,
show_progress_bar=True,
normalize_embeddings=True,
)
encode_time = time.time() - t0
logger.info("Encoded %d chunks in %.1fs (%.0f chunks/sec)",
batch_size, encode_time, batch_size / encode_time)
# Prepare ChromaDB data (no documents β€” text lives in SQLite only)
ids = []
metadatas = []
emb_list = []
for i, chunk in enumerate(batch_chunks):
chunk_id = str(chunk.get("id", f"{chunk['paper_id']}_chunk_{chunk['chunk_index']}"))
ids.append(chunk_id)
emb_list.append(embeddings[i].tolist())
metadatas.append({
"paper_id": chunk["paper_id"],
"chunk_type": chunk.get("chunk_type", "unknown"),
"chunk_index": chunk.get("chunk_index", 0),
"year": chunk.get("year", 0),
"venue": chunk.get("venue", ""),
"title": chunk.get("title", ""),
})
# Store in ChromaDB in sub-batches
t0 = time.time()
chroma.add_embeddings(
ids=ids,
embeddings=emb_list,
metadatas=metadatas,
batch_size=args.chroma_batch,
)
store_time = time.time() - t0
total_stored += batch_size
logger.info("Stored in ChromaDB in %.1fs. Total stored: %d/%d",
store_time, total_stored, total)
# Free memory
del texts, embeddings, ids, metadatas, emb_list, batch_chunks
gc.collect()
elapsed = time.time() - start_time
final_count = chroma.count()
logger.info("=== EMBEDDING COMPLETE ===")
logger.info("Total: %d embeddings in %.1fs (%.1f chunks/sec)",
final_count, elapsed, total / elapsed)
logger.info("ChromaDB count: %d", final_count)
if __name__ == "__main__":
main()