med / visualization.py
mgbam's picture
Update visualization.py
7352f83 verified
import re
import tempfile
import os
from pyvis.network import Network
from retrieval import collection # Import the collection from retrieval.py
def extract_key_terms(text: str):
"""
Naive keyword extraction: finds capitalized words.
For production, consider using a dedicated medical NER.
"""
return re.findall(r"\b[A-Z][a-zA-Z]+\b", text)
def create_medical_graph(query: str, docs: list) -> str:
"""
Builds a Pyvis network:
- A central 'QUERY' node.
- One node per retrieved abstract.
- Key terms from abstracts as subnodes.
Returns the HTML of the generated graph.
"""
net = Network(height="600px", width="100%", directed=False)
net.add_node("QUERY", label=f"Query: {query}", color="red", shape="star")
for i, doc in enumerate(docs):
doc_id = f"Doc_{i}"
net.add_node(doc_id, label=f"Abstract {i+1}", color="blue")
net.add_edge("QUERY", doc_id)
terms = extract_key_terms(doc)
for term in set(terms):
term_id = f"{doc_id}_{term}"
net.add_node(term_id, label=term, color="green")
net.add_edge(doc_id, term_id)
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmp:
temp_filename = tmp.name
net.write_html(temp_filename, open_browser=False, notebook=False)
with open(temp_filename, "r", encoding="utf-8") as f:
html_content = f.read()
os.remove(temp_filename)
return html_content