|
""" |
|
Prompt Engine for SQL Generation |
|
Constructs intelligent prompts for SQL generation using retrieved examples and best practices. |
|
""" |
|
|
|
import json |
|
from typing import List, Dict, Any, Optional |
|
from pathlib import Path |
|
from loguru import logger |
|
|
|
class PromptEngine: |
|
"""Intelligent prompt construction for SQL generation.""" |
|
|
|
def __init__(self, prompts_dir: str = "./prompts"): |
|
""" |
|
Initialize the prompt engine. |
|
|
|
Args: |
|
prompts_dir: Directory containing prompt templates |
|
""" |
|
self.prompts_dir = Path(prompts_dir) |
|
self.prompts_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.templates = self._load_prompt_templates() |
|
|
|
|
|
self.default_system_prompt = """You are an expert SQL developer. Your task is to convert natural language questions into accurate SQL queries. |
|
|
|
Key Guidelines: |
|
1. Always use the exact table column names provided |
|
2. Generate standard SQL syntax (compatible with most databases) |
|
3. Use appropriate JOINs when multiple tables are involved |
|
4. Apply proper WHERE clauses for filtering |
|
5. Use GROUP BY for aggregations when needed |
|
6. Ensure queries are efficient and readable |
|
7. Handle edge cases appropriately |
|
|
|
Table Schema: {table_schema} |
|
|
|
Retrieved Examples: |
|
{examples} |
|
|
|
Question: {question} |
|
|
|
Generate the SQL query:""" |
|
|
|
def _load_prompt_templates(self) -> Dict[str, str]: |
|
"""Load prompt templates from files.""" |
|
templates = {} |
|
|
|
|
|
default_templates = { |
|
"sql_generation.txt": self._get_default_sql_prompt(), |
|
"few_shot_examples.txt": self._get_default_few_shot_prompt(), |
|
"error_correction.txt": self._get_default_error_correction_prompt() |
|
} |
|
|
|
for filename, content in default_templates.items(): |
|
template_path = self.prompts_dir / filename |
|
if not template_path.exists(): |
|
with open(template_path, 'w', encoding='utf-8') as f: |
|
f.write(content) |
|
logger.info(f"Created default template: {filename}") |
|
|
|
|
|
with open(template_path, 'r', encoding='utf-8') as f: |
|
templates[filename.replace('.txt', '')] = f.read() |
|
|
|
return templates |
|
|
|
def _get_default_sql_prompt(self) -> str: |
|
"""Get default SQL generation prompt template.""" |
|
return """You are an expert SQL developer. Convert the natural language question to SQL. |
|
|
|
Table Schema: {table_schema} |
|
|
|
Examples: |
|
{examples} |
|
|
|
Question: {question} |
|
|
|
Generate SQL:""" |
|
|
|
def _get_default_few_shot_prompt(self) -> str: |
|
"""Get default few-shot learning prompt template.""" |
|
return """Given these examples, generate SQL for the new question: |
|
|
|
Examples: |
|
{examples} |
|
|
|
New Question: {question} |
|
Table Schema: {table_schema} |
|
|
|
SQL Query:""" |
|
|
|
def _get_default_error_correction_prompt(self) -> str: |
|
"""Get default error correction prompt template.""" |
|
return """The following SQL query has an error. Please correct it: |
|
|
|
Original Question: {question} |
|
Table Schema: {table_schema} |
|
Incorrect SQL: {incorrect_sql} |
|
Error: {error_message} |
|
|
|
Corrected SQL:""" |
|
|
|
def construct_sql_prompt(self, |
|
question: str, |
|
table_headers: List[str], |
|
retrieved_examples: List[Dict[str, Any]], |
|
prompt_type: str = "sql_generation") -> str: |
|
""" |
|
Construct a prompt for SQL generation. |
|
|
|
Args: |
|
question: Natural language question |
|
table_headers: List of table column names |
|
retrieved_examples: List of retrieved relevant examples |
|
prompt_type: Type of prompt to use |
|
|
|
Returns: |
|
Constructed prompt string |
|
""" |
|
|
|
table_schema = self._format_table_schema(table_headers) |
|
|
|
|
|
examples_text = self._format_examples(retrieved_examples) |
|
|
|
|
|
template = self.templates.get(prompt_type, self.templates["sql_generation"]) |
|
|
|
|
|
prompt = template.format( |
|
question=question, |
|
table_schema=table_schema, |
|
examples=examples_text |
|
) |
|
|
|
return prompt |
|
|
|
def construct_enhanced_prompt(self, |
|
question: str, |
|
table_headers: List[str], |
|
retrieved_examples: List[Dict[str, Any]], |
|
additional_context: Optional[Dict[str, Any]] = None) -> str: |
|
""" |
|
Construct an enhanced prompt with additional context and examples. |
|
|
|
Args: |
|
question: Natural language question |
|
table_headers: List of table column names |
|
retrieved_examples: List of retrieved relevant examples |
|
additional_context: Additional context information |
|
|
|
Returns: |
|
Enhanced prompt string |
|
""" |
|
|
|
prompt_parts = [self.default_system_prompt] |
|
|
|
|
|
table_schema = self._format_table_schema(table_headers) |
|
prompt_parts.append(f"Table Schema: {table_schema}\n") |
|
|
|
|
|
if retrieved_examples: |
|
prompt_parts.append("Relevant Examples (ordered by relevance):") |
|
for i, example in enumerate(retrieved_examples[:3], 1): |
|
relevance = example.get("final_score", example.get("similarity_score", 0)) |
|
prompt_parts.append(f"\nExample {i} (Relevance: {relevance:.2f}):") |
|
prompt_parts.append(f"Question: {example['question']}") |
|
prompt_parts.append(f"SQL: {example['sql']}") |
|
prompt_parts.append(f"Table: {example['table_headers']}") |
|
|
|
|
|
if additional_context: |
|
prompt_parts.append("\nAdditional Context:") |
|
for key, value in additional_context.items(): |
|
prompt_parts.append(f"{key}: {value}") |
|
|
|
|
|
prompt_parts.append(f"\nCurrent Question: {question}") |
|
prompt_parts.append("\nGenerate the SQL query:") |
|
|
|
return "\n".join(prompt_parts) |
|
|
|
def construct_few_shot_prompt(self, |
|
question: str, |
|
table_headers: List[str], |
|
examples: List[Dict[str, Any]]) -> str: |
|
""" |
|
Construct a few-shot learning prompt. |
|
|
|
Args: |
|
question: Natural language question |
|
table_headers: List of table column names |
|
examples: List of examples for few-shot learning |
|
|
|
Returns: |
|
Few-shot prompt string |
|
""" |
|
template = self.templates["few_shot_examples"] |
|
|
|
|
|
examples_text = "" |
|
for i, example in enumerate(examples[:5], 1): |
|
examples_text += f"\n--- Example {i} ---\n" |
|
examples_text += f"Question: {example['question']}\n" |
|
examples_text += f"Table: {example['table_headers']}\n" |
|
examples_text += f"SQL: {example['sql']}\n" |
|
|
|
table_schema = self._format_table_schema(table_headers) |
|
|
|
return template.format( |
|
examples=examples_text, |
|
question=question, |
|
table_schema=table_schema |
|
) |
|
|
|
def construct_error_correction_prompt(self, |
|
question: str, |
|
table_headers: List[str], |
|
incorrect_sql: str, |
|
error_message: str) -> str: |
|
""" |
|
Construct a prompt for error correction. |
|
|
|
Args: |
|
question: Natural language question |
|
table_headers: List of table column names |
|
incorrect_sql: The incorrect SQL query |
|
error_message: Error message or description |
|
|
|
Returns: |
|
Error correction prompt string |
|
""" |
|
template = self.templates["error_correction"] |
|
table_schema = self._format_table_schema(table_headers) |
|
|
|
return template.format( |
|
question=question, |
|
table_schema=table_schema, |
|
incorrect_sql=incorrect_sql, |
|
error_message=error_message |
|
) |
|
|
|
def _format_table_schema(self, table_headers: List[str]) -> str: |
|
"""Format table headers into a readable schema.""" |
|
if not table_headers: |
|
return "No table schema provided" |
|
|
|
|
|
schema_parts = [] |
|
|
|
|
|
pk_headers = [h for h in table_headers if 'id' in h.lower() or 'key' in h.lower()] |
|
if pk_headers: |
|
schema_parts.append(f"Primary Keys: {', '.join(pk_headers)}") |
|
|
|
|
|
text_headers = [h for h in table_headers if any(word in h.lower() for word in ['name', 'title', 'description', 'text'])] |
|
if text_headers: |
|
schema_parts.append(f"Text Fields: {', '.join(text_headers)}") |
|
|
|
|
|
numeric_headers = [h for h in table_headers if any(word in h.lower() for word in ['age', 'count', 'price', 'salary', 'amount', 'number'])] |
|
if numeric_headers: |
|
schema_parts.append(f"Numeric Fields: {', '.join(numeric_headers)}") |
|
|
|
|
|
date_headers = [h for h in table_headers if any(word in h.lower() for word in ['date', 'time', 'created', 'updated', 'birth'])] |
|
if date_headers: |
|
schema_parts.append(f"Date Fields: {', '.join(date_headers)}") |
|
|
|
|
|
bool_headers = [h for h in table_headers if any(word in h.lower() for word in ['is_', 'has_', 'active', 'enabled', 'status'])] |
|
if bool_headers: |
|
schema_parts.append(f"Boolean Fields: {', '.join(bool_headers)}") |
|
|
|
|
|
other_headers = [h for h in table_headers if h not in pk_headers + text_headers + numeric_headers + date_headers + bool_headers] |
|
if other_headers: |
|
schema_parts.append(f"Other Fields: {', '.join(other_headers)}") |
|
|
|
return "\n".join(schema_parts) |
|
|
|
def _format_examples(self, examples: List[Dict[str, Any]]) -> str: |
|
"""Format retrieved examples for prompt inclusion.""" |
|
if not examples: |
|
return "No relevant examples found." |
|
|
|
formatted_examples = [] |
|
for i, example in enumerate(examples[:3], 1): |
|
relevance = example.get("final_score", example.get("similarity_score", 0)) |
|
formatted_examples.append(f"Example {i} (Relevance: {relevance:.2f}):") |
|
formatted_examples.append(f" Question: {example['question']}") |
|
formatted_examples.append(f" SQL: {example['sql']}") |
|
formatted_examples.append(f" Table: {example['table_headers']}") |
|
|
|
return "\n".join(formatted_examples) |
|
|
|
def get_prompt_statistics(self) -> Dict[str, Any]: |
|
"""Get statistics about the prompt engine.""" |
|
return { |
|
"available_templates": list(self.templates.keys()), |
|
"prompts_directory": str(self.prompts_dir), |
|
"template_count": len(self.templates) |
|
} |
|
|