h3ir commited on
Commit
1cc16b2
·
verified ·
1 Parent(s): 35679ce

Upload mortality_graph.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mortality_graph.py +355 -0
mortality_graph.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mortality Graph Construction for THRML Integration
3
+ =================================================
4
+
5
+ This module converts Morbid AI's MortalityRecord data structure into
6
+ THRML-compatible probabilistic graphical models.
7
+ """
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import networkx as nx
12
+ from typing import List, Dict, Tuple, Optional
13
+ import pandas as pd
14
+ from dataclasses import dataclass
15
+
16
+ from thrml.pgm import CategoricalNode, SpinNode
17
+ from thrml.block_management import Block
18
+ from thrml.factor import AbstractFactor
19
+
20
+
21
+ @dataclass
22
+ class MortalityRecord:
23
+ """Morbid AI mortality record structure"""
24
+ country: str
25
+ year: int
26
+ sex: int # 1=male, 2=female, 3=both
27
+ age: int
28
+ deathRate: float # m(x)
29
+ deathProbability: float # q(x)
30
+ survivors: float # l(x)
31
+ deaths: float # d(x)
32
+ lifeExpectancy: float # e(x)
33
+
34
+
35
+ class MortalityGraphBuilder:
36
+ """
37
+ Builds THRML-compatible probabilistic graphical models from mortality data.
38
+
39
+ This class creates heterogeneous graphs that capture complex interactions
40
+ between demographic factors (age, country, sex, year) and mortality outcomes.
41
+ """
42
+
43
+ def __init__(self, mortality_data: List[MortalityRecord]):
44
+ """
45
+ Initialize with mortality data.
46
+
47
+ Args:
48
+ mortality_data: List of MortalityRecord objects
49
+ """
50
+ self.mortality_data = mortality_data
51
+ self.df = pd.DataFrame([
52
+ {
53
+ 'country': record.country,
54
+ 'year': record.year,
55
+ 'sex': record.sex,
56
+ 'age': record.age,
57
+ 'death_rate': record.deathRate,
58
+ 'death_probability': record.deathProbability,
59
+ 'survivors': record.survivors,
60
+ 'deaths': record.deaths,
61
+ 'life_expectancy': record.lifeExpectancy
62
+ } for record in mortality_data
63
+ ])
64
+
65
+ # Extract unique values for graph construction
66
+ self.countries = sorted(self.df['country'].unique())
67
+ self.years = sorted(self.df['year'].unique())
68
+ self.sexes = sorted(self.df['sex'].unique())
69
+ self.ages = sorted(self.df['age'].unique())
70
+
71
+ # Create node mappings
72
+ self._create_node_mappings()
73
+
74
+ def _create_node_mappings(self):
75
+ """Create mappings from data values to graph nodes."""
76
+ self.country_nodes = {country: CategoricalNode() for country in self.countries}
77
+ self.year_nodes = {year: CategoricalNode() for year in self.years}
78
+ self.sex_nodes = {sex: SpinNode() for sex in self.sexes} # Binary-like representation
79
+ self.age_nodes = {age: CategoricalNode() for age in self.ages}
80
+
81
+ # Create outcome nodes for mortality metrics
82
+ self.life_expectancy_nodes = {}
83
+ self.death_probability_nodes = {}
84
+
85
+ # Discretize life expectancy and death probability for categorical representation
86
+ self.life_exp_bins = jnp.linspace(0, 100, 21) # 20 bins for life expectancy
87
+ self.death_prob_bins = jnp.linspace(0, 1, 11) # 10 bins for death probability
88
+
89
+ for i in range(len(self.life_exp_bins) - 1):
90
+ self.life_expectancy_nodes[i] = CategoricalNode()
91
+
92
+ for i in range(len(self.death_prob_bins) - 1):
93
+ self.death_probability_nodes[i] = CategoricalNode()
94
+
95
+ def build_mortality_graph(self) -> nx.Graph:
96
+ """
97
+ Build NetworkX graph representing mortality factor interactions.
98
+
99
+ Returns:
100
+ NetworkX graph with nodes representing demographic factors
101
+ and edges representing interactions.
102
+ """
103
+ G = nx.Graph()
104
+
105
+ # Add nodes with attributes
106
+ for country, node in self.country_nodes.items():
107
+ G.add_node(f"country_{country}", type="country", value=country, thrml_node=node)
108
+
109
+ for year, node in self.year_nodes.items():
110
+ G.add_node(f"year_{year}", type="year", value=year, thrml_node=node)
111
+
112
+ for sex, node in self.sex_nodes.items():
113
+ G.add_node(f"sex_{sex}", type="sex", value=sex, thrml_node=node)
114
+
115
+ for age, node in self.age_nodes.items():
116
+ G.add_node(f"age_{age}", type="age", value=age, thrml_node=node)
117
+
118
+ # Add outcome nodes
119
+ for bin_idx, node in self.life_expectancy_nodes.items():
120
+ G.add_node(f"life_exp_{bin_idx}", type="life_expectancy",
121
+ bin_idx=bin_idx, thrml_node=node)
122
+
123
+ for bin_idx, node in self.death_probability_nodes.items():
124
+ G.add_node(f"death_prob_{bin_idx}", type="death_probability",
125
+ bin_idx=bin_idx, thrml_node=node)
126
+
127
+ # Add edges representing factor interactions
128
+ self._add_demographic_interactions(G)
129
+ self._add_outcome_interactions(G)
130
+
131
+ return G
132
+
133
+ def _add_demographic_interactions(self, G: nx.Graph):
134
+ """Add edges between demographic factor nodes."""
135
+ # Age-Sex interactions (biological mortality differences)
136
+ for age in self.ages:
137
+ for sex in self.sexes:
138
+ G.add_edge(f"age_{age}", f"sex_{sex}", interaction_type="age_sex")
139
+
140
+ # Country-Year interactions (temporal mortality trends by country)
141
+ for country in self.countries:
142
+ for year in self.years:
143
+ G.add_edge(f"country_{country}", f"year_{year}",
144
+ interaction_type="country_year")
145
+
146
+ # Age-Country interactions (demographic mortality patterns)
147
+ for age in self.ages[::5]: # Sample every 5th age to reduce complexity
148
+ for country in self.countries:
149
+ G.add_edge(f"age_{age}", f"country_{country}",
150
+ interaction_type="age_country")
151
+
152
+ def _add_outcome_interactions(self, G: nx.Graph):
153
+ """Add edges between demographic factors and mortality outcomes."""
154
+ # Connect age groups to life expectancy bins
155
+ for age in self.ages[::10]: # Sample to reduce complexity
156
+ for le_bin in range(len(self.life_expectancy_nodes)):
157
+ G.add_edge(f"age_{age}", f"life_exp_{le_bin}",
158
+ interaction_type="age_life_expectancy")
159
+
160
+ # Connect demographic factors to death probability
161
+ for country in self.countries:
162
+ for dp_bin in range(len(self.death_probability_nodes)):
163
+ G.add_edge(f"country_{country}", f"death_prob_{dp_bin}",
164
+ interaction_type="country_death_probability")
165
+
166
+ def create_sampling_blocks(self, strategy: str = "two_color") -> List[Block]:
167
+ """
168
+ Create sampling blocks for THRML block Gibbs sampling.
169
+
170
+ Args:
171
+ strategy: Blocking strategy ("two_color", "demographic", "outcome")
172
+
173
+ Returns:
174
+ List of Block objects for THRML sampling
175
+ """
176
+ all_nodes = []
177
+
178
+ # Collect all THRML nodes
179
+ all_nodes.extend(list(self.country_nodes.values()))
180
+ all_nodes.extend(list(self.year_nodes.values()))
181
+ all_nodes.extend(list(self.sex_nodes.values()))
182
+ all_nodes.extend(list(self.age_nodes.values()))
183
+ all_nodes.extend(list(self.life_expectancy_nodes.values()))
184
+ all_nodes.extend(list(self.death_probability_nodes.values()))
185
+
186
+ if strategy == "two_color":
187
+ # Simple two-color blocking with homogeneous node types
188
+ categorical_nodes = (list(self.country_nodes.values()) +
189
+ list(self.year_nodes.values()) +
190
+ list(self.age_nodes.values()) +
191
+ list(self.life_expectancy_nodes.values()) +
192
+ list(self.death_probability_nodes.values()))
193
+ spin_nodes = list(self.sex_nodes.values())
194
+
195
+ # Create separate blocks for different node types
196
+ if categorical_nodes and spin_nodes:
197
+ return [Block(categorical_nodes), Block(spin_nodes)]
198
+ elif categorical_nodes:
199
+ return [Block(categorical_nodes[::2]), Block(categorical_nodes[1::2])]
200
+ else:
201
+ return [Block(spin_nodes)]
202
+
203
+ elif strategy == "demographic":
204
+ # Block by demographic factor types - separate by node type
205
+ categorical_demographic = (list(self.country_nodes.values()) +
206
+ list(self.year_nodes.values()) +
207
+ list(self.age_nodes.values()))
208
+ spin_demographic = list(self.sex_nodes.values())
209
+ outcome_nodes = (list(self.life_expectancy_nodes.values()) +
210
+ list(self.death_probability_nodes.values()))
211
+
212
+ blocks = []
213
+ if categorical_demographic:
214
+ blocks.append(Block(categorical_demographic))
215
+ if spin_demographic:
216
+ blocks.append(Block(spin_demographic))
217
+ if outcome_nodes:
218
+ blocks.append(Block(outcome_nodes))
219
+ return blocks
220
+
221
+ elif strategy == "outcome":
222
+ # Block by outcome type - keep node types separate
223
+ life_exp_nodes = list(self.life_expectancy_nodes.values())
224
+ death_prob_nodes = list(self.death_probability_nodes.values())
225
+ categorical_demo = (list(self.country_nodes.values()) +
226
+ list(self.year_nodes.values()) +
227
+ list(self.age_nodes.values()))
228
+ spin_demo = list(self.sex_nodes.values())
229
+
230
+ blocks = []
231
+ if life_exp_nodes:
232
+ blocks.append(Block(life_exp_nodes))
233
+ if death_prob_nodes:
234
+ blocks.append(Block(death_prob_nodes))
235
+ if categorical_demo:
236
+ blocks.append(Block(categorical_demo))
237
+ if spin_demo:
238
+ blocks.append(Block(spin_demo))
239
+ return blocks
240
+
241
+ else:
242
+ raise ValueError(f"Unknown blocking strategy: {strategy}")
243
+
244
+ def create_interaction_factors(self) -> List[Dict]:
245
+ """
246
+ Create interaction factors for the energy-based model.
247
+
248
+ Returns:
249
+ List of simplified factor objects representing
250
+ pairwise and higher-order interactions
251
+ """
252
+ factors = []
253
+
254
+ # Age-Sex interaction factors
255
+ for age in self.ages[::10]: # Sample to manage complexity
256
+ for sex in self.sexes:
257
+ age_node = self.age_nodes[age]
258
+ sex_node = self.sex_nodes[sex]
259
+
260
+ # Create interaction matrix based on mortality data
261
+ interaction_strength = self._compute_age_sex_interaction(age, sex)
262
+ # For now, create a simplified factor representation
263
+ # In a full implementation, would create proper THRML factors
264
+ factors.append({
265
+ 'nodes': [age_node, sex_node],
266
+ 'strength': interaction_strength,
267
+ 'type': 'age_sex'
268
+ })
269
+
270
+ # Country-Year interaction factors
271
+ for country in self.countries:
272
+ for year in self.years[::2]: # Sample years
273
+ country_node = self.country_nodes[country]
274
+ year_node = self.year_nodes[year]
275
+
276
+ interaction_strength = self._compute_country_year_interaction(country, year)
277
+ # Simplified factor representation
278
+ factors.append({
279
+ 'nodes': [country_node, year_node],
280
+ 'strength': interaction_strength,
281
+ 'type': 'country_year'
282
+ })
283
+
284
+ return factors
285
+
286
+ def _compute_age_sex_interaction(self, age: int, sex: int) -> jnp.ndarray:
287
+ """Compute interaction strength between age and sex from data."""
288
+ # Filter data for this age-sex combination
289
+ subset = self.df[(self.df['age'] == age) & (self.df['sex'] == sex)]
290
+
291
+ if len(subset) == 0:
292
+ # Default weak interaction if no data
293
+ return jnp.array([[0.1, 0.0], [0.0, 0.1]])
294
+
295
+ # Use death rate as proxy for interaction strength
296
+ avg_death_rate = subset['death_rate'].mean()
297
+
298
+ # Create 2x2 interaction matrix
299
+ # Higher death rates = stronger interaction
300
+ strength = min(avg_death_rate * 10, 1.0) # Cap at 1.0
301
+ return jnp.array([[strength, -strength/2], [-strength/2, strength]])
302
+
303
+ def _compute_country_year_interaction(self, country: str, year: int) -> jnp.ndarray:
304
+ """Compute interaction strength between country and year."""
305
+ subset = self.df[(self.df['country'] == country) & (self.df['year'] == year)]
306
+
307
+ if len(subset) == 0:
308
+ return jnp.array([[0.1, 0.0], [0.0, 0.1]])
309
+
310
+ # Use life expectancy variance as interaction strength
311
+ life_exp_var = subset['life_expectancy'].var()
312
+ strength = min(life_exp_var / 100, 1.0) # Normalize and cap
313
+
314
+ return jnp.array([[strength, -strength/3], [-strength/3, strength]])
315
+
316
+ def get_mortality_prediction_nodes(self,
317
+ age: int,
318
+ country: str,
319
+ sex: int) -> Dict[str, any]:
320
+ """
321
+ Get the relevant nodes for mortality prediction given demographics.
322
+
323
+ Args:
324
+ age: Age value
325
+ country: Country name
326
+ sex: Sex value (1=male, 2=female, 3=both)
327
+
328
+ Returns:
329
+ Dictionary mapping node types to THRML nodes
330
+ """
331
+ return {
332
+ 'age_node': self.age_nodes.get(age),
333
+ 'country_node': self.country_nodes.get(country),
334
+ 'sex_node': self.sex_nodes.get(sex),
335
+ 'life_expectancy_nodes': self.life_expectancy_nodes,
336
+ 'death_probability_nodes': self.death_probability_nodes
337
+ }
338
+
339
+ def discretize_life_expectancy(self, life_exp: float) -> int:
340
+ """Convert continuous life expectancy to discrete bin index."""
341
+ return int(jnp.digitize(life_exp, self.life_exp_bins)) - 1
342
+
343
+ def discretize_death_probability(self, death_prob: float) -> int:
344
+ """Convert continuous death probability to discrete bin index."""
345
+ return int(jnp.digitize(death_prob, self.death_prob_bins)) - 1
346
+
347
+ def continuous_from_bin(self, bin_idx: int, bin_type: str) -> float:
348
+ """Convert bin index back to continuous value (bin center)."""
349
+ if bin_type == "life_expectancy":
350
+ if 0 <= bin_idx < len(self.life_exp_bins) - 1:
351
+ return (self.life_exp_bins[bin_idx] + self.life_exp_bins[bin_idx + 1]) / 2
352
+ elif bin_type == "death_probability":
353
+ if 0 <= bin_idx < len(self.death_prob_bins) - 1:
354
+ return (self.death_prob_bins[bin_idx] + self.death_prob_bins[bin_idx + 1]) / 2
355
+ return 0.0