File size: 15,186 Bytes
1cc16b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""
Mortality Graph Construction for THRML Integration
=================================================

This module converts Morbid AI's MortalityRecord data structure into
THRML-compatible probabilistic graphical models.
"""

import jax
import jax.numpy as jnp
import networkx as nx
from typing import List, Dict, Tuple, Optional
import pandas as pd
from dataclasses import dataclass

from thrml.pgm import CategoricalNode, SpinNode
from thrml.block_management import Block
from thrml.factor import AbstractFactor


@dataclass
class MortalityRecord:
    """Morbid AI mortality record structure"""
    country: str
    year: int
    sex: int  # 1=male, 2=female, 3=both
    age: int
    deathRate: float  # m(x)
    deathProbability: float  # q(x)
    survivors: float  # l(x)
    deaths: float  # d(x)
    lifeExpectancy: float  # e(x)


class MortalityGraphBuilder:
    """
    Builds THRML-compatible probabilistic graphical models from mortality data.
    
    This class creates heterogeneous graphs that capture complex interactions
    between demographic factors (age, country, sex, year) and mortality outcomes.
    """
    
    def __init__(self, mortality_data: List[MortalityRecord]):
        """
        Initialize with mortality data.
        
        Args:
            mortality_data: List of MortalityRecord objects
        """
        self.mortality_data = mortality_data
        self.df = pd.DataFrame([
            {
                'country': record.country,
                'year': record.year,
                'sex': record.sex,
                'age': record.age,
                'death_rate': record.deathRate,
                'death_probability': record.deathProbability,
                'survivors': record.survivors,
                'deaths': record.deaths,
                'life_expectancy': record.lifeExpectancy
            } for record in mortality_data
        ])
        
        # Extract unique values for graph construction
        self.countries = sorted(self.df['country'].unique())
        self.years = sorted(self.df['year'].unique()) 
        self.sexes = sorted(self.df['sex'].unique())
        self.ages = sorted(self.df['age'].unique())
        
        # Create node mappings
        self._create_node_mappings()
        
    def _create_node_mappings(self):
        """Create mappings from data values to graph nodes."""
        self.country_nodes = {country: CategoricalNode() for country in self.countries}
        self.year_nodes = {year: CategoricalNode() for year in self.years}
        self.sex_nodes = {sex: SpinNode() for sex in self.sexes}  # Binary-like representation
        self.age_nodes = {age: CategoricalNode() for age in self.ages}
        
        # Create outcome nodes for mortality metrics
        self.life_expectancy_nodes = {}
        self.death_probability_nodes = {}
        
        # Discretize life expectancy and death probability for categorical representation
        self.life_exp_bins = jnp.linspace(0, 100, 21)  # 20 bins for life expectancy
        self.death_prob_bins = jnp.linspace(0, 1, 11)  # 10 bins for death probability
        
        for i in range(len(self.life_exp_bins) - 1):
            self.life_expectancy_nodes[i] = CategoricalNode()
            
        for i in range(len(self.death_prob_bins) - 1):
            self.death_probability_nodes[i] = CategoricalNode()
    
    def build_mortality_graph(self) -> nx.Graph:
        """
        Build NetworkX graph representing mortality factor interactions.
        
        Returns:
            NetworkX graph with nodes representing demographic factors
            and edges representing interactions.
        """
        G = nx.Graph()
        
        # Add nodes with attributes
        for country, node in self.country_nodes.items():
            G.add_node(f"country_{country}", type="country", value=country, thrml_node=node)
            
        for year, node in self.year_nodes.items():
            G.add_node(f"year_{year}", type="year", value=year, thrml_node=node)
            
        for sex, node in self.sex_nodes.items():
            G.add_node(f"sex_{sex}", type="sex", value=sex, thrml_node=node)
            
        for age, node in self.age_nodes.items():
            G.add_node(f"age_{age}", type="age", value=age, thrml_node=node)
        
        # Add outcome nodes
        for bin_idx, node in self.life_expectancy_nodes.items():
            G.add_node(f"life_exp_{bin_idx}", type="life_expectancy", 
                      bin_idx=bin_idx, thrml_node=node)
                      
        for bin_idx, node in self.death_probability_nodes.items():
            G.add_node(f"death_prob_{bin_idx}", type="death_probability", 
                      bin_idx=bin_idx, thrml_node=node)
        
        # Add edges representing factor interactions
        self._add_demographic_interactions(G)
        self._add_outcome_interactions(G)
        
        return G
    
    def _add_demographic_interactions(self, G: nx.Graph):
        """Add edges between demographic factor nodes."""
        # Age-Sex interactions (biological mortality differences)
        for age in self.ages:
            for sex in self.sexes:
                G.add_edge(f"age_{age}", f"sex_{sex}", interaction_type="age_sex")
        
        # Country-Year interactions (temporal mortality trends by country)
        for country in self.countries:
            for year in self.years:
                G.add_edge(f"country_{country}", f"year_{year}", 
                          interaction_type="country_year")
        
        # Age-Country interactions (demographic mortality patterns)
        for age in self.ages[::5]:  # Sample every 5th age to reduce complexity
            for country in self.countries:
                G.add_edge(f"age_{age}", f"country_{country}", 
                          interaction_type="age_country")
    
    def _add_outcome_interactions(self, G: nx.Graph):
        """Add edges between demographic factors and mortality outcomes."""
        # Connect age groups to life expectancy bins
        for age in self.ages[::10]:  # Sample to reduce complexity
            for le_bin in range(len(self.life_expectancy_nodes)):
                G.add_edge(f"age_{age}", f"life_exp_{le_bin}", 
                          interaction_type="age_life_expectancy")
        
        # Connect demographic factors to death probability
        for country in self.countries:
            for dp_bin in range(len(self.death_probability_nodes)):
                G.add_edge(f"country_{country}", f"death_prob_{dp_bin}",
                          interaction_type="country_death_probability")
    
    def create_sampling_blocks(self, strategy: str = "two_color") -> List[Block]:
        """
        Create sampling blocks for THRML block Gibbs sampling.
        
        Args:
            strategy: Blocking strategy ("two_color", "demographic", "outcome")
            
        Returns:
            List of Block objects for THRML sampling
        """
        all_nodes = []
        
        # Collect all THRML nodes
        all_nodes.extend(list(self.country_nodes.values()))
        all_nodes.extend(list(self.year_nodes.values()))
        all_nodes.extend(list(self.sex_nodes.values()))
        all_nodes.extend(list(self.age_nodes.values()))
        all_nodes.extend(list(self.life_expectancy_nodes.values()))
        all_nodes.extend(list(self.death_probability_nodes.values()))
        
        if strategy == "two_color":
            # Simple two-color blocking with homogeneous node types
            categorical_nodes = (list(self.country_nodes.values()) + 
                               list(self.year_nodes.values()) +
                               list(self.age_nodes.values()) +
                               list(self.life_expectancy_nodes.values()) +
                               list(self.death_probability_nodes.values()))
            spin_nodes = list(self.sex_nodes.values())
            
            # Create separate blocks for different node types
            if categorical_nodes and spin_nodes:
                return [Block(categorical_nodes), Block(spin_nodes)]
            elif categorical_nodes:
                return [Block(categorical_nodes[::2]), Block(categorical_nodes[1::2])]
            else:
                return [Block(spin_nodes)]
            
        elif strategy == "demographic":
            # Block by demographic factor types - separate by node type
            categorical_demographic = (list(self.country_nodes.values()) + 
                                     list(self.year_nodes.values()) +
                                     list(self.age_nodes.values()))
            spin_demographic = list(self.sex_nodes.values())
            outcome_nodes = (list(self.life_expectancy_nodes.values()) +
                           list(self.death_probability_nodes.values()))
            
            blocks = []
            if categorical_demographic:
                blocks.append(Block(categorical_demographic))
            if spin_demographic:
                blocks.append(Block(spin_demographic))
            if outcome_nodes:
                blocks.append(Block(outcome_nodes))
            return blocks
            
        elif strategy == "outcome":
            # Block by outcome type - keep node types separate
            life_exp_nodes = list(self.life_expectancy_nodes.values())
            death_prob_nodes = list(self.death_probability_nodes.values())
            categorical_demo = (list(self.country_nodes.values()) + 
                              list(self.year_nodes.values()) +
                              list(self.age_nodes.values()))
            spin_demo = list(self.sex_nodes.values())
            
            blocks = []
            if life_exp_nodes:
                blocks.append(Block(life_exp_nodes))
            if death_prob_nodes:
                blocks.append(Block(death_prob_nodes))
            if categorical_demo:
                blocks.append(Block(categorical_demo))
            if spin_demo:
                blocks.append(Block(spin_demo))
            return blocks
        
        else:
            raise ValueError(f"Unknown blocking strategy: {strategy}")
    
    def create_interaction_factors(self) -> List[Dict]:
        """
        Create interaction factors for the energy-based model.
        
        Returns:
            List of simplified factor objects representing 
            pairwise and higher-order interactions
        """
        factors = []
        
        # Age-Sex interaction factors
        for age in self.ages[::10]:  # Sample to manage complexity
            for sex in self.sexes:
                age_node = self.age_nodes[age]
                sex_node = self.sex_nodes[sex]
                
                # Create interaction matrix based on mortality data
                interaction_strength = self._compute_age_sex_interaction(age, sex)
                # For now, create a simplified factor representation
                # In a full implementation, would create proper THRML factors
                factors.append({
                    'nodes': [age_node, sex_node],
                    'strength': interaction_strength,
                    'type': 'age_sex'
                })
        
        # Country-Year interaction factors
        for country in self.countries:
            for year in self.years[::2]:  # Sample years
                country_node = self.country_nodes[country]
                year_node = self.year_nodes[year]
                
                interaction_strength = self._compute_country_year_interaction(country, year)
                # Simplified factor representation
                factors.append({
                    'nodes': [country_node, year_node],
                    'strength': interaction_strength,
                    'type': 'country_year'
                })
        
        return factors
    
    def _compute_age_sex_interaction(self, age: int, sex: int) -> jnp.ndarray:
        """Compute interaction strength between age and sex from data."""
        # Filter data for this age-sex combination
        subset = self.df[(self.df['age'] == age) & (self.df['sex'] == sex)]
        
        if len(subset) == 0:
            # Default weak interaction if no data
            return jnp.array([[0.1, 0.0], [0.0, 0.1]])
        
        # Use death rate as proxy for interaction strength
        avg_death_rate = subset['death_rate'].mean()
        
        # Create 2x2 interaction matrix
        # Higher death rates = stronger interaction
        strength = min(avg_death_rate * 10, 1.0)  # Cap at 1.0
        return jnp.array([[strength, -strength/2], [-strength/2, strength]])
    
    def _compute_country_year_interaction(self, country: str, year: int) -> jnp.ndarray:
        """Compute interaction strength between country and year."""
        subset = self.df[(self.df['country'] == country) & (self.df['year'] == year)]
        
        if len(subset) == 0:
            return jnp.array([[0.1, 0.0], [0.0, 0.1]])
        
        # Use life expectancy variance as interaction strength
        life_exp_var = subset['life_expectancy'].var()
        strength = min(life_exp_var / 100, 1.0)  # Normalize and cap
        
        return jnp.array([[strength, -strength/3], [-strength/3, strength]])
    
    def get_mortality_prediction_nodes(self, 
                                     age: int, 
                                     country: str, 
                                     sex: int) -> Dict[str, any]:
        """
        Get the relevant nodes for mortality prediction given demographics.
        
        Args:
            age: Age value
            country: Country name
            sex: Sex value (1=male, 2=female, 3=both)
            
        Returns:
            Dictionary mapping node types to THRML nodes
        """
        return {
            'age_node': self.age_nodes.get(age),
            'country_node': self.country_nodes.get(country),
            'sex_node': self.sex_nodes.get(sex),
            'life_expectancy_nodes': self.life_expectancy_nodes,
            'death_probability_nodes': self.death_probability_nodes
        }
    
    def discretize_life_expectancy(self, life_exp: float) -> int:
        """Convert continuous life expectancy to discrete bin index."""
        return int(jnp.digitize(life_exp, self.life_exp_bins)) - 1
    
    def discretize_death_probability(self, death_prob: float) -> int:
        """Convert continuous death probability to discrete bin index."""
        return int(jnp.digitize(death_prob, self.death_prob_bins)) - 1
    
    def continuous_from_bin(self, bin_idx: int, bin_type: str) -> float:
        """Convert bin index back to continuous value (bin center)."""
        if bin_type == "life_expectancy":
            if 0 <= bin_idx < len(self.life_exp_bins) - 1:
                return (self.life_exp_bins[bin_idx] + self.life_exp_bins[bin_idx + 1]) / 2
        elif bin_type == "death_probability":
            if 0 <= bin_idx < len(self.death_prob_bins) - 1:
                return (self.death_prob_bins[bin_idx] + self.death_prob_bins[bin_idx + 1]) / 2
        return 0.0