| """ChromaDB tools for NYC code lookup — with re-ranking, budget tracking, and caching."""
|
| from __future__ import annotations
|
|
|
| import hashlib
|
| from collections import Counter
|
|
|
| import chromadb
|
| from chromadb.utils import embedding_functions
|
|
|
| from config import (
|
| CHROMA_COLLECTION_NAME,
|
| CHROMA_DB_PATH,
|
| DISCOVER_N_RESULTS,
|
| EMBEDDING_MODEL_NAME,
|
| FETCH_MAX_SECTIONS,
|
| RERANK_TOP_K,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| _collection = None
|
| _warmup_done = False
|
|
|
|
|
| def warmup_collection() -> bool:
|
| """Eagerly load the embedding model and connect to ChromaDB.
|
|
|
| Returns True if collection is available, False otherwise.
|
| Call this during app startup so the heavy model download + load
|
| happens visibly (with a progress spinner) rather than on the first query.
|
| """
|
| global _warmup_done
|
| try:
|
| get_collection()
|
| _warmup_done = True
|
| return True
|
| except Exception:
|
| _warmup_done = False
|
| return False
|
|
|
|
|
| def is_warmed_up() -> bool:
|
| return _warmup_done
|
|
|
|
|
| def get_collection():
|
| """Lazy-load the ChromaDB collection (singleton)."""
|
| global _collection
|
| if _collection is None:
|
| client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
| embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| model_name=EMBEDDING_MODEL_NAME,
|
| )
|
| _collection = client.get_collection(
|
| name=CHROMA_COLLECTION_NAME,
|
| embedding_function=embedding_fn,
|
| )
|
| return _collection
|
|
|
|
|
|
|
|
|
|
|
|
|
| class QueryCache:
|
| """Simple cache to avoid re-querying semantically identical topics."""
|
|
|
| def __init__(self):
|
| self._cache: dict[str, str] = {}
|
|
|
| def _normalize(self, query: str) -> str:
|
| words = sorted(set(query.lower().split()))
|
| return " ".join(words)
|
|
|
| def get(self, query: str) -> str | None:
|
| key = self._normalize(query)
|
| return self._cache.get(key)
|
|
|
| def put(self, query: str, result: str) -> None:
|
| key = self._normalize(query)
|
| self._cache[key] = result
|
|
|
|
|
|
|
|
|
|
|
|
|
| def discover_code_locations(query: str, cache: QueryCache | None = None) -> str:
|
| """Semantic search over NYC codes with hierarchical re-ranking.
|
|
|
| Returns a formatted report of the most relevant code sections.
|
| """
|
|
|
| if cache is not None:
|
| cached = cache.get(query)
|
| if cached is not None:
|
| return f"[CACHED RESULT]\n{cached}"
|
|
|
| collection = get_collection()
|
| results = collection.query(
|
| query_texts=[query],
|
| n_results=DISCOVER_N_RESULTS,
|
| include=["metadatas", "documents", "distances"],
|
| )
|
|
|
| if not results["metadatas"][0]:
|
| return "No results found. Try a different query phrasing."
|
|
|
| metas = results["metadatas"][0]
|
| docs = results["documents"][0]
|
| distances = results["distances"][0]
|
|
|
|
|
|
|
| ranked = []
|
| for meta, doc, dist in zip(metas, docs, distances):
|
| score = -dist
|
|
|
|
|
| depth = meta.get("section_full", "").count(".")
|
| score += max(0, 3 - depth) * 0.05
|
|
|
|
|
| if meta.get("has_exceptions", False):
|
| score += 0.1
|
|
|
| ranked.append((score, meta, doc))
|
|
|
| ranked.sort(key=lambda x: x[0], reverse=True)
|
| top_results = ranked[:RERANK_TOP_K]
|
|
|
|
|
| category_chapter_pairs = [
|
| f"{m['code_type']} | Ch. {m['parent_major']}" for _, m, _ in top_results
|
| ]
|
| counts = Counter(category_chapter_pairs)
|
| chapter_summary = "\n".join(
|
| f"- {pair} ({count} hits)" for pair, count in counts.most_common(5)
|
| )
|
|
|
| section_reports = []
|
| for _score, m, doc in top_results:
|
| exceptions_tag = " [HAS EXCEPTIONS]" if m.get("has_exceptions", False) else ""
|
| xrefs = m.get("cross_references", "")
|
| xref_tag = f"\n Cross-refs: {xrefs}" if xrefs else ""
|
|
|
| report = (
|
| f"ID: {m['section_full']} | Code: {m['code_type']} | Chapter: {m['parent_major']}"
|
| f"{exceptions_tag}{xref_tag}\n"
|
| f"Snippet: {doc[:500]}"
|
| )
|
| section_reports.append(report)
|
|
|
| output = (
|
| "### CODE DISCOVERY REPORT ###\n"
|
| f"MOST RELEVANT CHAPTERS:\n{chapter_summary}\n\n"
|
| "TOP RELEVANT SECTIONS:\n"
|
| + "\n---\n".join(section_reports)
|
| )
|
|
|
|
|
| if cache is not None:
|
| cache.put(query, output)
|
|
|
| return output
|
|
|
|
|
|
|
|
|
|
|
|
|
| def fetch_full_chapter(
|
| code_type: str,
|
| chapter_id: str,
|
| section_filter: str | None = None,
|
| ) -> str:
|
| """Retrieve sections from a specific chapter, with optional keyword filtering.
|
|
|
| Parameters
|
| ----------
|
| code_type : str
|
| One of: Administrative, Building, FuelGas, Mechanical, Plumbing
|
| chapter_id : str
|
| The parent_major chapter ID (e.g., "10", "602")
|
| section_filter : str, optional
|
| If provided, only return sections containing this keyword
|
| """
|
| collection = get_collection()
|
|
|
| try:
|
| chapter_data = collection.get(
|
| where={
|
| "$and": [
|
| {"code_type": {"$eq": code_type}},
|
| {"parent_major": {"$eq": chapter_id}},
|
| ]
|
| },
|
| include=["documents", "metadatas"],
|
| )
|
|
|
| if not chapter_data["documents"]:
|
| return f"No documentation found for {code_type} Chapter {chapter_id}."
|
|
|
| pairs = list(zip(chapter_data["metadatas"], chapter_data["documents"]))
|
|
|
|
|
| if section_filter:
|
| filter_lower = section_filter.lower()
|
| pairs = [(m, d) for m, d in pairs if filter_lower in d.lower()]
|
| if not pairs:
|
| return (
|
| f"No sections in {code_type} Chapter {chapter_id} "
|
| f"match filter '{section_filter}'."
|
| )
|
|
|
|
|
| pairs.sort(key=lambda x: x[0]["section_full"])
|
| total_sections = len(pairs)
|
| pairs = pairs[:FETCH_MAX_SECTIONS]
|
|
|
|
|
| header = f"## {code_type.upper()} CODE - CHAPTER {chapter_id}"
|
| if total_sections > FETCH_MAX_SECTIONS:
|
| header += f" (showing {FETCH_MAX_SECTIONS} of {total_sections} sections)"
|
| if section_filter:
|
| header += f" [filtered by: '{section_filter}']"
|
| header += "\n\n"
|
|
|
| full_text = header
|
| for meta, doc in pairs:
|
|
|
| blocks = doc.split("[CONT.]:")
|
| unique_blocks = []
|
| seen = set()
|
| for b in blocks:
|
| clean_b = b.strip()
|
| if clean_b:
|
| h = hashlib.md5(clean_b.encode()).hexdigest()
|
| if h not in seen:
|
| unique_blocks.append(clean_b)
|
| seen.add(h)
|
|
|
| clean_doc = " ".join(unique_blocks)
|
|
|
| exceptions_tag = ""
|
| if meta.get("has_exceptions", False):
|
| exceptions_tag = f" [CONTAINS {meta.get('exception_count', '?')} EXCEPTION(S)]"
|
|
|
| full_text += (
|
| f"### SECTION {meta['section_full']}{exceptions_tag}\n"
|
| f"{clean_doc}\n\n---\n\n"
|
| )
|
|
|
| return full_text
|
|
|
| except Exception as e:
|
| return f"Error retrieving chapter content: {e!s}"
|
|
|