text2sql-demo / src /schema_linker.py
tjhalanigrid's picture
Add src folder
dc59b01
"""
Simple schema linking for Spider-style Text-to-SQL.
Goal:
- Given (question, db_id), select a small set of relevant tables/columns
to include in the prompt (RAG-style schema retrieval).
Design constraints:
- Pure Python (no heavy external deps).
- Robust to missing/odd schemas: never crash.
"""
from __future__ import annotations
import json
import os
import re
import sqlite3
from contextlib import closing
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
_ALNUM_RE = re.compile(r"[A-Za-z0-9]+")
_CAMEL_RE = re.compile(r"([a-z])([A-Z])")
def _normalize_identifier(text: str) -> str:
"""
Normalize a schema identifier:
- split underscores
- split camelCase / PascalCase boundaries
- lowercase
"""
text = str(text or "")
text = text.replace("_", " ")
text = _CAMEL_RE.sub(r"\1 \2", text)
return text.lower()
def _tokenize(text: str) -> List[str]:
text = _normalize_identifier(text)
return _ALNUM_RE.findall(text)
@dataclass(frozen=True)
class TableSchema:
table_name: str
columns: Tuple[str, ...]
class SchemaLinker:
"""
Loads Spider `tables.json` and (optionally) SQLite schemas from disk.
Provides a lightweight table scoring function based on token overlap.
"""
def __init__(self, tables_json_path: str, db_root: Optional[str] = None):
self.tables_json_path = tables_json_path
self.db_root = db_root
self._tables_by_db: Dict[str, List[TableSchema]] = {}
self._sqlite_schema_cache: Dict[str, Dict[str, List[str]]] = {}
self._load_tables_json()
def _load_tables_json(self) -> None:
with open(self.tables_json_path) as f:
entries = json.load(f)
tables_by_db: Dict[str, List[TableSchema]] = {}
for entry in entries:
db_id = entry["db_id"]
table_names: List[str] = entry.get("table_names_original") or entry.get("table_names") or []
col_names: List[Sequence] = entry.get("column_names_original") or entry.get("column_names") or []
columns_by_table_idx: Dict[int, List[str]] = {i: [] for i in range(len(table_names))}
for col in col_names:
# Spider format: [table_idx, col_name]
if not col or len(col) < 2:
continue
table_idx, col_name = col[0], col[1]
if table_idx is None or table_idx < 0:
continue # skip "*"
if table_idx not in columns_by_table_idx:
continue
columns_by_table_idx[table_idx].append(str(col_name))
tables: List[TableSchema] = []
for i, tname in enumerate(table_names):
cols = tuple(columns_by_table_idx.get(i, []))
tables.append(TableSchema(table_name=str(tname), columns=cols))
tables_by_db[db_id] = tables
self._tables_by_db = tables_by_db
def _db_path(self, db_id: str) -> Optional[str]:
if not self.db_root:
return None
path = os.path.join(self.db_root, db_id, f"{db_id}.sqlite")
return path if os.path.exists(path) else None
def _load_sqlite_schema(self, db_id: str) -> Dict[str, List[str]]:
"""
Load actual SQLite schema (table -> columns). Cached per db_id.
"""
if db_id in self._sqlite_schema_cache:
return self._sqlite_schema_cache[db_id]
schema: Dict[str, List[str]] = {}
db_path = self._db_path(db_id)
if not db_path:
self._sqlite_schema_cache[db_id] = schema
return schema
try:
with closing(sqlite3.connect(db_path)) as conn:
cursor = conn.cursor()
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
for (table_name,) in tables:
columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
schema[str(table_name)] = [str(col[1]) for col in columns]
except Exception:
schema = {}
self._sqlite_schema_cache[db_id] = schema
return schema
def get_schema(self, db_id: str) -> List[TableSchema]:
"""
Returns a list of table schemas for this db.
Prefers `tables.json` (Spider canonical), but can fallback to SQLite if needed.
"""
tables = self._tables_by_db.get(db_id, [])
if tables:
return tables
sqlite_schema = self._load_sqlite_schema(db_id)
return [TableSchema(table_name=t, columns=tuple(cols)) for t, cols in sqlite_schema.items()]
def score_tables(self, question: str, db_id: str) -> List[Tuple[float, TableSchema]]:
"""
Score each table using token overlap:
- table token overlap (higher weight)
- column token overlap (lower weight)
"""
q_tokens = set(_tokenize(question))
tables = self.get_schema(db_id)
scored: List[Tuple[float, TableSchema]] = []
for t in tables:
table_tokens = set(_tokenize(t.table_name))
col_tokens: set[str] = set()
for c in t.columns:
col_tokens.update(_tokenize(c))
table_overlap = len(q_tokens & table_tokens)
col_overlap = len(q_tokens & col_tokens)
# Simple weighted overlap (tuned to bias table matches).
score = 3.0 * table_overlap + 1.0 * col_overlap
# Small boost for substring mentions (helps e.g. "album" vs "albums").
q_text = _normalize_identifier(question)
if t.table_name and _normalize_identifier(t.table_name) in q_text:
score += 0.5
scored.append((score, t))
scored.sort(key=lambda x: (x[0], x[1].table_name), reverse=True)
return scored
def select_top_tables(self, question: str, db_id: str, top_k: int = 4) -> List[TableSchema]:
scored = self.score_tables(question, db_id)
if not scored:
return []
top_k = max(1, int(top_k))
selected = [t for _, t in scored[:top_k]]
# If everything scores 0, still return a stable selection.
if scored[0][0] <= 0:
tables = self.get_schema(db_id)
return tables[:top_k]
return selected
def columns_for_selected_tables(self, db_id: str, selected_tables: Iterable[TableSchema]) -> Dict[str, List[str]]:
"""
Returns only columns belonging to selected tables.
Prefer SQLite columns (actual DB) if available; fallback to tables.json.
"""
sqlite_schema = self._load_sqlite_schema(db_id)
out: Dict[str, List[str]] = {}
for t in selected_tables:
if t.table_name in sqlite_schema and sqlite_schema[t.table_name]:
out[t.table_name] = sqlite_schema[t.table_name]
else:
out[t.table_name] = list(t.columns)
return out
def format_relevant_schema(self, question: str, db_id: str, top_k: int = 4) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Returns:
- lines: ["table(col1, col2)", ...]
- selected: {table: [cols...], ...}
"""
selected_tables = self.select_top_tables(question, db_id, top_k=top_k)
selected = self.columns_for_selected_tables(db_id, selected_tables)
lines: List[str] = []
for table_name, cols in selected.items():
cols_str = ", ".join(cols)
lines.append(f"{table_name}({cols_str})")
return lines, selected