import os import re import httpx import pandas as pd import json from typing import Tuple, Optional # Load API key OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") if not OPENROUTER_KEY: raise ValueError("❌ Set OPENROUTER_API_KEY or OPENAI_API_KEY") OPENROUTER_MODEL = "mistralai/mistral-small-3.2-24b-instruct:free" class Model: def __init__(self, conn): self.conn = conn self.query_cache = {} # --- Enhanced LLM call with better error handling --- @staticmethod def call_llm(prompt: str, timeout: int = 60, max_tokens: int = 800) -> Tuple[int, str]: """Call LLM with enhanced error handling and retry logic""" url = "https://openrouter.ai/api/v1/chat/completions" headers = { "Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json" } payload = { "model": OPENROUTER_MODEL, "messages": [{"role": "user", "content": prompt}], "temperature": 0.1, # Lower temperature for more consistent SQL generation "max_tokens": max_tokens, "top_p": 0.9, "frequency_penalty": 0.0, "presence_penalty": 0.0 } max_retries = 3 for attempt in range(max_retries): try: resp = httpx.post(url, headers=headers, json=payload, timeout=timeout) if resp.status_code == 200: j = resp.json() choices = j.get("choices", []) if choices: first = choices[0] text = first.get("message", {}).get("content") or first.get("text") or "" return 200, text.strip() else: return resp.status_code, "No choices returned" elif resp.status_code == 429: # Rate limit import time wait_time = 2 ** attempt # Exponential backoff time.sleep(wait_time) continue else: try: error_data = resp.json() return resp.status_code, f"API Error: {error_data.get('error', {}).get('message', 'Unknown error')}" except: return resp.status_code, f"HTTP Error: {resp.status_code}" except httpx.TimeoutException: if attempt < max_retries - 1: continue return 408, "Request timeout - please try again" except Exception as e: if attempt < max_retries - 1: continue return 500, f"Request failed: {str(e)}" return 500, "Max retries exceeded" # --- Enhanced SQL generation --- def generate_sql(self, question: str) -> str: """Generate SQL with enhanced prompting and validation""" if self.conn is None: return "" # Check cache first cache_key = question.lower().strip() if cache_key in self.query_cache: return self.query_cache[cache_key] try: # Get comprehensive schema information schema_info = self._get_schema_info() sample_data = self._get_sample_data() column_stats = self._get_column_statistics() # Enhanced prompt with better context prompt = f"""You are an expert SQL analyst. Convert the natural language question into a precise SQLite query. DATABASE SCHEMA: {schema_info['schema_text']} COLUMN INFORMATION: {schema_info['columns_info']} SAMPLE DATA (first 3 rows): {sample_data} DATA STATISTICS: {column_stats} USER QUESTION: "{question}" INSTRUCTIONS: - Write ONLY the SQL query, no explanations - Use SQLite syntax - The table name is "data" - Be precise with column names (they are case-sensitive) - For text searches, use LIKE with % wildcards - For numeric comparisons, ensure proper data types - Include appropriate LIMIT clause if asking for "top" results - Use GROUP BY for aggregations - Order results logically (DESC for highest, ASC for lowest) Examples of good SQL patterns: - SELECT column FROM data WHERE condition LIMIT 10 - SELECT category, COUNT(*) FROM data GROUP BY category ORDER BY COUNT(*) DESC - SELECT * FROM data WHERE text_column LIKE '%search_term%' - SELECT AVG(numeric_column) FROM data WHERE condition SQL Query:""" status, llm_output = self.call_llm(prompt, max_tokens=500) if status != 200: return "" # Extract and validate SQL sql = self._extract_sql(llm_output) validated_sql = self._validate_sql(sql) # Cache successful query if validated_sql: self.query_cache[cache_key] = validated_sql return validated_sql except Exception as e: print(f"Error generating SQL: {e}") return "" def _get_schema_info(self) -> dict: """Get detailed schema information""" try: # Get table schema schema_rows = pd.read_sql("SELECT name, sql FROM sqlite_master WHERE type='table'", self.conn) schema_text = "\n".join(schema_rows["sql"].dropna().tolist()) # Get column information cursor = self.conn.cursor() cursor.execute("PRAGMA table_info(data)") columns_info = cursor.fetchall() columns_detail = [] for col in columns_info: col_name = col[1] col_type = col[2] columns_detail.append(f"- {col_name} ({col_type})") return { 'schema_text': schema_text, 'columns_info': "\n".join(columns_detail), 'columns_list': [col[1] for col in columns_info] } except Exception: return {'schema_text': '', 'columns_info': '', 'columns_list': []} def _get_sample_data(self) -> str: """Get formatted sample data""" try: sample_df = pd.read_sql("SELECT * FROM data LIMIT 3", self.conn) if not sample_df.empty: return sample_df.to_string(index=False, max_cols=10) return "No sample data available" except Exception: return "Error retrieving sample data" def _get_column_statistics(self) -> str: """Get basic statistics about columns""" try: cursor = self.conn.cursor() cursor.execute("SELECT COUNT(*) FROM data") total_rows = cursor.fetchone()[0] # Get column info cursor.execute("PRAGMA table_info(data)") columns = cursor.fetchall() stats = [f"Total rows: {total_rows:,}"] stats.append(f"Total columns: {len(columns)}") # Sample some numeric/text columns for better context for col in columns[:5]: # First 5 columns col_name = col[1] try: cursor.execute(f"SELECT COUNT(DISTINCT `{col_name}`) FROM data") unique_count = cursor.fetchone()[0] stats.append(f"'{col_name}': {unique_count} unique values") except: continue return "\n".join(stats) except Exception: return "Statistics unavailable" def _extract_sql(self, llm_output: str) -> str: """Extract SQL from LLM output with multiple strategies""" if not llm_output: return "" # Strategy 1: Look for SQL code blocks sql_block_patterns = [ r"```sql\s*(.*?)```", r"```\s*(SELECT.*?)```", r"`(SELECT.*?)`" ] for pattern in sql_block_patterns: matches = re.findall(pattern, llm_output, re.DOTALL | re.IGNORECASE) if matches: return matches[0].strip() # Strategy 2: Find SQL statements by keywords lines = llm_output.split('\n') sql_lines = [] in_sql = False for line in lines: line = line.strip() if not line: continue # Start of SQL if line.upper().startswith(('SELECT', 'WITH', 'UPDATE', 'INSERT', 'DELETE')): in_sql = True sql_lines.append(line) elif in_sql: # Continue SQL until we hit a non-SQL line if any(keyword in line.upper() for keyword in ['FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'LIMIT', 'JOIN', 'AND', 'OR']): sql_lines.append(line) elif line.endswith(';'): sql_lines.append(line) break elif not any(char.isalpha() for char in line): # Only contains numbers, symbols sql_lines.append(line) else: break if sql_lines: return ' '.join(sql_lines).strip() # Strategy 3: Look for any SELECT statement select_match = re.search(r'(SELECT.*?)(?:\n\n|\Z)', llm_output, re.DOTALL | re.IGNORECASE) if select_match: return select_match.group(1).strip() return "" def _validate_sql(self, sql: str) -> str: """Validate SQL query without executing it""" if not sql: return "" # Clean up SQL sql = sql.strip().rstrip(';') # Basic validation if not sql.upper().startswith(('SELECT', 'WITH')): return "" # Check for dangerous operations dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE', 'TRUNCATE'] for keyword in dangerous_keywords: if keyword.upper() in sql.upper(): return "" # Try to parse with SQLite to validate syntax try: cursor = self.conn.cursor() cursor.execute(f"EXPLAIN QUERY PLAN {sql}") return sql except Exception: # If validation fails, return original SQL and let execute_sql handle the error return sql # --- Enhanced SQL execution --- def execute_sql(self, sql: str) -> pd.DataFrame: """Execute SQL with better error handling and limits""" if self.conn is None or not sql: return pd.DataFrame() try: # Add safety limits if not present sql_upper = sql.upper() if 'LIMIT' not in sql_upper and 'COUNT(' not in sql_upper: sql += " LIMIT 1000" df = pd.read_sql(sql, self.conn) # Additional safety check if len(df) > 10000: df = df.head(1000) return df except Exception as e: print(f"SQL execution error: {e}") return pd.DataFrame() # --- Enhanced natural language generation --- def results_to_natural_language(self, question: str, df: pd.DataFrame) -> str: """Convert results to natural language with better formatting""" if df.empty: return "The query executed successfully but returned no results. Try refining your question or check if the data contains what you're looking for." # For single value results if len(df) == 1 and len(df.columns) == 1: value = df.iloc[0, 0] return f"The answer is {value:,}" if isinstance(value, (int, float)) else f"The answer is {value}" # For aggregation results if len(df) == 1 and any(keyword in question.lower() for keyword in ['average', 'avg', 'sum', 'total', 'count', 'max', 'min']): result_parts = [] for col, val in df.iloc[0].items(): if isinstance(val, (int, float)): result_parts.append(f"{col}: {val:,}") else: result_parts.append(f"{col}: {val}") return "The result is " + ", ".join(result_parts) # Use LLM for complex results try: # Limit data sent to LLM sample_size = min(10, len(df)) sample_df = df.head(sample_size) # Convert to a concise format if len(df.columns) <= 3: results_text = sample_df.to_string(index=False, max_rows=10) else: # For wide tables, just show key columns results_text = str(sample_df.to_dict('records')[:5]) prompt = f"""Convert this SQL query result into a natural, conversational answer. USER QUESTION: "{question}" RESULTS ({len(df)} total rows, showing first {sample_size}): {results_text} Write a clear, concise answer that directly addresses the user's question. Include specific numbers/values when relevant. Keep it under 100 words. Answer:""" status, llm_output = self.call_llm(prompt, max_tokens=200) if status == 200 and llm_output: answer = llm_output.strip() # Add result count if there are many results if len(df) > 10: answer += f" (showing results from {len(df):,} total records)" return answer except Exception as e: print(f"Error in NL generation: {e}") # Fallback response if len(df) <= 5: return f"Found {len(df)} result(s). Here's what your data shows based on your question." else: return f"Found {len(df):,} results that match your question. The data has been retrieved successfully."