math_roots / app.py
thearn's picture
fixed label
135bcee
import streamlit as st
import asyncio
from typing import Dict, Any, Optional
from streamlit_agraph import agraph, Config
from src.network import make_payload, get_graph
from src.graph import build_tree_structure, create_hierarchical_view, tree_to_dot
def get_id_from_input(val: str) -> Optional[int]:
try:
return int(val)
except Exception:
return None
def display_tree_summary(graph: Dict[str, Any], root_id: int) -> None:
tree = build_tree_structure(graph, root_id)
if not tree:
return
max_depth = max(node["depth"] for node in tree.values()) if tree else 0
total_nodes = len(tree)
depth_counts: dict[int, int] = {}
for node in tree.values():
depth = node["depth"]
depth_counts[depth] = depth_counts.get(depth, 0) + 1
# display metrics in two columns to give more space
col1, col2 = st.columns(2)
with col1:
st.metric("Total Mathematicians", total_nodes)
st.metric("Generations Back", max_depth)
with col2:
root_name = tree.get(root_id, {}).get("name", "Unknown")
st.write("**Root Mathematician:**")
st.write(root_name)
def main():
st.title("Math Genealogy Ancestor Tree")
st.write("Interactive visualization of academic advisor relationships from the Mathematics Genealogy Project")
mathematicians = [
("Tristan Hearn", 162833),
("Alexander Grothendieck", 31245),
("Emmy Noether", 6967),
("David Hilbert", 7298),
("Sophie Germain", 55175),
("Carl Friedrich Gauss", 18231),
]
names = [f"{name} ({mid})" for name, mid in mathematicians]
default_index = 0 # Tristan Hearn
# initialize session state
if "mgp_id_str" not in st.session_state:
st.session_state["mgp_id_str"] = str(mathematicians[default_index][1])
if "graph_data" not in st.session_state:
st.session_state["graph_data"] = None
if "root_id" not in st.session_state:
st.session_state["root_id"] = None
# input section
st.subheader("Select Mathematician")
mgp_id_str = st.text_input(
"Enter MGP ID (integer):",
key="mgp_id_str",
help="You can type a custom ID or use the selection below."
)
def on_select():
st.session_state["mgp_id_str"] = str(mathematicians[st.session_state["mathematician_idx"]][1])
selected_idx = st.selectbox(
"Or select a mathematician:",
range(len(names)),
format_func=lambda i: names[i],
index=default_index,
key="mathematician_idx",
on_change=on_select,
)
progress_placeholder = st.empty()
# fetch data
run_btn = st.button("Fetch Ancestor Tree", type="primary")
if run_btn:
mgp_id = get_id_from_input(st.session_state["mgp_id_str"])
if mgp_id is None:
st.error("Please enter a valid integer MGP ID.")
return
payload = make_payload(mgp_id)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
def progress_cb(progress):
progress_placeholder.info(
f"Queued: {progress['queued']} | Fetching: {progress['fetching']} | Done: {progress['done']}"
)
async def runner():
graph = await get_graph(payload, progress_cb)
st.session_state["graph_data"] = graph
st.session_state["root_id"] = mgp_id
try:
loop.run_until_complete(runner())
progress_placeholder.success("Data fetched successfully!")
except Exception as e:
print(f"Error: {e}")
progress_placeholder.error(f"Error: {e}")
return
# display visualizations if data is available
if st.session_state["graph_data"] is not None:
graph = st.session_state["graph_data"]
root_id = st.session_state["root_id"]
# force sidebar open using JS injection
st.markdown(
"""
<script>
try {
window.parent.document.querySelector('section[data-testid="stSidebar"]').style.transform = "none";
} catch (e) {}
</script>
""",
unsafe_allow_html=True,
)
# sidebar timeline table
import pandas as pd
nodes = graph.get("nodes", {})
data = []
for node_id, node in nodes.items():
name = node.get("name", "")
year = node.get("year", None)
institution = node.get("institution", "")
# try to convert year to int for sorting, else None
try:
year_int = int(year)
except Exception:
year_int = None
data.append({"Name": name, "Year": year_int, "Institution": institution, "node_id": node_id})
df = pd.DataFrame(data)
df = df.dropna(subset=["Year"])
df = df.sort_values("Year", ascending=False)
st.sidebar.title("Timeline")
st.sidebar.dataframe(
df[["Year", "Name", "Institution"]],
use_container_width=True,
height=1000
)
st.divider()
# show summary
display_tree_summary(graph, root_id)
st.divider()
# export to pdf button
import io
from graphviz import Source
dot = tree_to_dot(graph)
pdf_bytes = None
try:
src = Source(dot)
pdf_bytes = src.pipe(format="pdf")
except Exception as e:
st.warning(f"Could not generate PDF: {e}")
if pdf_bytes:
st.download_button(
label="Download Graph as PDF",
data=pdf_bytes,
file_name="math_genealogy_tree.pdf",
mime="application/pdf"
)
# visualization options
st.subheader("Choose Visualization")
viz_option = st.radio(
"Select visualization type:",
["Interactive Hierarchical Tree", "Traditional Graph (Graphviz)"],
help="Different views for exploring the genealogy tree"
)
if viz_option == "Interactive Hierarchical Tree":
st.write("**Hierarchical Tree View** - Best for exploring direct lineages")
# depth filter
tree = build_tree_structure(graph, root_id)
max_available_depth = max(node["depth"] for node in tree.values()) if tree else 0
if max_available_depth > 0:
depth_filter = st.slider(
"Show generations back:",
min_value=0,
max_value=max_available_depth,
value=min(3, max_available_depth),
help="Limit the number of generations to display for better readability"
)
else:
depth_filter = 0
# create hierarchical view
nodes_list, edges_list = create_hierarchical_view(graph, root_id, depth_filter)
if nodes_list:
# configure for better dark mode compatibility
config = Config(
width=800,
height=600,
directed=True,
physics=True,
hierarchical=True,
nodeHighlightBehavior=True,
highlightColor="#F7A7A6",
collapsible=False,
# dark mode friendly settings
node={
"font": {
"color": "black", # ensure text is always black for readability
"size": 12,
"face": "arial"
},
"borderWidth": 2,
"borderWidthSelected": 3
}
)
selected = agraph(nodes=nodes_list, edges=edges_list, config=config)
selected_node_id = None
if selected and "id" in selected:
selected_node_id = selected["id"]
st.session_state["selected_node_id"] = selected_node_id
# debug output
if selected_node_id:
st.write(f"DEBUG: Selected node ID: {selected_node_id}")
else:
st.warning("No data to display with current filters.")
else: # Traditional Graph
st.write("**Traditional Graph View** - Standard graphviz layout")
st.graphviz_chart(dot)
# search functionality
st.divider()
st.subheader("Search Mathematicians")
nodes = graph.get("nodes", {})
search_term = st.text_input("Search by name:", placeholder="e.g., Gauss, Euler, Newton")
if search_term:
matches = []
for node_id, node in nodes.items():
name = node.get("name", "")
if search_term.lower() in name.lower():
year = node.get("year", "N/A")
institution = node.get("institution", "N/A")
matches.append({
"id": node_id,
"name": name,
"year": year,
"institution": institution
})
if matches:
st.write(f"Found {len(matches)} match(es):")
for match in matches[:10]: # limit to 10 results
st.write(f"• **{match['name']}** ({match['year']}) - {match['institution']} (ID: {match['id']})")
if len(matches) > 10:
st.write(f"... and {len(matches) - 10} more")
else:
st.write("No matches found.")
if __name__ == "__main__":
main()