|
""" |
|
SQL Generator using RAG-enhanced prompts |
|
Uses the best available LLMs for SQL generation with retrieval-augmented generation. |
|
""" |
|
|
|
import os |
|
import json |
|
import time |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from pathlib import Path |
|
import openai |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
from loguru import logger |
|
|
|
from .retriever import SQLRetriever |
|
from .prompt_engine import PromptEngine |
|
|
|
class SQLGenerator: |
|
"""High-accuracy SQL generator using RAG and best available LLMs.""" |
|
|
|
def __init__(self, |
|
retriever: SQLRetriever, |
|
prompt_engine: PromptEngine, |
|
model_config: Optional[Dict[str, Any]] = None): |
|
""" |
|
Initialize the SQL generator. |
|
|
|
Args: |
|
retriever: Initialized SQL retriever |
|
prompt_engine: Initialized prompt engine |
|
model_config: Configuration for model selection and usage |
|
""" |
|
self.retriever = retriever |
|
self.prompt_engine = prompt_engine |
|
|
|
|
|
self.model_config = model_config or self._get_default_model_config() |
|
|
|
|
|
self.models = {} |
|
self._initialize_models() |
|
|
|
logger.info("SQL Generator initialized successfully") |
|
|
|
def _get_default_model_config(self) -> Dict[str, Any]: |
|
"""Get default model configuration prioritizing CodeLlama for cost efficiency.""" |
|
return { |
|
"primary_model": "codellama", |
|
"fallback_models": ["openai", "codet5", "local"], |
|
"openai_config": { |
|
"model": "gpt-3.5-turbo", |
|
"temperature": 0.1, |
|
"max_tokens": 500, |
|
"api_key_env": "OPENAI_API_KEY" |
|
}, |
|
"local_config": { |
|
"codellama_model": "TheBloke/CodeLlama-7B-Python-GGUF", |
|
"codet5_model": "Salesforce/codet5-base", |
|
"max_length": 512, |
|
"temperature": 0.1 |
|
}, |
|
"retrieval_config": { |
|
"top_k": 5, |
|
"similarity_threshold": 0.7, |
|
"use_schema_filtering": True |
|
} |
|
} |
|
|
|
def _initialize_models(self) -> None: |
|
"""Initialize available models based on configuration.""" |
|
try: |
|
|
|
if self._initialize_codellama(): |
|
self.models["codellama"] = "codellama" |
|
logger.info("CodeLlama model initialized successfully") |
|
|
|
|
|
if self._initialize_openai(): |
|
self.models["openai"] = "openai" |
|
logger.info("OpenAI GPT initialized successfully") |
|
|
|
|
|
if self._initialize_codet5(): |
|
self.models["codet5"] = "codet5" |
|
logger.info("CodeT5 model initialized successfully") |
|
|
|
|
|
if self._initialize_local_models(): |
|
self.models["local"] = "local" |
|
logger.info("Local models initialized successfully") |
|
|
|
if not self.models: |
|
raise RuntimeError("No models could be initialized") |
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing models: {e}") |
|
raise |
|
|
|
def _initialize_openai(self) -> bool: |
|
"""Initialize OpenAI API client.""" |
|
try: |
|
api_key = os.getenv(self.model_config["openai_config"]["api_key_env"]) |
|
if not api_key: |
|
logger.warning("OpenAI API key not found in environment variables") |
|
return False |
|
|
|
|
|
from openai import OpenAI |
|
client = OpenAI(api_key=api_key) |
|
response = client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[{"role": "user", "content": "Hello"}], |
|
max_tokens=10 |
|
) |
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"OpenAI initialization failed: {e}") |
|
return False |
|
|
|
def _initialize_codellama(self) -> bool: |
|
"""Initialize CodeLlama model using ctransformers.""" |
|
try: |
|
from ctransformers import AutoModelForCausalLM |
|
|
|
|
|
model_options = [ |
|
"TheBloke/CodeLlama-7B-Python-GGUF", |
|
"TheBloke/CodeLlama-7B-GGUF", |
|
"TheBloke/CodeLlama-13B-Python-GGUF", |
|
"TheBloke/CodeLlama-13B-GGUF" |
|
] |
|
|
|
for model_name in model_options: |
|
try: |
|
logger.info(f"Trying to load CodeLlama model: {model_name}") |
|
|
|
|
|
self.codellama_model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
model_type="llama", |
|
gpu_layers=0, |
|
lib="avx2", |
|
context_length=2048, |
|
batch_size=1 |
|
) |
|
|
|
logger.info(f"CodeLlama model loaded successfully: {model_name}") |
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to load {model_name}: {e}") |
|
continue |
|
|
|
logger.warning("All CodeLlama models failed to load") |
|
return False |
|
|
|
except Exception as e: |
|
logger.warning(f"CodeLlama initialization failed: {e}") |
|
return False |
|
|
|
def _initialize_codet5(self) -> bool: |
|
"""Initialize CodeT5 model.""" |
|
try: |
|
|
|
model_name = self.model_config["local_config"]["codet5_model"] |
|
self.codet5_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.codet5_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"CodeT5 initialization failed: {e}") |
|
return False |
|
|
|
def _initialize_local_models(self) -> bool: |
|
"""Initialize local models.""" |
|
try: |
|
|
|
return torch.cuda.is_available() or True |
|
|
|
except Exception as e: |
|
logger.warning(f"Local models initialization failed: {e}") |
|
return False |
|
|
|
def generate_sql(self, |
|
question: str, |
|
table_headers: List[str], |
|
use_model: Optional[str] = None) -> Dict[str, Any]: |
|
""" |
|
Generate SQL query using RAG-enhanced generation. |
|
|
|
Args: |
|
question: Natural language question |
|
table_headers: List of table column names |
|
use_model: Specific model to use (if None, auto-selects best available) |
|
|
|
Returns: |
|
Dictionary containing SQL query and metadata |
|
""" |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
retrieved_examples = self.retriever.retrieve_examples( |
|
question=question, |
|
table_headers=table_headers, |
|
top_k=self.model_config["retrieval_config"]["top_k"], |
|
use_schema_filtering=self.model_config["retrieval_config"]["use_schema_filtering"] |
|
) |
|
|
|
|
|
prompt = self.prompt_engine.construct_enhanced_prompt( |
|
question=question, |
|
table_headers=table_headers, |
|
retrieved_examples=retrieved_examples |
|
) |
|
|
|
|
|
model_name = use_model or self._select_best_model() |
|
sql_result = self._generate_with_model(model_name, prompt, question, table_headers) |
|
|
|
|
|
processed_sql = self._post_process_sql(sql_result, question, table_headers) |
|
|
|
processing_time = time.time() - start_time |
|
|
|
return { |
|
"question": question, |
|
"table_headers": table_headers, |
|
"sql_query": processed_sql, |
|
"model_used": model_name, |
|
"retrieved_examples": retrieved_examples, |
|
"processing_time": processing_time, |
|
"prompt_length": len(prompt), |
|
"status": "success" |
|
} |
|
|
|
except Exception as e: |
|
processing_time = time.time() - start_time |
|
logger.error(f"SQL generation failed: {e}") |
|
|
|
return { |
|
"question": question, |
|
"table_headers": table_headers, |
|
"sql_query": "", |
|
"model_used": "none", |
|
"retrieved_examples": [], |
|
"processing_time": processing_time, |
|
"error": str(e), |
|
"status": "error" |
|
} |
|
|
|
def _select_best_model(self) -> str: |
|
"""Select the best available model for generation.""" |
|
|
|
priority_order = ["codellama", "openai", "codet5", "local"] |
|
|
|
for model in priority_order: |
|
if model in self.models: |
|
return model |
|
|
|
|
|
if "codet5" in self.models: |
|
logger.warning("Only CodeT5 available, using intelligent fallback for better accuracy") |
|
return "fallback" |
|
|
|
|
|
return list(self.models.keys())[0] if self.models else "none" |
|
|
|
def _generate_with_model(self, |
|
model_name: str, |
|
prompt: str, |
|
question: str, |
|
table_headers: List[str]) -> str: |
|
"""Generate SQL using the specified model.""" |
|
try: |
|
if model_name == "openai": |
|
return self._generate_with_openai(prompt) |
|
elif model_name == "codellama": |
|
return self._generate_with_codellama(prompt) |
|
elif model_name == "codet5": |
|
|
|
logger.info("CodeT5 selected but unreliable, using intelligent fallback") |
|
return self._generate_with_fallback(prompt) |
|
elif model_name == "local": |
|
return self._generate_with_local(prompt) |
|
elif model_name == "fallback": |
|
return self._generate_with_fallback(prompt) |
|
else: |
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
except Exception as e: |
|
logger.error(f"Generation failed with {model_name}: {e}") |
|
|
|
return self._generate_with_fallback(prompt) |
|
|
|
def _generate_with_openai(self, prompt: str) -> str: |
|
"""Generate SQL using OpenAI GPT-4.""" |
|
try: |
|
config = self.model_config["openai_config"] |
|
api_key = os.getenv(config["api_key_env"]) |
|
|
|
from openai import OpenAI |
|
client = OpenAI(api_key=api_key) |
|
|
|
response = client.chat.completions.create( |
|
model=config["model"], |
|
messages=[ |
|
{"role": "system", "content": "You are an expert SQL developer."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=config["temperature"], |
|
max_tokens=config["max_tokens"] |
|
) |
|
|
|
sql_query = response.choices[0].message.content.strip() |
|
return self._extract_sql_from_response(sql_query) |
|
|
|
except Exception as e: |
|
logger.error(f"OpenAI generation failed: {e}") |
|
raise |
|
|
|
def is_codellama_available(self) -> bool: |
|
"""Check if CodeLlama model is available and ready for use.""" |
|
return hasattr(self, 'codellama_model') and self.codellama_model is not None |
|
|
|
def get_available_models(self) -> List[str]: |
|
"""Get list of available models.""" |
|
return list(self.models.keys()) |
|
|
|
def _generate_with_codellama(self, prompt: str) -> str: |
|
"""Generate SQL using CodeLlama.""" |
|
try: |
|
if not self.is_codellama_available(): |
|
logger.warning("CodeLlama model not properly initialized, using fallback") |
|
return self._generate_with_fallback(prompt) |
|
|
|
|
|
system_prompt = """You are an expert SQL developer. Generate only the SQL query without any explanation or additional text. The query should be valid SQL syntax.""" |
|
|
|
|
|
full_prompt = f"{system_prompt}\n\n{prompt}\n\nSQL Query:" |
|
|
|
|
|
response = self.codellama_model( |
|
full_prompt, |
|
max_new_tokens=256, |
|
temperature=0.1, |
|
top_p=0.95, |
|
repetition_penalty=1.1, |
|
stop=["\n\n", "```", "Explanation:", "Note:"] |
|
) |
|
|
|
|
|
sql_query = response.strip() |
|
|
|
|
|
if "SQL Query:" in sql_query: |
|
sql_query = sql_query.split("SQL Query:")[-1].strip() |
|
|
|
|
|
if ";" in sql_query: |
|
sql_query = sql_query.split(";")[0] + ";" |
|
|
|
logger.info(f"CodeLlama generated SQL: {sql_query}") |
|
return sql_query |
|
|
|
except Exception as e: |
|
logger.error(f"CodeLlama generation failed: {e}") |
|
return self._generate_with_fallback(prompt) |
|
|
|
def _generate_with_codet5(self, prompt: str) -> str: |
|
"""Generate SQL using CodeT5.""" |
|
try: |
|
if not hasattr(self, 'codet5_tokenizer') or not hasattr(self, 'codet5_model'): |
|
logger.warning("CodeT5 model not properly initialized, using fallback") |
|
return self._generate_with_fallback(prompt) |
|
|
|
|
|
|
|
logger.info("CodeT5 SQL generation not reliable, using intelligent fallback") |
|
return self._generate_with_fallback(prompt) |
|
|
|
except Exception as e: |
|
logger.error(f"CodeT5 generation failed: {e}") |
|
|
|
return self._generate_with_fallback(prompt) |
|
|
|
def _simplify_prompt_for_codet5(self, prompt: str) -> str: |
|
"""Simplify the prompt for better CodeT5 generation.""" |
|
|
|
lines = prompt.split('\n') |
|
simplified_lines = [] |
|
|
|
for line in lines: |
|
if line.startswith('Question:') or line.startswith('Table columns:'): |
|
simplified_lines.append(line) |
|
elif 'SELECT' in line and 'FROM' in line: |
|
|
|
simplified_lines.append(line) |
|
|
|
if simplified_lines: |
|
return '\n'.join(simplified_lines) |
|
else: |
|
|
|
return prompt |
|
|
|
def _clean_codet5_output(self, output: str) -> str: |
|
"""Clean up CodeT5 generated output.""" |
|
|
|
output = output.replace('{table_schema}', '') |
|
output = output.replace('Example(', '') |
|
output = output.replace('Relevance:', '') |
|
|
|
|
|
if 'SELECT' in output.upper(): |
|
|
|
start = output.upper().find('SELECT') |
|
sql_part = output[start:] |
|
|
|
|
|
lines = sql_part.split('\n') |
|
clean_lines = [] |
|
for line in lines: |
|
line = line.strip() |
|
if line and not line.startswith(('Example', 'Question', 'Table', 'Relevance')): |
|
clean_lines.append(line) |
|
if line.endswith(';'): |
|
break |
|
|
|
return '\n'.join(clean_lines) |
|
|
|
return output |
|
|
|
def _generate_with_local(self, prompt: str) -> str: |
|
"""Generate SQL using local models.""" |
|
try: |
|
|
|
if "codellama" in self.models: |
|
return self._generate_with_codellama(prompt) |
|
elif "codet5" in self.models: |
|
return self._generate_with_codet5(prompt) |
|
else: |
|
raise RuntimeError("No local models available") |
|
|
|
except Exception as e: |
|
logger.error(f"Local generation failed: {e}") |
|
return self._generate_with_fallback(prompt) |
|
|
|
def _generate_with_fallback(self, prompt: str) -> str: |
|
"""Generate SQL using fallback methods.""" |
|
try: |
|
prompt_lower = prompt.lower() |
|
|
|
|
|
if "salary" in prompt_lower and any(word in prompt_lower for word in ["more than", "greater than", "above", "over"]): |
|
|
|
import re |
|
|
|
|
|
|
|
exact_patterns = [ |
|
r'more than (\d+)', |
|
r'more that (\d+)', |
|
r'greater than (\d+)', |
|
r'above (\d+)', |
|
r'over (\d+)', |
|
r'(\d+) or more', |
|
r'(\d+) and above' |
|
] |
|
|
|
salary_amount = None |
|
for pattern in exact_patterns: |
|
match = re.search(pattern, prompt_lower) |
|
if match: |
|
salary_amount = int(match.group(1)) |
|
break |
|
|
|
|
|
if salary_amount is None: |
|
salary_matches = re.findall(r'(\d+)', prompt) |
|
if salary_matches: |
|
|
|
salary_amounts = [int(match) for match in salary_matches if match.isdigit()] |
|
|
|
reasonable_salaries = [amt for amt in salary_amounts if 1000 <= amt <= 1000000] |
|
|
|
if reasonable_salaries: |
|
|
|
|
|
salary_amount = reasonable_salaries[0] |
|
else: |
|
salary_amount = max(salary_amounts) if salary_amounts else 50000 |
|
else: |
|
salary_amount = 50000 |
|
|
|
|
|
return f"SELECT * FROM employees WHERE salary > {salary_amount}" |
|
|
|
|
|
elif "count" in prompt_lower or "how many" in prompt_lower: |
|
return "SELECT COUNT(*) FROM employees" |
|
|
|
|
|
elif "average" in prompt_lower or "mean" in prompt_lower: |
|
return "SELECT AVG(salary) FROM employees" |
|
|
|
|
|
elif "sum" in prompt_lower or "total" in prompt_lower: |
|
return "SELECT SUM(salary) FROM employees" |
|
|
|
|
|
elif "employees" in prompt_lower and "select" in prompt_lower: |
|
return "SELECT * FROM employees" |
|
|
|
|
|
else: |
|
return "SELECT * FROM employees" |
|
|
|
except Exception as e: |
|
logger.error(f"Fallback generation failed: {e}") |
|
return "SELECT * FROM employees" |
|
|
|
def _extract_sql_from_response(self, response: str) -> str: |
|
"""Extract SQL query from model response.""" |
|
|
|
if "```sql" in response: |
|
start = response.find("```sql") + 6 |
|
end = response.find("```", start) |
|
if end != -1: |
|
return response[start:end].strip() |
|
|
|
|
|
sql_prefixes = ["SQL:", "Query:", "SELECT", "SELECT *", "SELECT * FROM"] |
|
for prefix in sql_prefixes: |
|
if prefix in response: |
|
start = response.find(prefix) |
|
sql_part = response[start:].strip() |
|
|
|
lines = sql_part.split('\n') |
|
sql_lines = [] |
|
for line in lines: |
|
if line.strip() and not line.strip().startswith(('Note:', 'Explanation:', '#')): |
|
sql_lines.append(line) |
|
if line.strip().endswith(';'): |
|
break |
|
return '\n'.join(sql_lines).strip() |
|
|
|
|
|
return response.strip() |
|
|
|
def _post_process_sql(self, |
|
sql_query: str, |
|
question: str, |
|
table_headers: List[str]) -> str: |
|
"""Post-process and validate generated SQL.""" |
|
if not sql_query: |
|
return sql_query |
|
|
|
|
|
sql_query = sql_query.strip() |
|
|
|
|
|
if not sql_query.upper().startswith('SELECT'): |
|
sql_query = f"SELECT * FROM employees WHERE 1=1" |
|
|
|
|
|
if not sql_query.endswith(';'): |
|
sql_query += ';' |
|
|
|
|
|
|
|
used_columns = [] |
|
for header in table_headers: |
|
if header.lower() in sql_query.lower(): |
|
used_columns.append(header) |
|
|
|
if not used_columns and len(table_headers) > 0: |
|
|
|
sql_query = f"SELECT {table_headers[0]} FROM employees;" |
|
|
|
return sql_query |
|
|
|
def get_generation_stats(self) -> Dict[str, Any]: |
|
"""Get statistics about the SQL generator.""" |
|
return { |
|
"available_models": list(self.models.keys()), |
|
"model_config": self.model_config, |
|
"retriever_stats": self.retriever.get_retrieval_stats(), |
|
"prompt_stats": self.prompt_engine.get_prompt_statistics() |
|
} |
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
"""Get detailed information about available models.""" |
|
model_info = { |
|
"available_models": list(self.models.keys()), |
|
"primary_model": self.model_config.get("primary_model", "codellama"), |
|
"codellama_status": "available" if self.is_codellama_available() else "unavailable", |
|
"openai_status": "available" if "openai" in self.models else "unavailable", |
|
"model_config": self.model_config |
|
} |
|
|
|
|
|
if self.is_codellama_available(): |
|
try: |
|
model_info["codellama_details"] = { |
|
"model_type": "CodeLlama", |
|
"context_length": 2048, |
|
"temperature": 0.1 |
|
} |
|
except Exception as e: |
|
model_info["codellama_details"] = {"error": str(e)} |
|
|
|
return model_info |
|
|