petter2025 commited on
Commit
332a79f
·
verified ·
1 Parent(s): 7e061fe

Create advanced_inference.py

Browse files
Files changed (1) hide show
  1. advanced_inference.py +73 -0
advanced_inference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hamiltonian Monte Carlo (NUTS) for complex pattern discovery.
3
+ Uses Pyro with NUTS and ArviZ for visualization.
4
+ """
5
+ import logging
6
+ import pyro
7
+ import pyro.distributions as dist
8
+ from pyro.infer import MCMC, NUTS
9
+ import torch
10
+ import numpy as np
11
+ import pandas as pd
12
+ from typing import Dict, Any, Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class HMCAnalyzer:
17
+ """Runs HMC on a simple regression model to demonstrate advanced inference."""
18
+
19
+ def __init__(self):
20
+ self.mcmc = None
21
+ self.trace = None
22
+
23
+ def _model(self, x, y=None):
24
+ # Linear regression with unknown noise
25
+ alpha = pyro.sample("alpha", dist.Normal(0, 10))
26
+ beta = pyro.sample("beta", dist.Normal(0, 1))
27
+ sigma = pyro.sample("sigma", dist.HalfNormal(1))
28
+ mu = alpha + beta * x
29
+ with pyro.plate("data", len(x)):
30
+ pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
31
+
32
+ def run_inference(self, data: Optional[pd.DataFrame] = None, num_samples: int = 500, warmup: int = 200):
33
+ """
34
+ Run HMC on synthetic or provided data.
35
+ If no data, generate synthetic trend data.
36
+ """
37
+ if data is None:
38
+ # Create synthetic data: a linear trend with noise
39
+ x = torch.linspace(0, 10, 50)
40
+ true_alpha, true_beta, true_sigma = 2.0, -0.3, 0.5
41
+ y = true_alpha + true_beta * x + torch.randn(50) * true_sigma
42
+ else:
43
+ # Assume data has columns 'x' and 'y'
44
+ x = torch.tensor(data['x'].values, dtype=torch.float32)
45
+ y = torch.tensor(data['y'].values, dtype=torch.float32)
46
+
47
+ nuts_kernel = NUTS(self._model)
48
+ self.mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup)
49
+ self.mcmc.run(x, y)
50
+
51
+ self.trace = self.mcmc.get_samples()
52
+ return self._summary()
53
+
54
+ def _summary(self) -> Dict[str, Any]:
55
+ """Return summary statistics of posterior samples."""
56
+ if self.trace is None:
57
+ return {}
58
+ summary = {}
59
+ for key in ['alpha', 'beta', 'sigma']:
60
+ samples = self.trace[key].numpy()
61
+ summary[key] = {
62
+ 'mean': float(samples.mean()),
63
+ 'std': float(samples.std()),
64
+ 'hpd_5': float(np.percentile(samples, 5)),
65
+ 'hpd_95': float(np.percentile(samples, 95))
66
+ }
67
+ return summary
68
+
69
+ def get_trace_data(self) -> Dict[str, np.ndarray]:
70
+ """Return posterior samples for plotting."""
71
+ if self.trace is None:
72
+ return {}
73
+ return {k: v.numpy() for k, v in self.trace.items()}