Upload mortality_graph.py with huggingface_hub
Browse files- mortality_graph.py +355 -0
mortality_graph.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mortality Graph Construction for THRML Integration
|
| 3 |
+
=================================================
|
| 4 |
+
|
| 5 |
+
This module converts Morbid AI's MortalityRecord data structure into
|
| 6 |
+
THRML-compatible probabilistic graphical models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import jax
|
| 10 |
+
import jax.numpy as jnp
|
| 11 |
+
import networkx as nx
|
| 12 |
+
from typing import List, Dict, Tuple, Optional
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
from thrml.pgm import CategoricalNode, SpinNode
|
| 17 |
+
from thrml.block_management import Block
|
| 18 |
+
from thrml.factor import AbstractFactor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class MortalityRecord:
|
| 23 |
+
"""Morbid AI mortality record structure"""
|
| 24 |
+
country: str
|
| 25 |
+
year: int
|
| 26 |
+
sex: int # 1=male, 2=female, 3=both
|
| 27 |
+
age: int
|
| 28 |
+
deathRate: float # m(x)
|
| 29 |
+
deathProbability: float # q(x)
|
| 30 |
+
survivors: float # l(x)
|
| 31 |
+
deaths: float # d(x)
|
| 32 |
+
lifeExpectancy: float # e(x)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MortalityGraphBuilder:
|
| 36 |
+
"""
|
| 37 |
+
Builds THRML-compatible probabilistic graphical models from mortality data.
|
| 38 |
+
|
| 39 |
+
This class creates heterogeneous graphs that capture complex interactions
|
| 40 |
+
between demographic factors (age, country, sex, year) and mortality outcomes.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, mortality_data: List[MortalityRecord]):
|
| 44 |
+
"""
|
| 45 |
+
Initialize with mortality data.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
mortality_data: List of MortalityRecord objects
|
| 49 |
+
"""
|
| 50 |
+
self.mortality_data = mortality_data
|
| 51 |
+
self.df = pd.DataFrame([
|
| 52 |
+
{
|
| 53 |
+
'country': record.country,
|
| 54 |
+
'year': record.year,
|
| 55 |
+
'sex': record.sex,
|
| 56 |
+
'age': record.age,
|
| 57 |
+
'death_rate': record.deathRate,
|
| 58 |
+
'death_probability': record.deathProbability,
|
| 59 |
+
'survivors': record.survivors,
|
| 60 |
+
'deaths': record.deaths,
|
| 61 |
+
'life_expectancy': record.lifeExpectancy
|
| 62 |
+
} for record in mortality_data
|
| 63 |
+
])
|
| 64 |
+
|
| 65 |
+
# Extract unique values for graph construction
|
| 66 |
+
self.countries = sorted(self.df['country'].unique())
|
| 67 |
+
self.years = sorted(self.df['year'].unique())
|
| 68 |
+
self.sexes = sorted(self.df['sex'].unique())
|
| 69 |
+
self.ages = sorted(self.df['age'].unique())
|
| 70 |
+
|
| 71 |
+
# Create node mappings
|
| 72 |
+
self._create_node_mappings()
|
| 73 |
+
|
| 74 |
+
def _create_node_mappings(self):
|
| 75 |
+
"""Create mappings from data values to graph nodes."""
|
| 76 |
+
self.country_nodes = {country: CategoricalNode() for country in self.countries}
|
| 77 |
+
self.year_nodes = {year: CategoricalNode() for year in self.years}
|
| 78 |
+
self.sex_nodes = {sex: SpinNode() for sex in self.sexes} # Binary-like representation
|
| 79 |
+
self.age_nodes = {age: CategoricalNode() for age in self.ages}
|
| 80 |
+
|
| 81 |
+
# Create outcome nodes for mortality metrics
|
| 82 |
+
self.life_expectancy_nodes = {}
|
| 83 |
+
self.death_probability_nodes = {}
|
| 84 |
+
|
| 85 |
+
# Discretize life expectancy and death probability for categorical representation
|
| 86 |
+
self.life_exp_bins = jnp.linspace(0, 100, 21) # 20 bins for life expectancy
|
| 87 |
+
self.death_prob_bins = jnp.linspace(0, 1, 11) # 10 bins for death probability
|
| 88 |
+
|
| 89 |
+
for i in range(len(self.life_exp_bins) - 1):
|
| 90 |
+
self.life_expectancy_nodes[i] = CategoricalNode()
|
| 91 |
+
|
| 92 |
+
for i in range(len(self.death_prob_bins) - 1):
|
| 93 |
+
self.death_probability_nodes[i] = CategoricalNode()
|
| 94 |
+
|
| 95 |
+
def build_mortality_graph(self) -> nx.Graph:
|
| 96 |
+
"""
|
| 97 |
+
Build NetworkX graph representing mortality factor interactions.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
NetworkX graph with nodes representing demographic factors
|
| 101 |
+
and edges representing interactions.
|
| 102 |
+
"""
|
| 103 |
+
G = nx.Graph()
|
| 104 |
+
|
| 105 |
+
# Add nodes with attributes
|
| 106 |
+
for country, node in self.country_nodes.items():
|
| 107 |
+
G.add_node(f"country_{country}", type="country", value=country, thrml_node=node)
|
| 108 |
+
|
| 109 |
+
for year, node in self.year_nodes.items():
|
| 110 |
+
G.add_node(f"year_{year}", type="year", value=year, thrml_node=node)
|
| 111 |
+
|
| 112 |
+
for sex, node in self.sex_nodes.items():
|
| 113 |
+
G.add_node(f"sex_{sex}", type="sex", value=sex, thrml_node=node)
|
| 114 |
+
|
| 115 |
+
for age, node in self.age_nodes.items():
|
| 116 |
+
G.add_node(f"age_{age}", type="age", value=age, thrml_node=node)
|
| 117 |
+
|
| 118 |
+
# Add outcome nodes
|
| 119 |
+
for bin_idx, node in self.life_expectancy_nodes.items():
|
| 120 |
+
G.add_node(f"life_exp_{bin_idx}", type="life_expectancy",
|
| 121 |
+
bin_idx=bin_idx, thrml_node=node)
|
| 122 |
+
|
| 123 |
+
for bin_idx, node in self.death_probability_nodes.items():
|
| 124 |
+
G.add_node(f"death_prob_{bin_idx}", type="death_probability",
|
| 125 |
+
bin_idx=bin_idx, thrml_node=node)
|
| 126 |
+
|
| 127 |
+
# Add edges representing factor interactions
|
| 128 |
+
self._add_demographic_interactions(G)
|
| 129 |
+
self._add_outcome_interactions(G)
|
| 130 |
+
|
| 131 |
+
return G
|
| 132 |
+
|
| 133 |
+
def _add_demographic_interactions(self, G: nx.Graph):
|
| 134 |
+
"""Add edges between demographic factor nodes."""
|
| 135 |
+
# Age-Sex interactions (biological mortality differences)
|
| 136 |
+
for age in self.ages:
|
| 137 |
+
for sex in self.sexes:
|
| 138 |
+
G.add_edge(f"age_{age}", f"sex_{sex}", interaction_type="age_sex")
|
| 139 |
+
|
| 140 |
+
# Country-Year interactions (temporal mortality trends by country)
|
| 141 |
+
for country in self.countries:
|
| 142 |
+
for year in self.years:
|
| 143 |
+
G.add_edge(f"country_{country}", f"year_{year}",
|
| 144 |
+
interaction_type="country_year")
|
| 145 |
+
|
| 146 |
+
# Age-Country interactions (demographic mortality patterns)
|
| 147 |
+
for age in self.ages[::5]: # Sample every 5th age to reduce complexity
|
| 148 |
+
for country in self.countries:
|
| 149 |
+
G.add_edge(f"age_{age}", f"country_{country}",
|
| 150 |
+
interaction_type="age_country")
|
| 151 |
+
|
| 152 |
+
def _add_outcome_interactions(self, G: nx.Graph):
|
| 153 |
+
"""Add edges between demographic factors and mortality outcomes."""
|
| 154 |
+
# Connect age groups to life expectancy bins
|
| 155 |
+
for age in self.ages[::10]: # Sample to reduce complexity
|
| 156 |
+
for le_bin in range(len(self.life_expectancy_nodes)):
|
| 157 |
+
G.add_edge(f"age_{age}", f"life_exp_{le_bin}",
|
| 158 |
+
interaction_type="age_life_expectancy")
|
| 159 |
+
|
| 160 |
+
# Connect demographic factors to death probability
|
| 161 |
+
for country in self.countries:
|
| 162 |
+
for dp_bin in range(len(self.death_probability_nodes)):
|
| 163 |
+
G.add_edge(f"country_{country}", f"death_prob_{dp_bin}",
|
| 164 |
+
interaction_type="country_death_probability")
|
| 165 |
+
|
| 166 |
+
def create_sampling_blocks(self, strategy: str = "two_color") -> List[Block]:
|
| 167 |
+
"""
|
| 168 |
+
Create sampling blocks for THRML block Gibbs sampling.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
strategy: Blocking strategy ("two_color", "demographic", "outcome")
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
List of Block objects for THRML sampling
|
| 175 |
+
"""
|
| 176 |
+
all_nodes = []
|
| 177 |
+
|
| 178 |
+
# Collect all THRML nodes
|
| 179 |
+
all_nodes.extend(list(self.country_nodes.values()))
|
| 180 |
+
all_nodes.extend(list(self.year_nodes.values()))
|
| 181 |
+
all_nodes.extend(list(self.sex_nodes.values()))
|
| 182 |
+
all_nodes.extend(list(self.age_nodes.values()))
|
| 183 |
+
all_nodes.extend(list(self.life_expectancy_nodes.values()))
|
| 184 |
+
all_nodes.extend(list(self.death_probability_nodes.values()))
|
| 185 |
+
|
| 186 |
+
if strategy == "two_color":
|
| 187 |
+
# Simple two-color blocking with homogeneous node types
|
| 188 |
+
categorical_nodes = (list(self.country_nodes.values()) +
|
| 189 |
+
list(self.year_nodes.values()) +
|
| 190 |
+
list(self.age_nodes.values()) +
|
| 191 |
+
list(self.life_expectancy_nodes.values()) +
|
| 192 |
+
list(self.death_probability_nodes.values()))
|
| 193 |
+
spin_nodes = list(self.sex_nodes.values())
|
| 194 |
+
|
| 195 |
+
# Create separate blocks for different node types
|
| 196 |
+
if categorical_nodes and spin_nodes:
|
| 197 |
+
return [Block(categorical_nodes), Block(spin_nodes)]
|
| 198 |
+
elif categorical_nodes:
|
| 199 |
+
return [Block(categorical_nodes[::2]), Block(categorical_nodes[1::2])]
|
| 200 |
+
else:
|
| 201 |
+
return [Block(spin_nodes)]
|
| 202 |
+
|
| 203 |
+
elif strategy == "demographic":
|
| 204 |
+
# Block by demographic factor types - separate by node type
|
| 205 |
+
categorical_demographic = (list(self.country_nodes.values()) +
|
| 206 |
+
list(self.year_nodes.values()) +
|
| 207 |
+
list(self.age_nodes.values()))
|
| 208 |
+
spin_demographic = list(self.sex_nodes.values())
|
| 209 |
+
outcome_nodes = (list(self.life_expectancy_nodes.values()) +
|
| 210 |
+
list(self.death_probability_nodes.values()))
|
| 211 |
+
|
| 212 |
+
blocks = []
|
| 213 |
+
if categorical_demographic:
|
| 214 |
+
blocks.append(Block(categorical_demographic))
|
| 215 |
+
if spin_demographic:
|
| 216 |
+
blocks.append(Block(spin_demographic))
|
| 217 |
+
if outcome_nodes:
|
| 218 |
+
blocks.append(Block(outcome_nodes))
|
| 219 |
+
return blocks
|
| 220 |
+
|
| 221 |
+
elif strategy == "outcome":
|
| 222 |
+
# Block by outcome type - keep node types separate
|
| 223 |
+
life_exp_nodes = list(self.life_expectancy_nodes.values())
|
| 224 |
+
death_prob_nodes = list(self.death_probability_nodes.values())
|
| 225 |
+
categorical_demo = (list(self.country_nodes.values()) +
|
| 226 |
+
list(self.year_nodes.values()) +
|
| 227 |
+
list(self.age_nodes.values()))
|
| 228 |
+
spin_demo = list(self.sex_nodes.values())
|
| 229 |
+
|
| 230 |
+
blocks = []
|
| 231 |
+
if life_exp_nodes:
|
| 232 |
+
blocks.append(Block(life_exp_nodes))
|
| 233 |
+
if death_prob_nodes:
|
| 234 |
+
blocks.append(Block(death_prob_nodes))
|
| 235 |
+
if categorical_demo:
|
| 236 |
+
blocks.append(Block(categorical_demo))
|
| 237 |
+
if spin_demo:
|
| 238 |
+
blocks.append(Block(spin_demo))
|
| 239 |
+
return blocks
|
| 240 |
+
|
| 241 |
+
else:
|
| 242 |
+
raise ValueError(f"Unknown blocking strategy: {strategy}")
|
| 243 |
+
|
| 244 |
+
def create_interaction_factors(self) -> List[Dict]:
|
| 245 |
+
"""
|
| 246 |
+
Create interaction factors for the energy-based model.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
List of simplified factor objects representing
|
| 250 |
+
pairwise and higher-order interactions
|
| 251 |
+
"""
|
| 252 |
+
factors = []
|
| 253 |
+
|
| 254 |
+
# Age-Sex interaction factors
|
| 255 |
+
for age in self.ages[::10]: # Sample to manage complexity
|
| 256 |
+
for sex in self.sexes:
|
| 257 |
+
age_node = self.age_nodes[age]
|
| 258 |
+
sex_node = self.sex_nodes[sex]
|
| 259 |
+
|
| 260 |
+
# Create interaction matrix based on mortality data
|
| 261 |
+
interaction_strength = self._compute_age_sex_interaction(age, sex)
|
| 262 |
+
# For now, create a simplified factor representation
|
| 263 |
+
# In a full implementation, would create proper THRML factors
|
| 264 |
+
factors.append({
|
| 265 |
+
'nodes': [age_node, sex_node],
|
| 266 |
+
'strength': interaction_strength,
|
| 267 |
+
'type': 'age_sex'
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
# Country-Year interaction factors
|
| 271 |
+
for country in self.countries:
|
| 272 |
+
for year in self.years[::2]: # Sample years
|
| 273 |
+
country_node = self.country_nodes[country]
|
| 274 |
+
year_node = self.year_nodes[year]
|
| 275 |
+
|
| 276 |
+
interaction_strength = self._compute_country_year_interaction(country, year)
|
| 277 |
+
# Simplified factor representation
|
| 278 |
+
factors.append({
|
| 279 |
+
'nodes': [country_node, year_node],
|
| 280 |
+
'strength': interaction_strength,
|
| 281 |
+
'type': 'country_year'
|
| 282 |
+
})
|
| 283 |
+
|
| 284 |
+
return factors
|
| 285 |
+
|
| 286 |
+
def _compute_age_sex_interaction(self, age: int, sex: int) -> jnp.ndarray:
|
| 287 |
+
"""Compute interaction strength between age and sex from data."""
|
| 288 |
+
# Filter data for this age-sex combination
|
| 289 |
+
subset = self.df[(self.df['age'] == age) & (self.df['sex'] == sex)]
|
| 290 |
+
|
| 291 |
+
if len(subset) == 0:
|
| 292 |
+
# Default weak interaction if no data
|
| 293 |
+
return jnp.array([[0.1, 0.0], [0.0, 0.1]])
|
| 294 |
+
|
| 295 |
+
# Use death rate as proxy for interaction strength
|
| 296 |
+
avg_death_rate = subset['death_rate'].mean()
|
| 297 |
+
|
| 298 |
+
# Create 2x2 interaction matrix
|
| 299 |
+
# Higher death rates = stronger interaction
|
| 300 |
+
strength = min(avg_death_rate * 10, 1.0) # Cap at 1.0
|
| 301 |
+
return jnp.array([[strength, -strength/2], [-strength/2, strength]])
|
| 302 |
+
|
| 303 |
+
def _compute_country_year_interaction(self, country: str, year: int) -> jnp.ndarray:
|
| 304 |
+
"""Compute interaction strength between country and year."""
|
| 305 |
+
subset = self.df[(self.df['country'] == country) & (self.df['year'] == year)]
|
| 306 |
+
|
| 307 |
+
if len(subset) == 0:
|
| 308 |
+
return jnp.array([[0.1, 0.0], [0.0, 0.1]])
|
| 309 |
+
|
| 310 |
+
# Use life expectancy variance as interaction strength
|
| 311 |
+
life_exp_var = subset['life_expectancy'].var()
|
| 312 |
+
strength = min(life_exp_var / 100, 1.0) # Normalize and cap
|
| 313 |
+
|
| 314 |
+
return jnp.array([[strength, -strength/3], [-strength/3, strength]])
|
| 315 |
+
|
| 316 |
+
def get_mortality_prediction_nodes(self,
|
| 317 |
+
age: int,
|
| 318 |
+
country: str,
|
| 319 |
+
sex: int) -> Dict[str, any]:
|
| 320 |
+
"""
|
| 321 |
+
Get the relevant nodes for mortality prediction given demographics.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
age: Age value
|
| 325 |
+
country: Country name
|
| 326 |
+
sex: Sex value (1=male, 2=female, 3=both)
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Dictionary mapping node types to THRML nodes
|
| 330 |
+
"""
|
| 331 |
+
return {
|
| 332 |
+
'age_node': self.age_nodes.get(age),
|
| 333 |
+
'country_node': self.country_nodes.get(country),
|
| 334 |
+
'sex_node': self.sex_nodes.get(sex),
|
| 335 |
+
'life_expectancy_nodes': self.life_expectancy_nodes,
|
| 336 |
+
'death_probability_nodes': self.death_probability_nodes
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def discretize_life_expectancy(self, life_exp: float) -> int:
|
| 340 |
+
"""Convert continuous life expectancy to discrete bin index."""
|
| 341 |
+
return int(jnp.digitize(life_exp, self.life_exp_bins)) - 1
|
| 342 |
+
|
| 343 |
+
def discretize_death_probability(self, death_prob: float) -> int:
|
| 344 |
+
"""Convert continuous death probability to discrete bin index."""
|
| 345 |
+
return int(jnp.digitize(death_prob, self.death_prob_bins)) - 1
|
| 346 |
+
|
| 347 |
+
def continuous_from_bin(self, bin_idx: int, bin_type: str) -> float:
|
| 348 |
+
"""Convert bin index back to continuous value (bin center)."""
|
| 349 |
+
if bin_type == "life_expectancy":
|
| 350 |
+
if 0 <= bin_idx < len(self.life_exp_bins) - 1:
|
| 351 |
+
return (self.life_exp_bins[bin_idx] + self.life_exp_bins[bin_idx + 1]) / 2
|
| 352 |
+
elif bin_type == "death_probability":
|
| 353 |
+
if 0 <= bin_idx < len(self.death_prob_bins) - 1:
|
| 354 |
+
return (self.death_prob_bins[bin_idx] + self.death_prob_bins[bin_idx + 1]) / 2
|
| 355 |
+
return 0.0
|