Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| RAG Retrieval Utilities for gprMax Documentation | |
| Provides search and retrieval functions for the vector database | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import json | |
| import chromadb | |
| from dataclasses import dataclass | |
| logger = logging.getLogger(__name__) | |
| class SearchResult: | |
| """Container for search results""" | |
| text: str | |
| score: float | |
| metadata: Dict[str, Any] | |
| def __str__(self) -> str: | |
| return f"[Score: {self.score:.3f}] {self.metadata.get('source', 'Unknown')}: {self.text[:100]}..." | |
| # Removed QwenEmbeddingModel class - using ChromaDB's default embedding | |
| class GprMaxRAGRetriever: | |
| """Retriever for gprMax documentation RAG database""" | |
| def __init__(self, db_path: Path = None): | |
| if db_path is None: | |
| db_path = Path(__file__).parent / "chroma_db" | |
| if not db_path.exists(): | |
| raise ValueError(f"Database path {db_path} does not exist. Run generate_db.py first.") | |
| self.db_path = db_path | |
| # Load metadata | |
| metadata_path = db_path / "metadata.json" | |
| if metadata_path.exists(): | |
| with open(metadata_path, 'r') as f: | |
| self.metadata = json.load(f) | |
| else: | |
| self.metadata = {} | |
| # Initialize ChromaDB client | |
| self.client = chromadb.PersistentClient(path=str(db_path)) | |
| # Get collection | |
| self.collection_name = self.metadata.get("collection_name", "gprmax_docs_v1") | |
| try: | |
| print(f"[RAG] Loading collection: {self.collection_name}") | |
| self.collection = self.client.get_collection(self.collection_name) | |
| doc_count = self.collection.count() | |
| print(f"[RAG] Loaded collection: {self.collection_name} with {doc_count} documents") | |
| logger.info(f"Loaded collection: {self.collection_name} with {doc_count} documents") | |
| except Exception as e: | |
| print(f"[RAG] ERROR loading collection: {e}") | |
| raise ValueError(f"Failed to load collection {self.collection_name}: {e}") | |
| def search( | |
| self, | |
| query: str, | |
| k: int = 10, | |
| threshold: float = 0.0, | |
| filter_metadata: Optional[Dict[str, Any]] = None | |
| ) -> List[SearchResult]: | |
| """ | |
| Search for relevant documents | |
| Args: | |
| query: Search query text | |
| k: Number of results to return | |
| threshold: Minimum similarity score threshold | |
| filter_metadata: Optional metadata filters | |
| Returns: | |
| List of SearchResult objects | |
| """ | |
| # Search in ChromaDB (it will generate embeddings automatically) | |
| try: | |
| results = self.collection.query( | |
| query_texts=[query], # Use query_texts instead of query_embeddings | |
| n_results=k, | |
| where=filter_metadata if filter_metadata else None, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| logger.info(f"ChromaDB query returned: {len(results.get('documents', [[]])[0]) if results.get('documents') else 0} results") | |
| except Exception as e: | |
| logger.error(f"ChromaDB query failed: {e}") | |
| raise | |
| # Convert to SearchResult objects | |
| search_results = [] | |
| if results["documents"] and results["documents"][0]: | |
| for doc, meta, dist in zip( | |
| results["documents"][0], | |
| results["metadatas"][0], | |
| results["distances"][0] | |
| ): | |
| # Convert distance to similarity score (1 - normalized_distance) | |
| score = 1.0 - (dist / 2.0) # Assuming cosine distance in [-1, 1] | |
| if score >= threshold: | |
| search_results.append(SearchResult( | |
| text=doc, | |
| score=score, | |
| metadata=meta | |
| )) | |
| return search_results | |
| def get_context( | |
| self, | |
| query: str, | |
| k: int = 3, | |
| max_context_length: int = 2000, | |
| format_as_markdown: bool = True | |
| ) -> str: | |
| """ | |
| Get formatted context for a query | |
| Args: | |
| query: Search query | |
| k: Number of documents to retrieve | |
| max_context_length: Maximum total context length | |
| format_as_markdown: Format output as markdown | |
| Returns: | |
| Formatted context string | |
| """ | |
| results = self.search(query, k=k) | |
| if not results: | |
| return "No relevant documentation found." | |
| context_parts = [] | |
| total_length = 0 | |
| for i, result in enumerate(results, 1): | |
| if total_length >= max_context_length: | |
| break | |
| # Truncate if needed | |
| text = result.text | |
| if total_length + len(text) > max_context_length: | |
| text = text[:max_context_length - total_length] | |
| if format_as_markdown: | |
| source = result.metadata.get("source", "Unknown") | |
| context_parts.append( | |
| f"### Document {i} (Source: {source}, Score: {result.score:.3f})\n" | |
| f"```\n{text}\n```\n" | |
| ) | |
| else: | |
| context_parts.append(text) | |
| total_length += len(text) | |
| return "\n".join(context_parts) | |
| def get_relevant_files(self, query: str, k: int = 5) -> List[str]: | |
| """Get list of relevant source files for a query""" | |
| results = self.search(query, k=k) | |
| # Extract unique source files | |
| sources = set() | |
| for result in results: | |
| source = result.metadata.get("source") | |
| if source: | |
| sources.add(source) | |
| return sorted(list(sources)) | |
| def search_by_file(self, file_pattern: str, k: int = 10) -> List[SearchResult]: | |
| """Search for documents from specific files""" | |
| # This would need ChromaDB's where clause with pattern matching | |
| # For now, we do a broad search and filter | |
| results = self.collection.get( | |
| limit=1000, # Get many results | |
| include=["documents", "metadatas"] | |
| ) | |
| filtered_results = [] | |
| if results["documents"]: | |
| for doc, meta in zip(results["documents"], results["metadatas"]): | |
| source = meta.get("source", "") | |
| if file_pattern.lower() in source.lower(): | |
| filtered_results.append(SearchResult( | |
| text=doc, | |
| score=1.0, # No score for direct retrieval | |
| metadata=meta | |
| )) | |
| if len(filtered_results) >= k: | |
| break | |
| return filtered_results | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get database statistics""" | |
| stats = { | |
| "total_documents": self.collection.count(), | |
| "database_path": str(self.db_path), | |
| "collection_name": self.collection_name, | |
| "embedding_model": self.metadata.get("embedding_model", "Unknown"), | |
| "created_at": self.metadata.get("created_at", "Unknown"), | |
| "chunk_size": self.metadata.get("chunk_size", "Unknown"), | |
| "chunk_overlap": self.metadata.get("chunk_overlap", "Unknown") | |
| } | |
| return stats | |
| def create_retriever(db_path: Optional[Path] = None) -> GprMaxRAGRetriever: | |
| """Factory function to create a retriever instance""" | |
| return GprMaxRAGRetriever(db_path=db_path) | |
| if __name__ == "__main__": | |
| # Example usage | |
| import sys | |
| if len(sys.argv) > 1: | |
| query = " ".join(sys.argv[1:]) | |
| else: | |
| query = "How to create a source in gprMax?" | |
| print(f"Testing retriever with query: '{query}'") | |
| print("-" * 80) | |
| try: | |
| retriever = create_retriever() | |
| # Get stats | |
| stats = retriever.get_stats() | |
| print(f"Database stats: {stats}") | |
| print("-" * 80) | |
| # Search | |
| results = retriever.search(query, k=3) | |
| print(f"Found {len(results)} results:") | |
| for i, result in enumerate(results, 1): | |
| print(f"\n{i}. {result}") | |
| # Get formatted context | |
| print("\n" + "=" * 80) | |
| print("Formatted context:") | |
| print(retriever.get_context(query, k=3)) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| sys.exit(1) |