#!/usr/bin/env python3 # /// script # dependencies = [ # "langchain_community", # "chromadb", # "huggingface_hub", # "langchain_community", # "sentence_transformers", # "pydantic" # ] # /// #!/usr/bin/env python3 """ Query interface for Arista AVD documentation vector database. Provides search and retrieval capabilities. """ import argparse import json from typing import List, Dict, Any, Optional from pathlib import Path import logging from pydantic import BaseModel, Field from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma from langchain.schema import Document logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class EmbeddingConfig(BaseModel): """Configuration for embeddings.""" model_name: str = Field(default="all-MiniLM-L6-v2", description="The name of the HuggingFace model to use") device: str = Field(default="cpu", description="Device to use for embedding generation (cpu or cuda)") normalize_embeddings: bool = Field(default=True, description="Whether to normalize embeddings") class AristaDocumentQuery(BaseModel): """Query interface for Arista AVD documentation.""" persist_directory: str = Field(default="./chroma_db", description="Directory containing the vector store") embedding_config: EmbeddingConfig = Field(default_factory=EmbeddingConfig, description="Configuration for embeddings") # These will be initialized in __init__ embeddings: Any = Field(default=None, exclude=True) vector_store: Any = Field(default=None, exclude=True) class Config: arbitrary_types_allowed = True def __init__(self, **data): super().__init__(**data) self.embeddings = HuggingFaceEmbeddings( model_name=self.embedding_config.model_name, model_kwargs={'device': self.embedding_config.device}, encode_kwargs={'normalize_embeddings': self.embedding_config.normalize_embeddings} ) self.vector_store = self._load_vector_store() def _load_vector_store(self) -> Chroma: """Load the existing vector store.""" try: vector_store = Chroma( persist_directory=self.persist_directory, embedding_function=self.embeddings ) logger.info(f"Loaded vector store from {self.persist_directory}") return vector_store except Exception as e: logger.error(f"Error loading vector store: {e}") raise def similarity_search(self, query: str, k: int = 5, filter_dict: Optional[Dict] = None) -> List[Document]: """Perform similarity search on the vector store.""" try: if filter_dict: results = self.vector_store.similarity_search( query=query, k=k, filter=filter_dict ) else: results = self.vector_store.similarity_search( query=query, k=k ) return results except Exception as e: logger.error(f"Error during similarity search: {e}") return [] def search_by_category(self, query: str, category: str, k: int = 5) -> List[Document]: """Search documents within a specific category.""" filter_dict = {"category": category} return self.similarity_search(query, k=k, filter_dict=filter_dict) def search_by_type(self, query: str, doc_type: str, k: int = 5) -> List[Document]: """Search documents of a specific type (markdown/csv).""" filter_dict = {"type": doc_type} return self.similarity_search(query, k=k, filter_dict=filter_dict) def get_categories(self) -> List[str]: """Get all available categories in the vector store.""" # This is a simplified version - in a real implementation, # you might want to query the metadata directly from ChromaDB categories = [ 'device_configuration', 'fabric_documentation', 'testing', 'netbox_integration', 'arista_cloud_test', 'avd_design', 'api_usage', 'workflow', 'infoblox_integration', 'network_testing', 'general_documentation', 'project_documentation' ] return categories def format_results(self, results: List[Document], verbose: bool = False) -> str: """Format search results for display.""" output = [] for i, doc in enumerate(results, 1): output.append(f"\n{'='*80}") output.append(f"Result {i}:") output.append(f"Source: {doc.metadata.get('source', 'Unknown')}") output.append(f"Category: {doc.metadata.get('category', 'Unknown')}") output.append(f"Type: {doc.metadata.get('type', 'Unknown')}") if doc.metadata.get('type') == 'csv': output.append(f"Columns: {doc.metadata.get('columns', 'Unknown')}") output.append(f"Rows: {doc.metadata.get('rows', 'Unknown')}") output.append(f"\nContent Preview:") content_preview = doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content output.append(content_preview) if verbose: output.append(f"\nFull Content:") output.append(doc.page_content) return "\n".join(output) def export_results(self, results: List[Document], output_file: str) -> None: """Export search results to a JSON file.""" data = [] for doc in results: data.append({ 'content': doc.page_content, 'metadata': doc.metadata }) with open(output_file, 'w') as f: json.dump(data, f, indent=2) logger.info(f"Results exported to {output_file}") def main(): """Main function for command-line interface.""" parser = argparse.ArgumentParser(description="Query Arista AVD documentation vector database") parser.add_argument("query", nargs="?", help="Search query") parser.add_argument("-k", "--top-k", type=int, default=5, help="Number of results to return (default: 5)") parser.add_argument("-c", "--category", help="Filter by category") parser.add_argument("-t", "--type", choices=['markdown', 'csv'], help="Filter by document type") parser.add_argument("-v", "--verbose", action="store_true", help="Show full content") parser.add_argument("-e", "--export", help="Export results to JSON file") parser.add_argument("--list-categories", action="store_true", help="List available categories") args = parser.parse_args() # Initialize query interface query_interface = AristaDocumentQuery() # List categories if requested if args.list_categories: categories = query_interface.get_categories() print("Available categories:") for cat in categories: print(f" - {cat}") return # Ensure query is provided if not listing categories if not args.query: parser.error("Query is required unless using --list-categories") # Perform search if args.category: results = query_interface.search_by_category(args.query, args.category, k=args.top_k) elif args.type: results = query_interface.search_by_type(args.query, args.type, k=args.top_k) else: results = query_interface.similarity_search(args.query, k=args.top_k) # Display results if results: formatted_results = query_interface.format_results(results, verbose=args.verbose) print(formatted_results) # Export if requested if args.export: query_interface.export_results(results, args.export) else: print("No results found.") if __name__ == "__main__": main()