Spaces:
Running
Running
| """Relationship discovery between database tables. | |
| Detects relationships via: | |
| 1. Explicit foreign-key constraints | |
| 2. Matching column names across tables | |
| 3. ID-like suffix patterns (*_id, *_key) | |
| 4. Fuzzy name matching (cust_id β customer_id) | |
| """ | |
| from dataclasses import dataclass | |
| from difflib import SequenceMatcher | |
| from sqlalchemy import text | |
| from db.connection import get_engine | |
| from db.schema import get_schema | |
| class Relationship: | |
| table_a: str | |
| column_a: str | |
| table_b: str | |
| column_b: str | |
| confidence: float # 0.0 β 1.0 | |
| source: str # "fk", "exact_match", "id_pattern", "fuzzy" | |
| def discover_relationships() -> list[Relationship]: | |
| """Return all discovered relationships across public tables.""" | |
| rels: list[Relationship] = [] | |
| rels.extend(_fk_relationships()) | |
| rels.extend(_implicit_relationships()) | |
| return _deduplicate(rels) | |
| # ββ Explicit FK relationships βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _fk_relationships() -> list[Relationship]: | |
| query = text(""" | |
| SELECT | |
| tc.table_name AS source_table, | |
| kcu.column_name AS source_column, | |
| ccu.table_name AS target_table, | |
| ccu.column_name AS target_column | |
| FROM information_schema.table_constraints tc | |
| JOIN information_schema.key_column_usage kcu | |
| ON tc.constraint_name = kcu.constraint_name | |
| AND tc.table_schema = kcu.table_schema | |
| JOIN information_schema.constraint_column_usage ccu | |
| ON ccu.constraint_name = tc.constraint_name | |
| AND ccu.table_schema = tc.table_schema | |
| WHERE tc.constraint_type = 'FOREIGN KEY' | |
| AND tc.table_schema = 'public' | |
| """) | |
| rels: list[Relationship] = [] | |
| with get_engine().connect() as conn: | |
| for row in conn.execute(query).fetchall(): | |
| rels.append(Relationship( | |
| table_a=row[0], column_a=row[1], | |
| table_b=row[2], column_b=row[3], | |
| confidence=1.0, source="fk", | |
| )) | |
| return rels | |
| # ββ Implicit relationships ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _implicit_relationships() -> list[Relationship]: | |
| schema = get_schema() | |
| tables = list(schema.keys()) | |
| rels: list[Relationship] = [] | |
| for i, t1 in enumerate(tables): | |
| cols1 = {c["column_name"] for c in schema[t1]} | |
| for t2 in tables[i + 1:]: | |
| cols2 = {c["column_name"] for c in schema[t2]} | |
| # 1. Exact column-name matches | |
| common = cols1 & cols2 | |
| for col in common: | |
| rels.append(Relationship( | |
| table_a=t1, column_a=col, | |
| table_b=t2, column_b=col, | |
| confidence=0.85, source="exact_match", | |
| )) | |
| # 2. ID-pattern matching (e.g. "id" in t1 β "t1_id" in t2) | |
| for c1 in cols1: | |
| if not c1.endswith(("_id", "_key", "id")): | |
| continue | |
| for c2 in cols2: | |
| if not c2.endswith(("_id", "_key", "id")): | |
| continue | |
| if c1 == c2: | |
| continue # already caught above | |
| base1 = c1.rsplit("_", 1)[0] if "_" in c1 else c1 | |
| base2 = c2.rsplit("_", 1)[0] if "_" in c2 else c2 | |
| if base1 == base2: | |
| rels.append(Relationship( | |
| table_a=t1, column_a=c1, | |
| table_b=t2, column_b=c2, | |
| confidence=0.75, source="id_pattern", | |
| )) | |
| # 3. Fuzzy matching for remaining column pairs | |
| for c1 in cols1: | |
| for c2 in cols2: | |
| if c1 == c2: | |
| continue | |
| ratio = SequenceMatcher(None, c1, c2).ratio() | |
| if ratio >= 0.75: | |
| rels.append(Relationship( | |
| table_a=t1, column_a=c1, | |
| table_b=t2, column_b=c2, | |
| confidence=round(ratio * 0.8, 2), | |
| source="fuzzy", | |
| )) | |
| return rels | |
| def _deduplicate(rels: list[Relationship]) -> list[Relationship]: | |
| """Keep the highest-confidence relationship for each column pair.""" | |
| best: dict[tuple, Relationship] = {} | |
| for r in rels: | |
| key = tuple(sorted([(r.table_a, r.column_a), (r.table_b, r.column_b)])) | |
| if key not in best or r.confidence > best[key].confidence: | |
| best[key] = r | |
| return list(best.values()) | |
| def format_relationships(rels: list[Relationship] | None = None) -> str: | |
| """Format relationships as a readable string for prompt injection.""" | |
| if rels is None: | |
| rels = discover_relationships() | |
| if not rels: | |
| return "No explicit or inferred relationships found between tables." | |
| lines: list[str] = [] | |
| for r in sorted(rels, key=lambda x: -x.confidence): | |
| lines.append( | |
| f"{r.table_a}.{r.column_a} <-> {r.table_b}.{r.column_b} " | |
| f"(confidence: {r.confidence:.0%}, source: {r.source})" | |
| ) | |
| return "\n".join(lines) | |