thermal-mortality-model / mortality_graph.py
h3ir's picture
Upload mortality_graph.py with huggingface_hub
1cc16b2 verified
"""
Mortality Graph Construction for THRML Integration
=================================================
This module converts Morbid AI's MortalityRecord data structure into
THRML-compatible probabilistic graphical models.
"""
import jax
import jax.numpy as jnp
import networkx as nx
from typing import List, Dict, Tuple, Optional
import pandas as pd
from dataclasses import dataclass
from thrml.pgm import CategoricalNode, SpinNode
from thrml.block_management import Block
from thrml.factor import AbstractFactor
@dataclass
class MortalityRecord:
"""Morbid AI mortality record structure"""
country: str
year: int
sex: int # 1=male, 2=female, 3=both
age: int
deathRate: float # m(x)
deathProbability: float # q(x)
survivors: float # l(x)
deaths: float # d(x)
lifeExpectancy: float # e(x)
class MortalityGraphBuilder:
"""
Builds THRML-compatible probabilistic graphical models from mortality data.
This class creates heterogeneous graphs that capture complex interactions
between demographic factors (age, country, sex, year) and mortality outcomes.
"""
def __init__(self, mortality_data: List[MortalityRecord]):
"""
Initialize with mortality data.
Args:
mortality_data: List of MortalityRecord objects
"""
self.mortality_data = mortality_data
self.df = pd.DataFrame([
{
'country': record.country,
'year': record.year,
'sex': record.sex,
'age': record.age,
'death_rate': record.deathRate,
'death_probability': record.deathProbability,
'survivors': record.survivors,
'deaths': record.deaths,
'life_expectancy': record.lifeExpectancy
} for record in mortality_data
])
# Extract unique values for graph construction
self.countries = sorted(self.df['country'].unique())
self.years = sorted(self.df['year'].unique())
self.sexes = sorted(self.df['sex'].unique())
self.ages = sorted(self.df['age'].unique())
# Create node mappings
self._create_node_mappings()
def _create_node_mappings(self):
"""Create mappings from data values to graph nodes."""
self.country_nodes = {country: CategoricalNode() for country in self.countries}
self.year_nodes = {year: CategoricalNode() for year in self.years}
self.sex_nodes = {sex: SpinNode() for sex in self.sexes} # Binary-like representation
self.age_nodes = {age: CategoricalNode() for age in self.ages}
# Create outcome nodes for mortality metrics
self.life_expectancy_nodes = {}
self.death_probability_nodes = {}
# Discretize life expectancy and death probability for categorical representation
self.life_exp_bins = jnp.linspace(0, 100, 21) # 20 bins for life expectancy
self.death_prob_bins = jnp.linspace(0, 1, 11) # 10 bins for death probability
for i in range(len(self.life_exp_bins) - 1):
self.life_expectancy_nodes[i] = CategoricalNode()
for i in range(len(self.death_prob_bins) - 1):
self.death_probability_nodes[i] = CategoricalNode()
def build_mortality_graph(self) -> nx.Graph:
"""
Build NetworkX graph representing mortality factor interactions.
Returns:
NetworkX graph with nodes representing demographic factors
and edges representing interactions.
"""
G = nx.Graph()
# Add nodes with attributes
for country, node in self.country_nodes.items():
G.add_node(f"country_{country}", type="country", value=country, thrml_node=node)
for year, node in self.year_nodes.items():
G.add_node(f"year_{year}", type="year", value=year, thrml_node=node)
for sex, node in self.sex_nodes.items():
G.add_node(f"sex_{sex}", type="sex", value=sex, thrml_node=node)
for age, node in self.age_nodes.items():
G.add_node(f"age_{age}", type="age", value=age, thrml_node=node)
# Add outcome nodes
for bin_idx, node in self.life_expectancy_nodes.items():
G.add_node(f"life_exp_{bin_idx}", type="life_expectancy",
bin_idx=bin_idx, thrml_node=node)
for bin_idx, node in self.death_probability_nodes.items():
G.add_node(f"death_prob_{bin_idx}", type="death_probability",
bin_idx=bin_idx, thrml_node=node)
# Add edges representing factor interactions
self._add_demographic_interactions(G)
self._add_outcome_interactions(G)
return G
def _add_demographic_interactions(self, G: nx.Graph):
"""Add edges between demographic factor nodes."""
# Age-Sex interactions (biological mortality differences)
for age in self.ages:
for sex in self.sexes:
G.add_edge(f"age_{age}", f"sex_{sex}", interaction_type="age_sex")
# Country-Year interactions (temporal mortality trends by country)
for country in self.countries:
for year in self.years:
G.add_edge(f"country_{country}", f"year_{year}",
interaction_type="country_year")
# Age-Country interactions (demographic mortality patterns)
for age in self.ages[::5]: # Sample every 5th age to reduce complexity
for country in self.countries:
G.add_edge(f"age_{age}", f"country_{country}",
interaction_type="age_country")
def _add_outcome_interactions(self, G: nx.Graph):
"""Add edges between demographic factors and mortality outcomes."""
# Connect age groups to life expectancy bins
for age in self.ages[::10]: # Sample to reduce complexity
for le_bin in range(len(self.life_expectancy_nodes)):
G.add_edge(f"age_{age}", f"life_exp_{le_bin}",
interaction_type="age_life_expectancy")
# Connect demographic factors to death probability
for country in self.countries:
for dp_bin in range(len(self.death_probability_nodes)):
G.add_edge(f"country_{country}", f"death_prob_{dp_bin}",
interaction_type="country_death_probability")
def create_sampling_blocks(self, strategy: str = "two_color") -> List[Block]:
"""
Create sampling blocks for THRML block Gibbs sampling.
Args:
strategy: Blocking strategy ("two_color", "demographic", "outcome")
Returns:
List of Block objects for THRML sampling
"""
all_nodes = []
# Collect all THRML nodes
all_nodes.extend(list(self.country_nodes.values()))
all_nodes.extend(list(self.year_nodes.values()))
all_nodes.extend(list(self.sex_nodes.values()))
all_nodes.extend(list(self.age_nodes.values()))
all_nodes.extend(list(self.life_expectancy_nodes.values()))
all_nodes.extend(list(self.death_probability_nodes.values()))
if strategy == "two_color":
# Simple two-color blocking with homogeneous node types
categorical_nodes = (list(self.country_nodes.values()) +
list(self.year_nodes.values()) +
list(self.age_nodes.values()) +
list(self.life_expectancy_nodes.values()) +
list(self.death_probability_nodes.values()))
spin_nodes = list(self.sex_nodes.values())
# Create separate blocks for different node types
if categorical_nodes and spin_nodes:
return [Block(categorical_nodes), Block(spin_nodes)]
elif categorical_nodes:
return [Block(categorical_nodes[::2]), Block(categorical_nodes[1::2])]
else:
return [Block(spin_nodes)]
elif strategy == "demographic":
# Block by demographic factor types - separate by node type
categorical_demographic = (list(self.country_nodes.values()) +
list(self.year_nodes.values()) +
list(self.age_nodes.values()))
spin_demographic = list(self.sex_nodes.values())
outcome_nodes = (list(self.life_expectancy_nodes.values()) +
list(self.death_probability_nodes.values()))
blocks = []
if categorical_demographic:
blocks.append(Block(categorical_demographic))
if spin_demographic:
blocks.append(Block(spin_demographic))
if outcome_nodes:
blocks.append(Block(outcome_nodes))
return blocks
elif strategy == "outcome":
# Block by outcome type - keep node types separate
life_exp_nodes = list(self.life_expectancy_nodes.values())
death_prob_nodes = list(self.death_probability_nodes.values())
categorical_demo = (list(self.country_nodes.values()) +
list(self.year_nodes.values()) +
list(self.age_nodes.values()))
spin_demo = list(self.sex_nodes.values())
blocks = []
if life_exp_nodes:
blocks.append(Block(life_exp_nodes))
if death_prob_nodes:
blocks.append(Block(death_prob_nodes))
if categorical_demo:
blocks.append(Block(categorical_demo))
if spin_demo:
blocks.append(Block(spin_demo))
return blocks
else:
raise ValueError(f"Unknown blocking strategy: {strategy}")
def create_interaction_factors(self) -> List[Dict]:
"""
Create interaction factors for the energy-based model.
Returns:
List of simplified factor objects representing
pairwise and higher-order interactions
"""
factors = []
# Age-Sex interaction factors
for age in self.ages[::10]: # Sample to manage complexity
for sex in self.sexes:
age_node = self.age_nodes[age]
sex_node = self.sex_nodes[sex]
# Create interaction matrix based on mortality data
interaction_strength = self._compute_age_sex_interaction(age, sex)
# For now, create a simplified factor representation
# In a full implementation, would create proper THRML factors
factors.append({
'nodes': [age_node, sex_node],
'strength': interaction_strength,
'type': 'age_sex'
})
# Country-Year interaction factors
for country in self.countries:
for year in self.years[::2]: # Sample years
country_node = self.country_nodes[country]
year_node = self.year_nodes[year]
interaction_strength = self._compute_country_year_interaction(country, year)
# Simplified factor representation
factors.append({
'nodes': [country_node, year_node],
'strength': interaction_strength,
'type': 'country_year'
})
return factors
def _compute_age_sex_interaction(self, age: int, sex: int) -> jnp.ndarray:
"""Compute interaction strength between age and sex from data."""
# Filter data for this age-sex combination
subset = self.df[(self.df['age'] == age) & (self.df['sex'] == sex)]
if len(subset) == 0:
# Default weak interaction if no data
return jnp.array([[0.1, 0.0], [0.0, 0.1]])
# Use death rate as proxy for interaction strength
avg_death_rate = subset['death_rate'].mean()
# Create 2x2 interaction matrix
# Higher death rates = stronger interaction
strength = min(avg_death_rate * 10, 1.0) # Cap at 1.0
return jnp.array([[strength, -strength/2], [-strength/2, strength]])
def _compute_country_year_interaction(self, country: str, year: int) -> jnp.ndarray:
"""Compute interaction strength between country and year."""
subset = self.df[(self.df['country'] == country) & (self.df['year'] == year)]
if len(subset) == 0:
return jnp.array([[0.1, 0.0], [0.0, 0.1]])
# Use life expectancy variance as interaction strength
life_exp_var = subset['life_expectancy'].var()
strength = min(life_exp_var / 100, 1.0) # Normalize and cap
return jnp.array([[strength, -strength/3], [-strength/3, strength]])
def get_mortality_prediction_nodes(self,
age: int,
country: str,
sex: int) -> Dict[str, any]:
"""
Get the relevant nodes for mortality prediction given demographics.
Args:
age: Age value
country: Country name
sex: Sex value (1=male, 2=female, 3=both)
Returns:
Dictionary mapping node types to THRML nodes
"""
return {
'age_node': self.age_nodes.get(age),
'country_node': self.country_nodes.get(country),
'sex_node': self.sex_nodes.get(sex),
'life_expectancy_nodes': self.life_expectancy_nodes,
'death_probability_nodes': self.death_probability_nodes
}
def discretize_life_expectancy(self, life_exp: float) -> int:
"""Convert continuous life expectancy to discrete bin index."""
return int(jnp.digitize(life_exp, self.life_exp_bins)) - 1
def discretize_death_probability(self, death_prob: float) -> int:
"""Convert continuous death probability to discrete bin index."""
return int(jnp.digitize(death_prob, self.death_prob_bins)) - 1
def continuous_from_bin(self, bin_idx: int, bin_type: str) -> float:
"""Convert bin index back to continuous value (bin center)."""
if bin_type == "life_expectancy":
if 0 <= bin_idx < len(self.life_exp_bins) - 1:
return (self.life_exp_bins[bin_idx] + self.life_exp_bins[bin_idx + 1]) / 2
elif bin_type == "death_probability":
if 0 <= bin_idx < len(self.death_prob_bins) - 1:
return (self.death_prob_bins[bin_idx] + self.death_prob_bins[bin_idx + 1]) / 2
return 0.0