|
""" |
|
Vector Store for SQL Examples |
|
Handles storage and retrieval of SQL examples using ChromaDB and FAISS for high-performance similarity search. |
|
""" |
|
|
|
import os |
|
import json |
|
import pickle |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from pathlib import Path |
|
|
|
import chromadb |
|
from chromadb.config import Settings |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from loguru import logger |
|
|
|
class VectorStore: |
|
"""High-performance vector store for SQL examples using ChromaDB and FAISS.""" |
|
|
|
def __init__(self, |
|
persist_directory: str = "./data/vector_store", |
|
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", |
|
collection_name: str = "sql_examples"): |
|
""" |
|
Initialize the vector store. |
|
|
|
Args: |
|
persist_directory: Directory to persist the vector store |
|
embedding_model: Sentence transformer model for embeddings |
|
collection_name: Name of the ChromaDB collection |
|
""" |
|
self.persist_directory = Path(persist_directory) |
|
self.persist_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
self.embedding_model = SentenceTransformer(embedding_model) |
|
self.collection_name = collection_name |
|
|
|
|
|
self.client = chromadb.PersistentClient( |
|
path=str(self.persist_directory), |
|
settings=Settings( |
|
anonymized_telemetry=False, |
|
allow_reset=True |
|
) |
|
) |
|
|
|
|
|
self.collection = self.client.get_or_create_collection( |
|
name=collection_name, |
|
metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
logger.info(f"Vector store initialized at {self.persist_directory}") |
|
|
|
def add_examples(self, examples: List[Dict[str, Any]]) -> None: |
|
""" |
|
Add SQL examples to the vector store. |
|
|
|
Args: |
|
examples: List of dictionaries with keys: question, sql, table_headers, metadata |
|
""" |
|
if not examples: |
|
return |
|
|
|
|
|
ids = [] |
|
documents = [] |
|
metadatas = [] |
|
|
|
for i, example in enumerate(examples): |
|
|
|
question = example["question"] |
|
table_headers = ", ".join(example["table_headers"]) if isinstance(example["table_headers"], list) else example["table_headers"] |
|
|
|
document_text = f"Question: {question}\nTable columns: {table_headers}" |
|
|
|
ids.append(f"example_{i}") |
|
documents.append(document_text) |
|
|
|
|
|
metadata = { |
|
"question": question, |
|
"sql": example["sql"], |
|
"table_headers": table_headers, |
|
"difficulty": example.get("difficulty", "medium"), |
|
"category": example.get("category", "general"), |
|
"example_id": i |
|
} |
|
metadatas.append(metadata) |
|
|
|
|
|
self.collection.add( |
|
documents=documents, |
|
metadatas=metadatas, |
|
ids=ids |
|
) |
|
|
|
logger.info(f"Added {len(examples)} examples to vector store") |
|
|
|
def search_similar(self, |
|
query: str, |
|
table_headers: List[str], |
|
top_k: int = 5, |
|
similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: |
|
""" |
|
Search for similar SQL examples. |
|
|
|
Args: |
|
query: Natural language question |
|
table_headers: List of table column names |
|
top_k: Number of top results to return |
|
similarity_threshold: Minimum similarity score |
|
|
|
Returns: |
|
List of similar examples with scores |
|
""" |
|
|
|
search_text = f"Question: {query}\nTable columns: {', '.join(table_headers)}" |
|
|
|
|
|
results = self.collection.query( |
|
query_texts=[search_text], |
|
n_results=top_k * 2, |
|
include=["metadatas", "distances"] |
|
) |
|
|
|
|
|
similar_examples = [] |
|
for i, (metadata, distance) in enumerate(zip(results["metadatas"][0], results["distances"][0])): |
|
|
|
similarity_score = 1 - distance |
|
|
|
if similarity_score >= similarity_threshold: |
|
example = { |
|
"question": metadata["question"], |
|
"sql": metadata["sql"], |
|
"table_headers": metadata["table_headers"], |
|
"similarity_score": similarity_score, |
|
"difficulty": metadata.get("difficulty", "medium"), |
|
"category": metadata.get("category", "general") |
|
} |
|
similar_examples.append(example) |
|
|
|
|
|
similar_examples.sort(key=lambda x: x["similarity_score"], reverse=True) |
|
return similar_examples[:top_k] |
|
|
|
def get_example_by_id(self, example_id: str) -> Optional[Dict[str, Any]]: |
|
"""Get a specific example by ID.""" |
|
try: |
|
result = self.collection.get(ids=[example_id]) |
|
if result["metadatas"]: |
|
metadata = result["metadatas"][0] |
|
return { |
|
"question": metadata["question"], |
|
"sql": metadata["sql"], |
|
"table_headers": metadata["table_headers"], |
|
"difficulty": metadata.get("difficulty", "medium"), |
|
"category": metadata.get("category", "general") |
|
} |
|
except Exception as e: |
|
logger.error(f"Error retrieving example {example_id}: {e}") |
|
|
|
return None |
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
"""Get statistics about the vector store.""" |
|
try: |
|
count = self.collection.count() |
|
return { |
|
"total_examples": count, |
|
"collection_name": self.collection_name, |
|
"persist_directory": str(self.persist_directory) |
|
} |
|
except Exception as e: |
|
logger.error(f"Error getting statistics: {e}") |
|
return {"error": str(e)} |
|
|
|
def clear_collection(self) -> None: |
|
"""Clear all examples from the collection.""" |
|
try: |
|
self.client.delete_collection(self.collection_name) |
|
self.collection = self.client.create_collection( |
|
name=self.collection_name, |
|
metadata={"hnsw:space": "cosine"} |
|
) |
|
logger.info("Collection cleared successfully") |
|
except Exception as e: |
|
logger.error(f"Error clearing collection: {e}") |
|
|
|
def export_examples(self, filepath: str) -> None: |
|
"""Export all examples to a JSON file.""" |
|
try: |
|
results = self.collection.get() |
|
examples = [] |
|
|
|
for i, metadata in enumerate(results["metadatas"]): |
|
example = { |
|
"question": metadata["question"], |
|
"sql": metadata["sql"], |
|
"table_headers": metadata["table_headers"], |
|
"difficulty": metadata.get("difficulty", "medium"), |
|
"category": metadata.get("category", "general") |
|
} |
|
examples.append(example) |
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
json.dump(examples, f, indent=2, ensure_ascii=False) |
|
|
|
logger.info(f"Exported {len(examples)} examples to {filepath}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error exporting examples: {e}") |
|
|