|
|
""" |
|
|
Mortality Graph Construction for THRML Integration |
|
|
================================================= |
|
|
|
|
|
This module converts Morbid AI's MortalityRecord data structure into |
|
|
THRML-compatible probabilistic graphical models. |
|
|
""" |
|
|
|
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import networkx as nx |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
import pandas as pd |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from thrml.pgm import CategoricalNode, SpinNode |
|
|
from thrml.block_management import Block |
|
|
from thrml.factor import AbstractFactor |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MortalityRecord: |
|
|
"""Morbid AI mortality record structure""" |
|
|
country: str |
|
|
year: int |
|
|
sex: int |
|
|
age: int |
|
|
deathRate: float |
|
|
deathProbability: float |
|
|
survivors: float |
|
|
deaths: float |
|
|
lifeExpectancy: float |
|
|
|
|
|
|
|
|
class MortalityGraphBuilder: |
|
|
""" |
|
|
Builds THRML-compatible probabilistic graphical models from mortality data. |
|
|
|
|
|
This class creates heterogeneous graphs that capture complex interactions |
|
|
between demographic factors (age, country, sex, year) and mortality outcomes. |
|
|
""" |
|
|
|
|
|
def __init__(self, mortality_data: List[MortalityRecord]): |
|
|
""" |
|
|
Initialize with mortality data. |
|
|
|
|
|
Args: |
|
|
mortality_data: List of MortalityRecord objects |
|
|
""" |
|
|
self.mortality_data = mortality_data |
|
|
self.df = pd.DataFrame([ |
|
|
{ |
|
|
'country': record.country, |
|
|
'year': record.year, |
|
|
'sex': record.sex, |
|
|
'age': record.age, |
|
|
'death_rate': record.deathRate, |
|
|
'death_probability': record.deathProbability, |
|
|
'survivors': record.survivors, |
|
|
'deaths': record.deaths, |
|
|
'life_expectancy': record.lifeExpectancy |
|
|
} for record in mortality_data |
|
|
]) |
|
|
|
|
|
|
|
|
self.countries = sorted(self.df['country'].unique()) |
|
|
self.years = sorted(self.df['year'].unique()) |
|
|
self.sexes = sorted(self.df['sex'].unique()) |
|
|
self.ages = sorted(self.df['age'].unique()) |
|
|
|
|
|
|
|
|
self._create_node_mappings() |
|
|
|
|
|
def _create_node_mappings(self): |
|
|
"""Create mappings from data values to graph nodes.""" |
|
|
self.country_nodes = {country: CategoricalNode() for country in self.countries} |
|
|
self.year_nodes = {year: CategoricalNode() for year in self.years} |
|
|
self.sex_nodes = {sex: SpinNode() for sex in self.sexes} |
|
|
self.age_nodes = {age: CategoricalNode() for age in self.ages} |
|
|
|
|
|
|
|
|
self.life_expectancy_nodes = {} |
|
|
self.death_probability_nodes = {} |
|
|
|
|
|
|
|
|
self.life_exp_bins = jnp.linspace(0, 100, 21) |
|
|
self.death_prob_bins = jnp.linspace(0, 1, 11) |
|
|
|
|
|
for i in range(len(self.life_exp_bins) - 1): |
|
|
self.life_expectancy_nodes[i] = CategoricalNode() |
|
|
|
|
|
for i in range(len(self.death_prob_bins) - 1): |
|
|
self.death_probability_nodes[i] = CategoricalNode() |
|
|
|
|
|
def build_mortality_graph(self) -> nx.Graph: |
|
|
""" |
|
|
Build NetworkX graph representing mortality factor interactions. |
|
|
|
|
|
Returns: |
|
|
NetworkX graph with nodes representing demographic factors |
|
|
and edges representing interactions. |
|
|
""" |
|
|
G = nx.Graph() |
|
|
|
|
|
|
|
|
for country, node in self.country_nodes.items(): |
|
|
G.add_node(f"country_{country}", type="country", value=country, thrml_node=node) |
|
|
|
|
|
for year, node in self.year_nodes.items(): |
|
|
G.add_node(f"year_{year}", type="year", value=year, thrml_node=node) |
|
|
|
|
|
for sex, node in self.sex_nodes.items(): |
|
|
G.add_node(f"sex_{sex}", type="sex", value=sex, thrml_node=node) |
|
|
|
|
|
for age, node in self.age_nodes.items(): |
|
|
G.add_node(f"age_{age}", type="age", value=age, thrml_node=node) |
|
|
|
|
|
|
|
|
for bin_idx, node in self.life_expectancy_nodes.items(): |
|
|
G.add_node(f"life_exp_{bin_idx}", type="life_expectancy", |
|
|
bin_idx=bin_idx, thrml_node=node) |
|
|
|
|
|
for bin_idx, node in self.death_probability_nodes.items(): |
|
|
G.add_node(f"death_prob_{bin_idx}", type="death_probability", |
|
|
bin_idx=bin_idx, thrml_node=node) |
|
|
|
|
|
|
|
|
self._add_demographic_interactions(G) |
|
|
self._add_outcome_interactions(G) |
|
|
|
|
|
return G |
|
|
|
|
|
def _add_demographic_interactions(self, G: nx.Graph): |
|
|
"""Add edges between demographic factor nodes.""" |
|
|
|
|
|
for age in self.ages: |
|
|
for sex in self.sexes: |
|
|
G.add_edge(f"age_{age}", f"sex_{sex}", interaction_type="age_sex") |
|
|
|
|
|
|
|
|
for country in self.countries: |
|
|
for year in self.years: |
|
|
G.add_edge(f"country_{country}", f"year_{year}", |
|
|
interaction_type="country_year") |
|
|
|
|
|
|
|
|
for age in self.ages[::5]: |
|
|
for country in self.countries: |
|
|
G.add_edge(f"age_{age}", f"country_{country}", |
|
|
interaction_type="age_country") |
|
|
|
|
|
def _add_outcome_interactions(self, G: nx.Graph): |
|
|
"""Add edges between demographic factors and mortality outcomes.""" |
|
|
|
|
|
for age in self.ages[::10]: |
|
|
for le_bin in range(len(self.life_expectancy_nodes)): |
|
|
G.add_edge(f"age_{age}", f"life_exp_{le_bin}", |
|
|
interaction_type="age_life_expectancy") |
|
|
|
|
|
|
|
|
for country in self.countries: |
|
|
for dp_bin in range(len(self.death_probability_nodes)): |
|
|
G.add_edge(f"country_{country}", f"death_prob_{dp_bin}", |
|
|
interaction_type="country_death_probability") |
|
|
|
|
|
def create_sampling_blocks(self, strategy: str = "two_color") -> List[Block]: |
|
|
""" |
|
|
Create sampling blocks for THRML block Gibbs sampling. |
|
|
|
|
|
Args: |
|
|
strategy: Blocking strategy ("two_color", "demographic", "outcome") |
|
|
|
|
|
Returns: |
|
|
List of Block objects for THRML sampling |
|
|
""" |
|
|
all_nodes = [] |
|
|
|
|
|
|
|
|
all_nodes.extend(list(self.country_nodes.values())) |
|
|
all_nodes.extend(list(self.year_nodes.values())) |
|
|
all_nodes.extend(list(self.sex_nodes.values())) |
|
|
all_nodes.extend(list(self.age_nodes.values())) |
|
|
all_nodes.extend(list(self.life_expectancy_nodes.values())) |
|
|
all_nodes.extend(list(self.death_probability_nodes.values())) |
|
|
|
|
|
if strategy == "two_color": |
|
|
|
|
|
categorical_nodes = (list(self.country_nodes.values()) + |
|
|
list(self.year_nodes.values()) + |
|
|
list(self.age_nodes.values()) + |
|
|
list(self.life_expectancy_nodes.values()) + |
|
|
list(self.death_probability_nodes.values())) |
|
|
spin_nodes = list(self.sex_nodes.values()) |
|
|
|
|
|
|
|
|
if categorical_nodes and spin_nodes: |
|
|
return [Block(categorical_nodes), Block(spin_nodes)] |
|
|
elif categorical_nodes: |
|
|
return [Block(categorical_nodes[::2]), Block(categorical_nodes[1::2])] |
|
|
else: |
|
|
return [Block(spin_nodes)] |
|
|
|
|
|
elif strategy == "demographic": |
|
|
|
|
|
categorical_demographic = (list(self.country_nodes.values()) + |
|
|
list(self.year_nodes.values()) + |
|
|
list(self.age_nodes.values())) |
|
|
spin_demographic = list(self.sex_nodes.values()) |
|
|
outcome_nodes = (list(self.life_expectancy_nodes.values()) + |
|
|
list(self.death_probability_nodes.values())) |
|
|
|
|
|
blocks = [] |
|
|
if categorical_demographic: |
|
|
blocks.append(Block(categorical_demographic)) |
|
|
if spin_demographic: |
|
|
blocks.append(Block(spin_demographic)) |
|
|
if outcome_nodes: |
|
|
blocks.append(Block(outcome_nodes)) |
|
|
return blocks |
|
|
|
|
|
elif strategy == "outcome": |
|
|
|
|
|
life_exp_nodes = list(self.life_expectancy_nodes.values()) |
|
|
death_prob_nodes = list(self.death_probability_nodes.values()) |
|
|
categorical_demo = (list(self.country_nodes.values()) + |
|
|
list(self.year_nodes.values()) + |
|
|
list(self.age_nodes.values())) |
|
|
spin_demo = list(self.sex_nodes.values()) |
|
|
|
|
|
blocks = [] |
|
|
if life_exp_nodes: |
|
|
blocks.append(Block(life_exp_nodes)) |
|
|
if death_prob_nodes: |
|
|
blocks.append(Block(death_prob_nodes)) |
|
|
if categorical_demo: |
|
|
blocks.append(Block(categorical_demo)) |
|
|
if spin_demo: |
|
|
blocks.append(Block(spin_demo)) |
|
|
return blocks |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown blocking strategy: {strategy}") |
|
|
|
|
|
def create_interaction_factors(self) -> List[Dict]: |
|
|
""" |
|
|
Create interaction factors for the energy-based model. |
|
|
|
|
|
Returns: |
|
|
List of simplified factor objects representing |
|
|
pairwise and higher-order interactions |
|
|
""" |
|
|
factors = [] |
|
|
|
|
|
|
|
|
for age in self.ages[::10]: |
|
|
for sex in self.sexes: |
|
|
age_node = self.age_nodes[age] |
|
|
sex_node = self.sex_nodes[sex] |
|
|
|
|
|
|
|
|
interaction_strength = self._compute_age_sex_interaction(age, sex) |
|
|
|
|
|
|
|
|
factors.append({ |
|
|
'nodes': [age_node, sex_node], |
|
|
'strength': interaction_strength, |
|
|
'type': 'age_sex' |
|
|
}) |
|
|
|
|
|
|
|
|
for country in self.countries: |
|
|
for year in self.years[::2]: |
|
|
country_node = self.country_nodes[country] |
|
|
year_node = self.year_nodes[year] |
|
|
|
|
|
interaction_strength = self._compute_country_year_interaction(country, year) |
|
|
|
|
|
factors.append({ |
|
|
'nodes': [country_node, year_node], |
|
|
'strength': interaction_strength, |
|
|
'type': 'country_year' |
|
|
}) |
|
|
|
|
|
return factors |
|
|
|
|
|
def _compute_age_sex_interaction(self, age: int, sex: int) -> jnp.ndarray: |
|
|
"""Compute interaction strength between age and sex from data.""" |
|
|
|
|
|
subset = self.df[(self.df['age'] == age) & (self.df['sex'] == sex)] |
|
|
|
|
|
if len(subset) == 0: |
|
|
|
|
|
return jnp.array([[0.1, 0.0], [0.0, 0.1]]) |
|
|
|
|
|
|
|
|
avg_death_rate = subset['death_rate'].mean() |
|
|
|
|
|
|
|
|
|
|
|
strength = min(avg_death_rate * 10, 1.0) |
|
|
return jnp.array([[strength, -strength/2], [-strength/2, strength]]) |
|
|
|
|
|
def _compute_country_year_interaction(self, country: str, year: int) -> jnp.ndarray: |
|
|
"""Compute interaction strength between country and year.""" |
|
|
subset = self.df[(self.df['country'] == country) & (self.df['year'] == year)] |
|
|
|
|
|
if len(subset) == 0: |
|
|
return jnp.array([[0.1, 0.0], [0.0, 0.1]]) |
|
|
|
|
|
|
|
|
life_exp_var = subset['life_expectancy'].var() |
|
|
strength = min(life_exp_var / 100, 1.0) |
|
|
|
|
|
return jnp.array([[strength, -strength/3], [-strength/3, strength]]) |
|
|
|
|
|
def get_mortality_prediction_nodes(self, |
|
|
age: int, |
|
|
country: str, |
|
|
sex: int) -> Dict[str, any]: |
|
|
""" |
|
|
Get the relevant nodes for mortality prediction given demographics. |
|
|
|
|
|
Args: |
|
|
age: Age value |
|
|
country: Country name |
|
|
sex: Sex value (1=male, 2=female, 3=both) |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping node types to THRML nodes |
|
|
""" |
|
|
return { |
|
|
'age_node': self.age_nodes.get(age), |
|
|
'country_node': self.country_nodes.get(country), |
|
|
'sex_node': self.sex_nodes.get(sex), |
|
|
'life_expectancy_nodes': self.life_expectancy_nodes, |
|
|
'death_probability_nodes': self.death_probability_nodes |
|
|
} |
|
|
|
|
|
def discretize_life_expectancy(self, life_exp: float) -> int: |
|
|
"""Convert continuous life expectancy to discrete bin index.""" |
|
|
return int(jnp.digitize(life_exp, self.life_exp_bins)) - 1 |
|
|
|
|
|
def discretize_death_probability(self, death_prob: float) -> int: |
|
|
"""Convert continuous death probability to discrete bin index.""" |
|
|
return int(jnp.digitize(death_prob, self.death_prob_bins)) - 1 |
|
|
|
|
|
def continuous_from_bin(self, bin_idx: int, bin_type: str) -> float: |
|
|
"""Convert bin index back to continuous value (bin center).""" |
|
|
if bin_type == "life_expectancy": |
|
|
if 0 <= bin_idx < len(self.life_exp_bins) - 1: |
|
|
return (self.life_exp_bins[bin_idx] + self.life_exp_bins[bin_idx + 1]) / 2 |
|
|
elif bin_type == "death_probability": |
|
|
if 0 <= bin_idx < len(self.death_prob_bins) - 1: |
|
|
return (self.death_prob_bins[bin_idx] + self.death_prob_bins[bin_idx + 1]) / 2 |
|
|
return 0.0 |