Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Embedding generation CLI script for the RAG pipeline. | |
| This script generates embeddings from text chunks and builds search indexes | |
| for the pythermalcomfort RAG chatbot. It supports: | |
| - Reading chunks from JSONL files | |
| - Generating embeddings using BGE encoder with GPU acceleration | |
| - Building FAISS indexes for dense retrieval | |
| - Building BM25 indexes for sparse retrieval | |
| - Publishing artifacts to HuggingFace (optional) | |
| Usage: | |
| # Basic embedding generation | |
| poetry run python scripts/embed.py data/chunks/chunks.jsonl data/embeddings/ | |
| # With HuggingFace publishing | |
| poetry run python scripts/embed.py data/chunks/chunks.jsonl data/embeddings/ \ | |
| --publish | |
| # Custom batch size and model | |
| poetry run python scripts/embed.py data/chunks/chunks.jsonl data/embeddings/ \ | |
| --batch-size 64 --model BAAI/bge-base-en-v1.5 | |
| Output Files: | |
| {output_dir}/ | |
| βββ embeddings.parquet # Embeddings with chunk_id mapping | |
| βββ metadata.json # Model metadata | |
| βββ faiss_index.bin # FAISS index for dense retrieval | |
| βββ faiss_index.bin.ids.json # Chunk ID mapping for FAISS | |
| βββ bm25_index.pkl # BM25 index for sparse retrieval | |
| βββ chunks.parquet # Chunks in parquet format (if --publish) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| # ============================================================================= | |
| # Environment Variable Loading | |
| # ============================================================================= | |
| # Load environment variables from .env file at script startup. | |
| # This ensures HF_TOKEN and other secrets are available before they're needed. | |
| # The .env file should be in the project root directory. | |
| # ============================================================================= | |
| from dotenv import load_dotenv | |
| # Find the project root (parent of scripts/ directory) and load .env from there | |
| _PROJECT_ROOT = Path(__file__).parent.parent | |
| _ENV_FILE = _PROJECT_ROOT / ".env" | |
| if _ENV_FILE.exists(): | |
| load_dotenv(_ENV_FILE) | |
| # Rich is lightweight - import at module level for progress display | |
| from rich.console import Console | |
| from rich.progress import ( | |
| BarColumn, | |
| MofNCompleteColumn, | |
| Progress, | |
| SpinnerColumn, | |
| TaskID, | |
| TaskProgressColumn, | |
| TextColumn, | |
| TimeElapsedColumn, | |
| TimeRemainingColumn, | |
| ) | |
| from rich.table import Table | |
| if TYPE_CHECKING: | |
| from rag_chatbot.chunking.models import Chunk | |
| from rag_chatbot.embeddings import ( | |
| BGEEncoder, | |
| EmbeddingRecord, | |
| ) | |
| # ============================================================================= | |
| # Module Constants | |
| # ============================================================================= | |
| # Default embedding model (BAAI General Embedding - small variant) | |
| DEFAULT_MODEL: str = "BAAI/bge-small-en-v1.5" | |
| # Default batch size for embedding generation | |
| DEFAULT_BATCH_SIZE: int = 32 | |
| # Default embedding dimension for bge-small-en-v1.5 | |
| BGE_SMALL_DIM: int = 384 | |
| # ============================================================================= | |
| # Logging Configuration | |
| # ============================================================================= | |
| # Configure logging to stderr with timestamp | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stderr)], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Rich console for styled output | |
| console = Console() | |
| # ============================================================================= | |
| # Helper Functions | |
| # ============================================================================= | |
| def get_device_info() -> tuple[str, str]: | |
| """Get GPU/device information for reporting. | |
| This function checks for CUDA availability and returns device information | |
| for display in the CLI output. It imports torch lazily to avoid loading | |
| heavy dependencies until needed. | |
| Returns: | |
| ------- | |
| Tuple of (device_type, device_name) where: | |
| - device_type: "cuda" or "cpu" | |
| - device_name: GPU name (e.g., "NVIDIA RTX 4090") or "CPU" | |
| Example: | |
| ------- | |
| >>> device_type, device_name = get_device_info() | |
| >>> print(f"Using {device_name}") | |
| Using NVIDIA RTX 4090 | |
| """ | |
| # Lazy import torch to avoid loading until needed | |
| import torch # type: ignore[import-not-found] | |
| if torch.cuda.is_available(): | |
| device_type = "cuda" | |
| device_name = torch.cuda.get_device_name(0) | |
| else: | |
| device_type = "cpu" | |
| device_name = "CPU" | |
| return device_type, device_name | |
| def load_chunks_from_jsonl(input_path: Path) -> list[Chunk]: | |
| """Load chunks from a JSONL file. | |
| Reads a JSONL file containing chunk data and parses each line into | |
| a Chunk model instance. Text normalization is applied during loading | |
| to fix common PDF extraction artifacts. | |
| Args: | |
| ---- | |
| input_path: Path to the chunks.jsonl file. | |
| Returns: | |
| ------- | |
| List of Chunk objects parsed from the file. | |
| Raises: | |
| ------ | |
| FileNotFoundError: If the input file doesn't exist. | |
| ValueError: If the file is empty or contains invalid JSON. | |
| Example: | |
| ------- | |
| >>> chunks = load_chunks_from_jsonl(Path("data/chunks/chunks.jsonl")) | |
| >>> len(chunks) | |
| 1500 | |
| """ | |
| # Lazy import to avoid loading heavy dependencies at module level | |
| from rag_chatbot.chunking.models import Chunk, TextNormalizer | |
| if not input_path.exists(): | |
| msg = f"Input file not found: {input_path}" | |
| raise FileNotFoundError(msg) | |
| chunks: list[Chunk] = [] | |
| normalizer = TextNormalizer() | |
| with open(input_path, encoding="utf-8") as f: | |
| for line_num, raw_line in enumerate(f, start=1): | |
| # Skip empty lines | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| try: | |
| # Parse JSON line | |
| data = json.loads(line) | |
| # Normalize text before creating chunk | |
| # Apply text normalization to fix OCR artifacts | |
| if "text" in data: | |
| data["text"] = normalizer.normalize(data["text"], is_heading=False) | |
| # Create Chunk instance | |
| chunk = Chunk(**data) | |
| chunks.append(chunk) | |
| except json.JSONDecodeError as exc: | |
| logger.warning("Invalid JSON at line %d: %s", line_num, exc) | |
| continue | |
| except Exception as exc: | |
| logger.warning("Error parsing chunk at line %d: %s", line_num, exc) | |
| continue | |
| if not chunks: | |
| msg = f"No valid chunks found in {input_path}" | |
| raise ValueError(msg) | |
| return chunks | |
| def create_embedding_records( | |
| chunks: list[Chunk], | |
| encoder: BGEEncoder, | |
| batch_size: int, | |
| progress: Progress, | |
| task_id: TaskID, | |
| ) -> list[EmbeddingRecord]: | |
| """Generate embeddings for chunks and create EmbeddingRecord objects. | |
| This function encodes all chunk texts using the BGE encoder and creates | |
| EmbeddingRecord instances with chunk_id, chunk_hash, and embedding data. | |
| Progress is reported through the Rich progress bar. | |
| Args: | |
| ---- | |
| chunks: List of Chunk objects to embed. | |
| encoder: BGEEncoder instance for generating embeddings. | |
| batch_size: Number of chunks to process per batch. | |
| progress: Rich Progress instance for progress tracking. | |
| task_id: Task ID for the progress bar. | |
| Returns: | |
| ------- | |
| List of EmbeddingRecord objects with generated embeddings. | |
| Example: | |
| ------- | |
| >>> records = create_embedding_records(chunks, encoder, 32, progress, task_id) | |
| >>> len(records) == len(chunks) | |
| True | |
| """ | |
| # Lazy import | |
| from rag_chatbot.embeddings import EmbeddingRecord | |
| # Extract texts for embedding | |
| texts = [chunk.text for chunk in chunks] | |
| # Track progress through callback | |
| def progress_callback(current_batch: int, _total_batches: int) -> None: | |
| """Update progress bar after each batch.""" | |
| progress.update(task_id, completed=current_batch) | |
| # Generate embeddings with progress tracking | |
| embeddings = encoder.encode( | |
| texts=texts, | |
| batch_size=batch_size, | |
| show_progress=False, # We use our own progress bar | |
| progress_callback=progress_callback, | |
| ) | |
| # Create EmbeddingRecord for each chunk | |
| records: list[EmbeddingRecord] = [] | |
| for idx, chunk in enumerate(chunks): | |
| record = EmbeddingRecord( | |
| chunk_id=chunk.chunk_id, | |
| chunk_hash=chunk.chunk_hash, | |
| embedding=embeddings[idx].tolist(), | |
| ) | |
| records.append(record) | |
| return records | |
| def build_indexes( | |
| output_dir: Path, | |
| chunks: list[Chunk], | |
| progress: Progress, | |
| ) -> tuple[float, float]: | |
| """Build FAISS and BM25 indexes from embeddings and chunks. | |
| This function builds both dense (FAISS) and sparse (BM25) indexes | |
| for hybrid retrieval. The FAISS index is built from the embeddings | |
| parquet file, while BM25 is built from chunk texts. | |
| Args: | |
| ---- | |
| output_dir: Directory containing embeddings.parquet and for saving indexes. | |
| chunks: List of Chunk objects for BM25 indexing. | |
| progress: Rich Progress instance for progress tracking. | |
| Returns: | |
| ------- | |
| Tuple of (faiss_build_time, bm25_build_time) in seconds. | |
| Example: | |
| ------- | |
| >>> faiss_time, bm25_time = build_indexes(output_dir, chunks, progress) | |
| >>> print(f"FAISS: {faiss_time:.2f}s, BM25: {bm25_time:.2f}s") | |
| """ | |
| # Lazy imports | |
| from rag_chatbot.embeddings import BM25IndexBuilder, FAISSIndexBuilder | |
| embeddings_path = output_dir / "embeddings.parquet" | |
| # Build FAISS index | |
| faiss_task = progress.add_task("[cyan]Building FAISS index...", total=1) | |
| faiss_start = time.perf_counter() | |
| faiss_builder = FAISSIndexBuilder() | |
| faiss_index = faiss_builder.build_from_parquet(embeddings_path) | |
| faiss_builder.save_index(faiss_index, output_dir / "faiss_index.bin") | |
| faiss_time = time.perf_counter() - faiss_start | |
| progress.update(faiss_task, completed=1) | |
| # Build BM25 index | |
| bm25_task = progress.add_task("[cyan]Building BM25 index...", total=1) | |
| bm25_start = time.perf_counter() | |
| bm25_builder = BM25IndexBuilder() | |
| bm25_index, chunk_ids = bm25_builder.build_from_chunks(chunks) | |
| bm25_builder.save_index(bm25_index, chunk_ids, output_dir / "bm25_index.pkl") | |
| bm25_time = time.perf_counter() - bm25_start | |
| progress.update(bm25_task, completed=1) | |
| return faiss_time, bm25_time | |
| def publish_to_huggingface( | |
| output_dir: Path, | |
| chunks: list[Chunk], | |
| model_name: str, | |
| embedding_dim: int, | |
| progress: Progress, | |
| ) -> str: | |
| """Publish all artifacts to HuggingFace dataset repository. | |
| This function handles the complete publishing workflow: | |
| 1. Saves chunks to parquet format | |
| 2. Generates source manifest | |
| 3. Authenticates with HuggingFace | |
| 4. Uploads all artifacts | |
| Args: | |
| ---- | |
| output_dir: Directory containing artifacts to publish. | |
| chunks: List of Chunk objects for chunks.parquet. | |
| model_name: Name of the embedding model used. | |
| embedding_dim: Dimension of embeddings. | |
| progress: Rich Progress instance for progress tracking. | |
| Returns: | |
| ------- | |
| URL of the published HuggingFace dataset. | |
| Raises: | |
| ------ | |
| ValueError: If HF_TOKEN is not set. | |
| RuntimeError: If publishing fails. | |
| Example: | |
| ------- | |
| >>> url = publish_to_huggingface( | |
| ... output_dir, chunks, "BAAI/bge-small-en-v1.5", 384, progress | |
| ... ) | |
| >>> print(url) | |
| 'https://huggingface.co/datasets/sadickam/pytherm_index' | |
| """ | |
| # Lazy imports | |
| from rag_chatbot.embeddings import HuggingFacePublisher, PublisherConfig | |
| publish_task = progress.add_task("[cyan]Publishing to HuggingFace...", total=4) | |
| # Step 1: Save chunks to parquet | |
| config = PublisherConfig() | |
| publisher = HuggingFacePublisher(config) | |
| publisher.save_chunks_parquet(chunks, output_dir) | |
| progress.update(publish_task, advance=1) | |
| # Step 2: Generate source manifest | |
| manifest = publisher.generate_source_manifest( | |
| source_files=[], # Source files not tracked in this context | |
| total_chunks=len(chunks), | |
| total_embeddings=len(chunks), | |
| ) | |
| progress.update(publish_task, advance=1) | |
| # Step 3: Authenticate | |
| publisher.authenticate() | |
| progress.update(publish_task, advance=1) | |
| # Step 4: Publish all artifacts | |
| dataset_url = publisher.publish( | |
| artifacts_dir=output_dir, | |
| manifest=manifest, | |
| model_name=model_name, | |
| embedding_dimension=embedding_dim, | |
| ) | |
| progress.update(publish_task, advance=1) | |
| return dataset_url | |
| def print_statistics( # noqa: PLR0913 | |
| total_chunks: int, | |
| total_time: float, | |
| embedding_time: float, | |
| faiss_time: float, | |
| bm25_time: float, | |
| device_name: str, | |
| model_name: str, | |
| output_dir: Path, | |
| dataset_url: str | None = None, | |
| ) -> None: | |
| """Print final statistics table using Rich. | |
| Displays a formatted table with embedding statistics including: | |
| - Total chunks processed | |
| - Time breakdowns (embedding, indexing) | |
| - Throughput metrics | |
| - Device information | |
| - Output file sizes | |
| Args: | |
| ---- | |
| total_chunks: Number of chunks embedded. | |
| total_time: Total elapsed time in seconds. | |
| embedding_time: Time spent on embedding generation. | |
| faiss_time: Time spent building FAISS index. | |
| bm25_time: Time spent building BM25 index. | |
| device_name: Name of the device used (GPU name or "CPU"). | |
| model_name: Name of the embedding model. | |
| output_dir: Directory where outputs were saved. | |
| dataset_url: Optional URL of published HuggingFace dataset. | |
| """ | |
| # Calculate throughput | |
| throughput = total_chunks / embedding_time if embedding_time > 0 else 0 | |
| # Create statistics table | |
| table = Table( | |
| title="Embedding Statistics", show_header=True, header_style="bold cyan" | |
| ) | |
| table.add_column("Metric", style="dim", width=25) | |
| table.add_column("Value", justify="right") | |
| # Add rows | |
| table.add_row("Total Chunks", f"{total_chunks:,}") | |
| table.add_row("Model", model_name) | |
| table.add_row("Device", device_name) | |
| table.add_row("", "") # Separator | |
| table.add_row("Embedding Time", f"{embedding_time:.2f}s") | |
| table.add_row("FAISS Build Time", f"{faiss_time:.2f}s") | |
| table.add_row("BM25 Build Time", f"{bm25_time:.2f}s") | |
| table.add_row("Total Time", f"{total_time:.2f}s") | |
| table.add_row("", "") # Separator | |
| table.add_row("Throughput", f"{throughput:.1f} chunks/sec") | |
| # Add file sizes | |
| embeddings_file = output_dir / "embeddings.parquet" | |
| faiss_file = output_dir / "faiss_index.bin" | |
| bm25_file = output_dir / "bm25_index.pkl" | |
| if embeddings_file.exists(): | |
| size_mb = embeddings_file.stat().st_size / (1024 * 1024) | |
| table.add_row("Embeddings Size", f"{size_mb:.2f} MB") | |
| if faiss_file.exists(): | |
| size_mb = faiss_file.stat().st_size / (1024 * 1024) | |
| table.add_row("FAISS Index Size", f"{size_mb:.2f} MB") | |
| if bm25_file.exists(): | |
| size_mb = bm25_file.stat().st_size / (1024 * 1024) | |
| table.add_row("BM25 Index Size", f"{size_mb:.2f} MB") | |
| # Add dataset URL if published | |
| if dataset_url: | |
| table.add_row("", "") # Separator | |
| table.add_row("Published URL", dataset_url) | |
| console.print() | |
| console.print(table) | |
| console.print() | |
| def parse_args() -> argparse.Namespace: | |
| """Parse command line arguments. | |
| Returns | |
| ------- | |
| Parsed argument namespace with input_path, output_dir, and options. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Generate embeddings and build indexes for RAG retrieval.", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| poetry run python scripts/embed.py chunks.jsonl output/ | |
| poetry run python scripts/embed.py chunks.jsonl output/ --publish | |
| poetry run python scripts/embed.py chunks.jsonl output/ --batch-size 64 | |
| """, | |
| ) | |
| parser.add_argument( | |
| "input_path", | |
| type=Path, | |
| help="Path to chunks.jsonl file containing text chunks", | |
| ) | |
| parser.add_argument( | |
| "output_dir", | |
| type=Path, | |
| help="Directory to save embeddings and indexes", | |
| ) | |
| parser.add_argument( | |
| "--publish", | |
| action="store_true", | |
| help="Publish artifacts to HuggingFace after embedding", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=DEFAULT_BATCH_SIZE, | |
| help=f"Batch size for embedding generation (default: {DEFAULT_BATCH_SIZE})", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default=DEFAULT_MODEL, | |
| help=f"Embedding model name (default: {DEFAULT_MODEL})", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="auto", | |
| choices=["cpu", "cuda", "auto"], | |
| help="Device to use for embedding (default: auto)", | |
| ) | |
| return parser.parse_args() | |
| def main() -> int: | |
| """Run the embedding generation pipeline. | |
| This is the main entry point for the embed.py CLI script. It orchestrates: | |
| 1. Loading chunks from JSONL | |
| 2. Initializing the encoder with GPU if available | |
| 3. Generating embeddings with progress tracking | |
| 4. Saving embeddings to parquet storage | |
| 5. Building FAISS and BM25 indexes | |
| 6. Optionally publishing to HuggingFace | |
| Returns | |
| ------- | |
| Exit code (0 for success, 1 for error). | |
| """ | |
| # Parse command line arguments | |
| args = parse_args() | |
| # Track total time | |
| start_time = time.perf_counter() | |
| try: | |
| # ================================================================= | |
| # Step 1: Display GPU/Device Information | |
| # ================================================================= | |
| console.print("\n[bold cyan]Embedding Generation Pipeline[/bold cyan]\n") | |
| device_type, device_name = get_device_info() | |
| device = args.device if args.device != "auto" else None | |
| console.print(f"[green]Device:[/green] {device_name}") | |
| console.print(f"[green]Model:[/green] {args.model}") | |
| console.print(f"[green]Batch Size:[/green] {args.batch_size}") | |
| console.print() | |
| # ================================================================= | |
| # Step 2: Load Chunks from JSONL | |
| # ================================================================= | |
| console.print(f"[cyan]Loading chunks from {args.input_path}...[/cyan]") | |
| chunks = load_chunks_from_jsonl(args.input_path) | |
| console.print(f"[green]Loaded {len(chunks):,} chunks[/green]\n") | |
| # Handle empty input gracefully | |
| if not chunks: | |
| console.print("[yellow]Warning: No chunks to process. Exiting.[/yellow]") | |
| return 0 | |
| # ================================================================= | |
| # Step 3: Initialize Encoder and Storage | |
| # ================================================================= | |
| # Lazy imports for heavy dependencies | |
| from rag_chatbot.embeddings import ( | |
| BGEEncoder, | |
| EmbeddingBatch, | |
| EmbeddingStorage, | |
| ) | |
| # Create output directory | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Initialize encoder with configured model and device | |
| encoder = BGEEncoder( | |
| model_name=args.model, | |
| device=device, | |
| normalize_text=False, # Already normalized during chunk loading | |
| ) | |
| # Initialize storage | |
| storage = EmbeddingStorage(args.output_dir) | |
| # ================================================================= | |
| # Step 4: Generate Embeddings with Progress Tracking | |
| # ================================================================= | |
| # Calculate total batches for progress bar | |
| import math | |
| total_batches = math.ceil(len(chunks) / args.batch_size) | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TaskProgressColumn(), | |
| MofNCompleteColumn(), | |
| TimeElapsedColumn(), | |
| TimeRemainingColumn(), | |
| console=console, | |
| ) as progress: | |
| # Embedding task | |
| embed_task = progress.add_task( | |
| "[cyan]Embedding chunks...", | |
| total=total_batches, | |
| ) | |
| embedding_start = time.perf_counter() | |
| # Generate embeddings | |
| records = create_embedding_records( | |
| chunks=chunks, | |
| encoder=encoder, | |
| batch_size=args.batch_size, | |
| progress=progress, | |
| task_id=embed_task, | |
| ) | |
| embedding_time = time.perf_counter() - embedding_start | |
| # Complete the embedding task | |
| progress.update(embed_task, completed=total_batches) | |
| # ================================================================= | |
| # Step 5: Save Embeddings to Storage | |
| # ================================================================= | |
| save_task = progress.add_task("[cyan]Saving embeddings...", total=1) | |
| # Create EmbeddingBatch | |
| batch = EmbeddingBatch( | |
| model_name=args.model, | |
| dimension=encoder.embedding_dim, | |
| dtype="float16", | |
| records=records, | |
| ) | |
| # Save to storage | |
| storage.save(batch) | |
| progress.update(save_task, completed=1) | |
| # ================================================================= | |
| # Step 6: Build FAISS and BM25 Indexes | |
| # ================================================================= | |
| faiss_time, bm25_time = build_indexes( | |
| output_dir=args.output_dir, | |
| chunks=chunks, | |
| progress=progress, | |
| ) | |
| # ================================================================= | |
| # Step 7: Publish to HuggingFace (if requested) | |
| # ================================================================= | |
| dataset_url: str | None = None | |
| if args.publish: | |
| dataset_url = publish_to_huggingface( | |
| output_dir=args.output_dir, | |
| chunks=chunks, | |
| model_name=args.model, | |
| embedding_dim=encoder.embedding_dim, | |
| progress=progress, | |
| ) | |
| # ================================================================= | |
| # Step 8: Print Statistics | |
| # ================================================================= | |
| total_time = time.perf_counter() - start_time | |
| print_statistics( | |
| total_chunks=len(chunks), | |
| total_time=total_time, | |
| embedding_time=embedding_time, | |
| faiss_time=faiss_time, | |
| bm25_time=bm25_time, | |
| device_name=device_name, | |
| model_name=args.model, | |
| output_dir=args.output_dir, | |
| dataset_url=dataset_url, | |
| ) | |
| except FileNotFoundError as exc: | |
| console.print(f"[bold red]Error:[/bold red] {exc}") | |
| return 1 | |
| except ValueError as exc: | |
| console.print(f"[bold red]Error:[/bold red] {exc}") | |
| return 1 | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]Interrupted by user[/yellow]") | |
| return 1 | |
| except Exception as exc: | |
| console.print(f"[bold red]Unexpected error:[/bold red] {exc}") | |
| logger.exception("Unexpected error during embedding generation") | |
| return 1 | |
| else: | |
| console.print("[bold green]Embedding generation complete![/bold green]\n") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |