sqlbot / db /relationships.py
sqlbot
Initial Hugging Face sqlbot setup
28035e9
"""Relationship discovery between database tables.
Detects relationships via:
1. Explicit foreign-key constraints
2. Matching column names across tables
3. ID-like suffix patterns (*_id, *_key)
4. Fuzzy name matching (cust_id β‰ˆ customer_id)
"""
from dataclasses import dataclass
from difflib import SequenceMatcher
from sqlalchemy import text
from db.connection import get_engine
from db.schema import get_schema
@dataclass
class Relationship:
table_a: str
column_a: str
table_b: str
column_b: str
confidence: float # 0.0 – 1.0
source: str # "fk", "exact_match", "id_pattern", "fuzzy"
def discover_relationships() -> list[Relationship]:
"""Return all discovered relationships across public tables."""
rels: list[Relationship] = []
rels.extend(_fk_relationships())
rels.extend(_implicit_relationships())
return _deduplicate(rels)
# ── Explicit FK relationships ───────────────────────────────────────────────
def _fk_relationships() -> list[Relationship]:
query = text("""
SELECT
tc.table_name AS source_table,
kcu.column_name AS source_column,
ccu.table_name AS target_table,
ccu.column_name AS target_column
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
""")
rels: list[Relationship] = []
with get_engine().connect() as conn:
for row in conn.execute(query).fetchall():
rels.append(Relationship(
table_a=row[0], column_a=row[1],
table_b=row[2], column_b=row[3],
confidence=1.0, source="fk",
))
return rels
# ── Implicit relationships ──────────────────────────────────────────────────
def _implicit_relationships() -> list[Relationship]:
schema = get_schema()
tables = list(schema.keys())
rels: list[Relationship] = []
for i, t1 in enumerate(tables):
cols1 = {c["column_name"] for c in schema[t1]}
for t2 in tables[i + 1:]:
cols2 = {c["column_name"] for c in schema[t2]}
# 1. Exact column-name matches
common = cols1 & cols2
for col in common:
rels.append(Relationship(
table_a=t1, column_a=col,
table_b=t2, column_b=col,
confidence=0.85, source="exact_match",
))
# 2. ID-pattern matching (e.g. "id" in t1 ↔ "t1_id" in t2)
for c1 in cols1:
if not c1.endswith(("_id", "_key", "id")):
continue
for c2 in cols2:
if not c2.endswith(("_id", "_key", "id")):
continue
if c1 == c2:
continue # already caught above
base1 = c1.rsplit("_", 1)[0] if "_" in c1 else c1
base2 = c2.rsplit("_", 1)[0] if "_" in c2 else c2
if base1 == base2:
rels.append(Relationship(
table_a=t1, column_a=c1,
table_b=t2, column_b=c2,
confidence=0.75, source="id_pattern",
))
# 3. Fuzzy matching for remaining column pairs
for c1 in cols1:
for c2 in cols2:
if c1 == c2:
continue
ratio = SequenceMatcher(None, c1, c2).ratio()
if ratio >= 0.75:
rels.append(Relationship(
table_a=t1, column_a=c1,
table_b=t2, column_b=c2,
confidence=round(ratio * 0.8, 2),
source="fuzzy",
))
return rels
def _deduplicate(rels: list[Relationship]) -> list[Relationship]:
"""Keep the highest-confidence relationship for each column pair."""
best: dict[tuple, Relationship] = {}
for r in rels:
key = tuple(sorted([(r.table_a, r.column_a), (r.table_b, r.column_b)]))
if key not in best or r.confidence > best[key].confidence:
best[key] = r
return list(best.values())
def format_relationships(rels: list[Relationship] | None = None) -> str:
"""Format relationships as a readable string for prompt injection."""
if rels is None:
rels = discover_relationships()
if not rels:
return "No explicit or inferred relationships found between tables."
lines: list[str] = []
for r in sorted(rels, key=lambda x: -x.confidence):
lines.append(
f"{r.table_a}.{r.column_a} <-> {r.table_b}.{r.column_b} "
f"(confidence: {r.confidence:.0%}, source: {r.source})"
)
return "\n".join(lines)