Spaces:
Running
Running
| import json | |
| import os | |
| from functools import lru_cache | |
| from openai import OpenAI | |
| from datetime import datetime, date, timedelta | |
| import re | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # ========================= | |
| # METADATA LOADING | |
| # ========================= | |
| def load_metadata(): | |
| with open("modules.json") as f: | |
| modules = json.load(f) | |
| with open("join_graph.json") as f: | |
| joins = json.load(f) | |
| with open("field_types.json") as f: | |
| field_types = json.load(f) | |
| with open("fields.json") as f: | |
| fields = json.load(f) | |
| return { | |
| "modules": modules, | |
| "joins": joins, | |
| "field_types": field_types, | |
| "fields": fields | |
| } | |
| # ========================= | |
| # OPERATOR RESOLUTION (COMPLETE FIXED VERSION) | |
| # ========================= | |
| def resolve_operator(op, value, field_type=None): | |
| """ | |
| Resolve operator and format value based on data type | |
| FIXED: Properly handles numeric types without quotes | |
| """ | |
| # Normalize operator input | |
| op = op.lower().strip().replace(" ", "_") | |
| # Extended operator aliases for all your operators | |
| OPERATOR_ALIASES = { | |
| "=": "equals", | |
| "==": "equals", | |
| "eq": "equals", | |
| "!=": "not_equals", | |
| "<>": "not_equals", | |
| ">": "greater_than", | |
| "<": "less_than", | |
| ">=": "greater_or_equal", | |
| "<=": "less_or_equal", | |
| "greater than": "greater_than", | |
| "less than": "less_than", | |
| "greaterthan": "greater_than", | |
| "lessthan": "less_than", | |
| "greaterthanorequal": "greater_or_equal", | |
| "lessthanorequal": "less_or_equal", | |
| "does_not_contain": "not_contains", | |
| "is_blank": "is_empty", | |
| "is_not_blank": "is_not_empty", | |
| "on": "equals", | |
| "date_equals": "equals", | |
| "date_between": "between", | |
| "startswith": "starts_with", | |
| "endswith": "ends_with" | |
| } | |
| op = OPERATOR_ALIASES.get(op, op) | |
| # SQL operator mapping | |
| mapping = { | |
| "equals": "=", | |
| "not_equals": "!=", | |
| "greater_than": ">", | |
| "less_than": "<", | |
| "greater_or_equal": ">=", | |
| "less_or_equal": "<=", | |
| "contains": "LIKE", | |
| "not_contains": "NOT LIKE", | |
| "starts_with": "LIKE", | |
| "ends_with": "LIKE", | |
| "in": "IN", | |
| "not_in": "NOT IN", | |
| "is_empty": "IS NULL", | |
| "is_not_empty": "IS NOT NULL", | |
| "between": "BETWEEN", | |
| "not_between": "NOT BETWEEN", | |
| "before": "<", | |
| "after": ">", | |
| # Date relative operators | |
| "today": "=", | |
| "yesterday": "=", | |
| "tomorrow": "=", | |
| "this_week": "BETWEEN", | |
| "last_week": "BETWEEN", | |
| "next_week": "BETWEEN", | |
| "this_month": "BETWEEN", | |
| "last_month": "BETWEEN", | |
| "next_month": "BETWEEN", | |
| "this_quarter": "BETWEEN", | |
| "last_quarter": "BETWEEN", | |
| "next_quarter": "BETWEEN", | |
| "this_year": "BETWEEN", | |
| "last_year": "BETWEEN" | |
| } | |
| if op not in mapping: | |
| raise ValueError(f"Unsupported operator: {op}") | |
| sql_op = mapping[op] | |
| # β Determine if field is numeric | |
| is_numeric = field_type in ['integer', 'decimal', 'float', 'number', 'int', 'bigint'] | |
| is_date = field_type in ['date', 'datetime', 'timestamp'] | |
| is_boolean = field_type in ['boolean', 'bool'] | |
| # Escape string values safely | |
| def sql_escape(val): | |
| if val is None: | |
| return 'NULL' | |
| return str(val).replace("'", "''") | |
| # Handle NULL operators | |
| if op in ("is_empty", "is_not_empty"): | |
| return sql_op, "" | |
| # Handle date relative operators | |
| if op in ("today", "yesterday", "tomorrow", "this_week", "last_week", "next_week", | |
| "this_month", "last_month", "next_month", "this_quarter", "last_quarter", | |
| "next_quarter", "this_year", "last_year"): | |
| today = date.today() | |
| if op == "today": | |
| return "=", f"'{today}'" | |
| elif op == "yesterday": | |
| return "=", f"'{today - timedelta(days=1)}'" | |
| elif op == "tomorrow": | |
| return "=", f"'{today + timedelta(days=1)}'" | |
| elif op == "this_week": | |
| start = today - timedelta(days=today.weekday()) | |
| end = start + timedelta(days=6) | |
| return "BETWEEN", f"'{start}' AND '{end}'" | |
| elif op == "this_month": | |
| start = today.replace(day=1) | |
| if today.month == 12: | |
| end = today.replace(day=31) | |
| else: | |
| end = (today.replace(month=today.month+1, day=1) - timedelta(days=1)) | |
| return "BETWEEN", f"'{start}' AND '{end}'" | |
| elif op == "this_year": | |
| start = today.replace(month=1, day=1) | |
| end = today.replace(month=12, day=31) | |
| return "BETWEEN", f"'{start}' AND '{end}'" | |
| # Add more as needed | |
| # Handle LIKE operators | |
| if op == "contains": | |
| return sql_op, f"'%{sql_escape(value)}%'" | |
| if op == "not_contains": | |
| return sql_op, f"'%{sql_escape(value)}%'" | |
| if op == "starts_with": | |
| return sql_op, f"'{sql_escape(value)}%'" | |
| if op == "ends_with": | |
| return sql_op, f"'%{sql_escape(value)}'" | |
| # Handle BETWEEN operator | |
| if op in ("between", "not_between"): | |
| if not isinstance(value, (list, tuple)) or len(value) != 2: | |
| raise ValueError("BETWEEN operator requires array of 2 values") | |
| if is_numeric: | |
| return sql_op, f"{value[0]} AND {value[1]}" | |
| else: | |
| return sql_op, f"'{sql_escape(value[0])}' AND '{sql_escape(value[1])}'" | |
| # β Handle IN operators with type checking | |
| if op in ("in", "not_in"): | |
| if not isinstance(value, list): | |
| value = [value] | |
| if is_numeric: | |
| escaped = [str(v) for v in value] # β No quotes for numbers | |
| else: | |
| escaped = [f"'{sql_escape(v)}'" for v in value] | |
| return sql_op, f"({', '.join(escaped)})" | |
| # β Handle regular comparison operators with type awareness | |
| if is_numeric: | |
| return sql_op, str(value) # β No quotes for numbers | |
| elif is_boolean: | |
| if isinstance(value, bool): | |
| return sql_op, "1" if value else "0" | |
| return sql_op, str(value) | |
| elif is_date: | |
| return sql_op, f"'{sql_escape(value)}'" | |
| else: | |
| return sql_op, f"'{sql_escape(value)}'" | |
| # ========================= | |
| # JOIN RESOLUTION (FIXED) | |
| # ========================= | |
| def resolve_join_path(start_table, end_table): | |
| """ | |
| Find join path between two tables | |
| FIXED: Handles your join_graph.json structure | |
| """ | |
| joins = load_metadata()["joins"] | |
| # Try direct lookup with double underscore | |
| key = f"{start_table}__{end_table}" | |
| if key in joins: | |
| return joins[key] | |
| # Try searching by start and end table | |
| for path_key, path in joins.items(): | |
| if path["start_table"] == start_table and path["end_table"] == end_table: | |
| return path | |
| raise ValueError( | |
| f"No join path found from {start_table} to {end_table}" | |
| ) | |
| def build_join_sql(base_table, join_path): | |
| """ | |
| Build JOIN SQL from join path | |
| FIXED: Properly handles multi-step joins with from_previous_step flag | |
| """ | |
| steps = join_path["steps"] | |
| sql = [] | |
| # Sort steps by step number | |
| sorted_steps = sorted(steps, key=lambda x: x.get("step", 0)) | |
| for i, step in enumerate(sorted_steps): | |
| alias = step["alias"] | |
| table = step["table"] | |
| join_type = step["join_type"].upper() | |
| # β Determine the left side of the join | |
| if i == 0: | |
| # First join always references base table | |
| left_ref = base_table | |
| else: | |
| # Subsequent joins: check from_previous_step flag | |
| if step.get("from_previous_step", False): | |
| left_ref = sorted_steps[i-1]["alias"] # β Use previous alias | |
| else: | |
| left_ref = base_table | |
| # Build basic join condition | |
| join_condition = f"{left_ref}.{step['base_column']} = {alias}.{step['foreign_column']}" | |
| # β Add extra conditions if present | |
| if "extra_conditions" in step and step["extra_conditions"]: | |
| for extra in step["extra_conditions"]: | |
| condition = f"{alias}.{extra['column']} {extra['operator']} {extra['value']}" | |
| join_condition += f" AND {condition}" | |
| sql.append( | |
| f"{join_type} JOIN {table} {alias} ON {join_condition}" | |
| ) | |
| return "\n".join(sql) | |
| # ========================= | |
| # FIELD RESOLUTION | |
| # ========================= | |
| FIELD_ALIASES = { | |
| "join_date": "date_of_joining", | |
| "joining_date": "date_of_joining", | |
| "joined": "date_of_joining", | |
| "hire_date": "date_of_joining", | |
| "emp_code": "employee_code", | |
| "emp_name": "full_name", | |
| "dept": "department" | |
| } | |
| def resolve_field(field_name, module): | |
| meta = load_metadata() | |
| fields = meta["fields"] | |
| # πΉ Normalize field name | |
| field_name = field_name.lower().strip().replace(" ", "_") | |
| field_name = FIELD_ALIASES.get(field_name, field_name) | |
| # πΉ Validate existence | |
| if field_name not in fields: | |
| raise ValueError(f"Unknown field: {field_name}") | |
| field = fields[field_name] | |
| # πΉ Validate module | |
| if field["module"] != module: | |
| raise ValueError( | |
| f"Field '{field_name}' does not belong to module '{module}'" | |
| ) | |
| # πΉ Validate mapping | |
| if "table" not in field or "column" not in field: | |
| raise ValueError( | |
| f"Field '{field_name}' is missing table/column mapping" | |
| ) | |
| return field | |
| # ========================= | |
| # JSON SAFETY | |
| # ========================= | |
| def safe_json_loads(text): | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| # Try to extract JSON from markdown | |
| match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL) | |
| if match: | |
| return json.loads(match.group(1)) | |
| match = re.search(r"\{.*\}", text, re.DOTALL) | |
| if match: | |
| return json.loads(match.group()) | |
| raise ValueError("LLM returned invalid JSON") | |
| # ========================= | |
| # INTENT PARSING (LLM) | |
| # ========================= | |
| def parse_intent(question, retries=2): | |
| meta = load_metadata() | |
| # β Build schema safely | |
| schema_description = "\n".join([ | |
| f"{module}: {', '.join(fields)}" | |
| for module in meta["modules"] | |
| if (fields := [ | |
| f for f in meta["fields"] | |
| if meta["fields"][f]["module"] == module | |
| ][:20]) # Limit to 20 fields per module for token efficiency | |
| ]) | |
| prompt = f""" | |
| You are a Text-to-SQL engine. | |
| Your task is to generate a SINGLE valid SQL query based ONLY on the metadata provided. | |
| CRITICAL RULES (follow strictly): | |
| 1. Use ONLY the tables and columns explicitly listed in the metadata. | |
| 2. If the user asks for a field, table, or concept NOT present in the metadata, IGNORE that part. | |
| 3. Do NOT invent table names, column names, joins, or filters. | |
| 4. Do NOT explain the query. | |
| 5. Do NOT return anything except the SQL query. | |
| 6. If no valid SQL can be generated using the metadata, return a SQL query that explains the reason in a single text column named reason | |
| Database Metadata: | |
| {{METADATA_JSON}} | |
| User Question: | |
| {{USER_QUERY}} | |
| Output: | |
| - Return a single SQL query in {{SQL_DIALECT}} syntax. | |
| - No markdown. | |
| - No comments. | |
| - No extra text. | |
| """ | |
| for attempt in range(retries): | |
| try: | |
| res = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "Return ONLY valid minified JSON. No text. No explanation." | |
| }, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0 | |
| ) | |
| content = res.choices[0].message.content.strip() | |
| plan = safe_json_loads(content) | |
| # β NORMALIZE + STABILIZE INTENT SHAPE | |
| if "module" in plan: | |
| plan["module"] = plan["module"].lower().strip() | |
| plan.setdefault("filters", []) | |
| plan.setdefault("select", []) | |
| return plan | |
| except Exception as e: | |
| if attempt == retries - 1: | |
| raise ValueError(f"LLM failed to return valid JSON: {str(e)}") | |
| # ========================= | |
| # SQL GENERATOR (FIXED) | |
| # ========================= | |
| def build_sql(plan): | |
| meta = load_metadata() | |
| # π΄ Defensive: normalize module | |
| module = plan["module"].lower().strip() | |
| if module not in meta["modules"]: | |
| raise ValueError(f"Unknown module: {module}") | |
| base_table = meta["modules"][module]["base_table"] | |
| joins = [] | |
| joined_tables = {base_table} # β Track all joined tables | |
| where_clauses = [] | |
| # ---------- SELECT ---------- | |
| select_fields = plan.get("select", []) | |
| if select_fields: | |
| select_columns = [] | |
| for f in select_fields: | |
| field = resolve_field(f, module) | |
| select_columns.append( | |
| f"{field['table']}.{field['column']} AS {f}" | |
| ) | |
| select_sql = ", ".join(select_columns) | |
| else: | |
| select_sql = f"{base_table}.*" | |
| # ---------- FILTERS ---------- | |
| for f in plan.get("filters", []): | |
| field = resolve_field(f["field"], module) | |
| table = field["table"] | |
| column = field["column"] | |
| field_type = field.get("type") # β Get field type | |
| # Add join if needed | |
| if table != base_table and table not in joined_tables: | |
| join_path = resolve_join_path(base_table, table) | |
| joins.append(build_join_sql(base_table, join_path)) | |
| # β Track all tables in join path | |
| for step in join_path["steps"]: | |
| joined_tables.add(step["table"]) | |
| # β Pass field_type to resolve_operator | |
| sql_op, sql_value = resolve_operator(f["operator"], f["value"], field_type) | |
| if sql_value: # Has value | |
| where_clauses.append(f"{table}.{column} {sql_op} {sql_value}") | |
| else: # IS NULL / IS NOT NULL | |
| where_clauses.append(f"{table}.{column} {sql_op}") | |
| # π΄ FIX: safe WHERE clause | |
| where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" | |
| # ---------- FINAL SQL ---------- | |
| sql_parts = [ | |
| f"SELECT {select_sql}", | |
| f"FROM {base_table}" | |
| ] | |
| if joins: | |
| sql_parts.extend(joins) | |
| if where_sql: | |
| sql_parts.append(where_sql) | |
| sql_parts.append("LIMIT 100") | |
| sql = "\n".join(sql_parts) | |
| return sql.strip() | |
| # ========================= | |
| # VALIDATION | |
| # ========================= | |
| def validate_sql(sql): | |
| sql_lower = sql.lower() | |
| if not sql_lower.strip().startswith("select"): | |
| raise ValueError("Only SELECT allowed") | |
| forbidden = ["drop", "delete", "update", "insert", "truncate", "alter", "create"] | |
| for keyword in forbidden: | |
| if re.search(rf'\b{keyword}\b', sql_lower): | |
| raise ValueError(f"Unsafe SQL: '{keyword}' not allowed") | |
| return sql | |
| # ========================= | |
| # MAIN ENTRY POINT | |
| # ========================= | |
| def run(question): | |
| plan = parse_intent(question) | |
| # π΄ REQUIRED: validate minimum intent | |
| if not isinstance(plan, dict): | |
| raise ValueError("Invalid intent format") | |
| if "module" not in plan: | |
| raise ValueError("Unable to determine module from question") | |
| # Optional but safe defaults | |
| plan.setdefault("filters", []) | |
| plan.setdefault("select", []) | |
| sql = build_sql(plan) | |
| sql = validate_sql(sql) | |
| return { | |
| "query_plan": plan, | |
| "sql": sql | |
| } | |
| # ========================= | |
| # TEST | |
| # ========================= | |
| if __name__ == "__main__": | |
| test_queries = [ | |
| "Show all employees", | |
| "Find departments with more than 50 employees", | |
| "Show employees in departments 1, 2, 3", | |
| "List employees who joined this month" | |
| ] | |
| for q in test_queries: | |
| print(f"\n{'='*80}") | |
| print(f"Q: {q}") | |
| print('='*80) | |
| try: | |
| result = run(q) | |
| print("SQL:", result["sql"]) | |
| except Exception as e: | |
| print("ERROR:", e) |