raredx / backend /scripts /hello_world.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
hello_world.py
--------------
Week 1 Milestone: Query graph store and ChromaDB simultaneously.
Primary: Neo4j (Docker) + ChromaDB HTTP (Docker)
Fallback: LocalGraphStore (JSON) + ChromaDB PersistentClient (embedded)
Demonstrates the RareDx core pattern:
1. Retrieve disease structured data from graph store (Neo4j / JSON)
2. Find semantically related diseases from ChromaDB (BioLORD-2023 vectors)
3. Merge and display results
Usage:
python hello_world.py [disease_name]
python hello_world.py "Marfan syndrome"
"""
import os
import sys
import io
import time
# Force UTF-8 output on Windows terminals
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
import concurrent.futures
from pathlib import Path
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(Path(__file__).parents[2] / ".env")
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "raredx_password")
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"
QUERY_DISEASE = sys.argv[1] if len(sys.argv) > 1 else "Marfan syndrome"
# ---------------------------------------------------------------------------
# Graph store queries (Neo4j primary, LocalGraphStore fallback)
# ---------------------------------------------------------------------------
def fetch_from_graph(disease_name: str) -> tuple[dict | None, str]:
"""Returns (disease_dict or None, backend_label)."""
# Try Neo4j first
try:
from neo4j import GraphDatabase
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
driver.verify_connectivity()
with driver.session() as session:
result = session.run(
"""
MATCH (d:Disease)
WHERE toLower(d.name) CONTAINS toLower($name)
OPTIONAL MATCH (d)-[:HAS_SYNONYM]->(s:Synonym)
RETURN
d.orpha_code AS orpha_code,
d.name AS name,
d.definition AS definition,
d.expert_link AS expert_link,
collect(s.text) AS synonyms
LIMIT 1
""",
name=disease_name,
)
record = result.single()
driver.close()
if record:
return dict(record), "Neo4j (Docker)"
return None, "Neo4j (Docker)"
except Exception:
pass # fall through to local store
# LocalGraphStore fallback
try:
from graph_store import LocalGraphStore
store = LocalGraphStore()
disease = store.find_disease_by_name(disease_name)
return disease, "LocalGraphStore (JSON)"
except Exception as exc:
print(f" Graph store error: {exc}")
return None, "unavailable"
# ---------------------------------------------------------------------------
# ChromaDB semantic search (HTTP primary, embedded fallback)
# ---------------------------------------------------------------------------
def get_chroma_client() -> chromadb.ClientAPI:
try:
client = chromadb.HttpClient(
host=CHROMA_HOST,
port=CHROMA_PORT,
settings=Settings(anonymized_telemetry=False),
)
client.heartbeat()
return client
except Exception:
return chromadb.PersistentClient(
path=str(CHROMA_PERSIST_DIR),
settings=Settings(anonymized_telemetry=False),
)
def fetch_from_chromadb(
query_text: str,
model: SentenceTransformer,
n_results: int = 5,
) -> tuple[list[dict], str]:
client = get_chroma_client()
backend = "ChromaDB HTTP" if hasattr(client, "_api") else "ChromaDB Embedded"
collection = client.get_collection(COLLECTION_NAME)
embedding = model.encode([query_text], normalize_embeddings=True)
results = collection.query(
query_embeddings=embedding.tolist(),
n_results=n_results,
include=["documents", "metadatas", "distances"],
)
hits = []
for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
hits.append({
"orpha_code": meta.get("orpha_code"),
"name": meta.get("name"),
"definition": meta.get("definition", ""),
"synonyms": meta.get("synonyms", ""),
"cosine_similarity": round(1 - dist, 4),
})
return hits, backend
# ---------------------------------------------------------------------------
# Display
# ---------------------------------------------------------------------------
BOLD = "\033[1m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW= "\033[93m"
DIM = "\033[2m"
RESET = "\033[0m"
LINE = "-" * 62
def _wrap(text: str, width: int = 72, indent: str = " ") -> str:
words = text.split()
lines, cur = [], []
for w in words:
cur.append(w)
if len(" ".join(cur)) > width:
lines.append(indent + " ".join(cur[:-1]))
cur = [w]
if cur:
lines.append(indent + " ".join(cur))
return "\n".join(lines)
def print_graph_result(disease: dict | None, backend: str) -> None:
print(f"\n{BOLD}{CYAN}[ Graph Store — {backend} ]{RESET}")
print(LINE)
if disease is None:
print(f" {YELLOW}No match found.{RESET}")
return
print(f" {BOLD}OrphaCode :{RESET} ORPHA:{disease['orpha_code']}")
print(f" {BOLD}Name :{RESET} {disease['name']}")
if disease.get("synonyms"):
print(f" {BOLD}Synonyms :{RESET} {', '.join(disease['synonyms'])}")
if disease.get("definition"):
print(f" {BOLD}Definition :{RESET}")
print(_wrap(disease["definition"]))
if disease.get("expert_link"):
print(f" {BOLD}OrphaNet :{RESET} {DIM}{disease['expert_link']}{RESET}")
def print_chroma_results(hits: list[dict], backend: str) -> None:
print(f"\n{BOLD}{GREEN}[ ChromaDB — BioLORD-2023 Semantic Neighbours | {backend} ]{RESET}")
print(LINE)
if not hits:
print(f" {YELLOW}No results.{RESET}")
return
for rank, hit in enumerate(hits, 1):
sim = hit["cosine_similarity"]
bar_len = int(sim * 20)
bar = "█" * bar_len + "░" * (20 - bar_len)
print(f" {rank}. [{bar}] {sim:.4f} ORPHA:{hit['orpha_code']} {hit['name']}")
if hit.get("synonyms"):
print(f" {DIM}Also: {hit['synonyms']}{RESET}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 62)
print("RareDx — Week 1 Hello World Milestone")
print("=" * 62)
print(f"\nQuery: {BOLD}{QUERY_DISEASE}{RESET}\n")
# Load BioLORD (needed before spawning threads so it is not loaded twice)
print(f"Loading BioLORD-2023...")
t0 = time.time()
model = SentenceTransformer(EMBED_MODEL)
print(f" Model ready in {time.time() - t0:.1f}s")
# Parallel queries
print(f"\nQuerying graph store and ChromaDB simultaneously...")
t_start = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
graph_fut = pool.submit(fetch_from_graph, QUERY_DISEASE)
chroma_fut = pool.submit(fetch_from_chromadb, QUERY_DISEASE, model, 5)
disease, graph_backend = graph_fut.result()
hits, chroma_backend = chroma_fut.result()
elapsed = time.time() - t_start
print(f" Both queries completed in {elapsed:.2f}s")
# Display
print_graph_result(disease, graph_backend)
print_chroma_results(hits, chroma_backend)
# Summary
graph_ok = disease is not None
chroma_ok = len(hits) > 0
print(f"\n{LINE}")
print(f"{BOLD}Week 1 Milestone Summary{RESET}")
print(LINE)
print(f" Graph store : {'OK' if graph_ok else 'MISS'}{graph_backend}")
print(f" ChromaDB : {'OK' if chroma_ok else 'MISS'}{chroma_backend}")
print()
if graph_ok and chroma_ok:
print(f" {BOLD}{GREEN}PASSED{RESET} — Neo4j + ChromaDB both responding.")
else:
print(f" {YELLOW}PARTIAL{RESET} — one or more backends had no results.")
sys.exit(1)
print()
if __name__ == "__main__":
main()