Spaces:
Running
Running
| """Embed all chunks into ChromaDB in memory-safe mega-batches. | |
| Processes chunks in mega-batches (default 10K) to avoid OOM: | |
| encode batch β store in ChromaDB β free memory β next batch | |
| Usage: | |
| python scripts/embed_chunks.py # defaults | |
| python scripts/embed_chunks.py --mega-batch 5000 # smaller mega-batches | |
| python scripts/embed_chunks.py --encode-batch 256 # bigger GPU batches | |
| """ | |
| import argparse | |
| import gc | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.config import get_config | |
| from src.storage.chroma_store import ChromaStore | |
| from src.storage.sqlite_db import SQLiteDB | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Embed chunks into ChromaDB") | |
| parser.add_argument("--mega-batch", type=int, default=10000, | |
| help="Chunks per mega-batch (encode+store cycle)") | |
| parser.add_argument("--encode-batch", type=int, default=128, | |
| help="GPU encoding batch size") | |
| parser.add_argument("--chroma-batch", type=int, default=500, | |
| help="ChromaDB insertion batch size") | |
| parser.add_argument("--no-reset", action="store_true", | |
| help="Don't reset ChromaDB (resume mode)") | |
| args = parser.parse_args() | |
| config = get_config() | |
| db = SQLiteDB(config.sqlite_db_path) | |
| chroma = ChromaStore(config.chroma_db_path) | |
| # Reset ChromaDB unless resuming | |
| if not args.no_reset: | |
| logger.info("Resetting ChromaDB collection...") | |
| chroma.reset() | |
| # Load all chunks from DB | |
| logger.info("Loading chunks from SQLite...") | |
| all_chunks = db.get_all_chunks() | |
| total = len(all_chunks) | |
| logger.info("Total chunks: %d", total) | |
| # Load model | |
| logger.info("Loading embedding model...") | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer(config.embedding_model) | |
| dim = model.get_sentence_embedding_dimension() | |
| logger.info("Model loaded on %s (dim=%d)", model.device, dim) | |
| # Process in mega-batches | |
| start_time = time.time() | |
| total_stored = 0 | |
| for mega_start in range(0, total, args.mega_batch): | |
| mega_end = min(mega_start + args.mega_batch, total) | |
| batch_chunks = all_chunks[mega_start:mega_end] | |
| batch_size = len(batch_chunks) | |
| logger.info("=== Mega-batch %d-%d / %d (%d chunks) ===", | |
| mega_start, mega_end, total, batch_size) | |
| # Extract texts | |
| texts = [c["chunk_text"] for c in batch_chunks] | |
| # Encode on GPU | |
| t0 = time.time() | |
| embeddings = model.encode( | |
| texts, | |
| batch_size=args.encode_batch, | |
| show_progress_bar=True, | |
| normalize_embeddings=True, | |
| ) | |
| encode_time = time.time() - t0 | |
| logger.info("Encoded %d chunks in %.1fs (%.0f chunks/sec)", | |
| batch_size, encode_time, batch_size / encode_time) | |
| # Prepare ChromaDB data (no documents β text lives in SQLite only) | |
| ids = [] | |
| metadatas = [] | |
| emb_list = [] | |
| for i, chunk in enumerate(batch_chunks): | |
| chunk_id = str(chunk.get("id", f"{chunk['paper_id']}_chunk_{chunk['chunk_index']}")) | |
| ids.append(chunk_id) | |
| emb_list.append(embeddings[i].tolist()) | |
| metadatas.append({ | |
| "paper_id": chunk["paper_id"], | |
| "chunk_type": chunk.get("chunk_type", "unknown"), | |
| "chunk_index": chunk.get("chunk_index", 0), | |
| "year": chunk.get("year", 0), | |
| "venue": chunk.get("venue", ""), | |
| "title": chunk.get("title", ""), | |
| }) | |
| # Store in ChromaDB in sub-batches | |
| t0 = time.time() | |
| chroma.add_embeddings( | |
| ids=ids, | |
| embeddings=emb_list, | |
| metadatas=metadatas, | |
| batch_size=args.chroma_batch, | |
| ) | |
| store_time = time.time() - t0 | |
| total_stored += batch_size | |
| logger.info("Stored in ChromaDB in %.1fs. Total stored: %d/%d", | |
| store_time, total_stored, total) | |
| # Free memory | |
| del texts, embeddings, ids, metadatas, emb_list, batch_chunks | |
| gc.collect() | |
| elapsed = time.time() - start_time | |
| final_count = chroma.count() | |
| logger.info("=== EMBEDDING COMPLETE ===") | |
| logger.info("Total: %d embeddings in %.1fs (%.1f chunks/sec)", | |
| final_count, elapsed, total / elapsed) | |
| logger.info("ChromaDB count: %d", final_count) | |
| if __name__ == "__main__": | |
| main() | |