Singularity / scripts /3.2_visualize_340b.py
SlappAI's picture
Dev Scripts
f5b1acc
raw
history blame
3.27 kB
import requests
import networkx as nx
import matplotlib.pyplot as plt
# API Base URL
base_url = "http://localhost:5000"
def fetch_relationships(node_id, direction="down"):
"""Fetch relationships for the specified node in the given direction (up or down)."""
response = requests.get(f"{base_url}/traverse_node?node_id={node_id}&direction={direction}")
return response.json().get("traversal_path", {})
def build_graph_from_relationships(node_id):
"""Builds a NetworkX graph based on recursive relationship traversal."""
# Initialize directed graph
G = nx.DiGraph()
# Collect descendants and ancestors to build the graph structure
descendants_data = fetch_relationships(node_id, direction="down")
ancestors_data = fetch_relationships(node_id, direction="up")
# Recursively add nodes and edges for both descendants and ancestors
add_nodes_and_edges(G, descendants_data)
add_nodes_and_edges(G, ancestors_data)
return G
def add_nodes_and_edges(G, node, visited=None):
"""Recursive function to add nodes and edges from a traversal hierarchy to a NetworkX graph."""
if visited is None:
visited = set()
node_id = node.get("node_id")
if not node_id or node_id in visited:
return
visited.add(node_id)
# Add node to graph
G.add_node(node_id, label=node_id)
# Process child (descendant) relationships
for child in node.get("descendants", []):
child_id = child.get("node_id")
relationship = child.get("relationship", "related_to")
G.add_edge(node_id, child_id, label=relationship)
add_nodes_and_edges(G, child, visited) # Recursive call for descendants
# Process parent (ancestor) relationships
for ancestor in node.get("ancestors", []):
ancestor_id = ancestor.get("node_id")
relationship = ancestor.get("relationship", "related_to")
G.add_edge(ancestor_id, node_id, label=relationship)
add_nodes_and_edges(G, ancestor, visited) # Recursive call for ancestors
def visualize_graph(G, title="Graph Structure and Relationships"):
"""Visualize the graph using matplotlib and networkx."""
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G)
# Draw nodes and labels
nx.draw_networkx_nodes(G, pos, node_size=3000, node_color="skyblue", alpha=0.8)
nx.draw_networkx_labels(G, pos, font_size=10, font_color="black")
# Draw edges with labels
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True)
edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red")
# Title and display options
plt.title(title)
plt.axis("off")
plt.show()
# Step 1: Load Graph (Specify the graph to load, e.g., PHSA/340B section)
print("\n--- Loading Graph ---")
graph_data = {"graph_file": "graphs/PHSA/phsa_sec_340b.json"}
response = requests.post(f"{base_url}/load_graph", json=graph_data)
print("Load Graph Response:", response.json())
# Step 2: Build and visualize the graph for 340B Program
print("\n--- Building Graph for Visualization ---")
G = build_graph_from_relationships("340B Program")
visualize_graph(G, title="340B Program - Inferred Contextual Relationships")