|
import re |
|
import tempfile |
|
import os |
|
from pyvis.network import Network |
|
from retrieval import collection |
|
|
|
def extract_key_terms(text: str): |
|
""" |
|
Naive keyword extraction: finds capitalized words. |
|
For production, consider using a dedicated medical NER. |
|
""" |
|
return re.findall(r"\b[A-Z][a-zA-Z]+\b", text) |
|
|
|
def create_medical_graph(query: str, docs: list) -> str: |
|
""" |
|
Builds a Pyvis network: |
|
- A central 'QUERY' node. |
|
- One node per retrieved abstract. |
|
- Key terms from abstracts as subnodes. |
|
Returns the HTML of the generated graph. |
|
""" |
|
net = Network(height="600px", width="100%", directed=False) |
|
net.add_node("QUERY", label=f"Query: {query}", color="red", shape="star") |
|
for i, doc in enumerate(docs): |
|
doc_id = f"Doc_{i}" |
|
net.add_node(doc_id, label=f"Abstract {i+1}", color="blue") |
|
net.add_edge("QUERY", doc_id) |
|
terms = extract_key_terms(doc) |
|
for term in set(terms): |
|
term_id = f"{doc_id}_{term}" |
|
net.add_node(term_id, label=term, color="green") |
|
net.add_edge(doc_id, term_id) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmp: |
|
temp_filename = tmp.name |
|
net.write_html(temp_filename, open_browser=False, notebook=False) |
|
with open(temp_filename, "r", encoding="utf-8") as f: |
|
html_content = f.read() |
|
os.remove(temp_filename) |
|
return html_content |
|
|