ecflow / inference.py
Bing Yan
Switch TPD to Summary-ECFlow model (21-dim hand-crafted features)
ad63797
"""
ECFlow inference engine.
Loads trained EC and TPD models and runs end-to-end inference
from preprocessed arrays (dimensionless for CV, physical for TPD).
"""
import json
import sys
import os
from pathlib import Path
import numpy as np
import torch
from multi_mechanism_model import MultiMechanismFlow
from tpd_model import MultiMechanismFlowTPD
from flow_model import MECHANISM_LIST, MECHANISM_PARAMS, ActNorm
from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS
def _fix_actnorm_initialized(model):
"""Mark all ActNorm layers as initialized after loading a checkpoint.
Old checkpoints lack the ``_initialized`` buffer, so ``load_state_dict``
leaves it at ``False``. The first forward pass would then overwrite the
trained ``log_scale``/``bias`` with data-dependent statistics.
"""
for module in model.modules():
if isinstance(module, ActNorm) and not module.initialized:
module.initialized = True
class ECFlowPredictor:
"""Unified predictor for both EC (cyclic voltammetry) and TPD domains."""
def __init__(self, ec_checkpoint=None, tpd_checkpoint=None, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.ec_model = None
self.ec_norm_stats = None
self.tpd_model = None
self.tpd_norm_stats = None
if ec_checkpoint is not None:
self._load_ec(ec_checkpoint)
if tpd_checkpoint is not None:
self._load_tpd(tpd_checkpoint)
def _load_ec(self, ckpt_path):
ckpt_path = Path(ckpt_path)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
args = checkpoint["args"]
self.ec_model = MultiMechanismFlow(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
)
self.ec_model.load_state_dict(checkpoint["model_state_dict"], strict=False)
_fix_actnorm_initialized(self.ec_model)
self.ec_model.to(self.device).eval()
# Search for norm_stats in multiple locations
ckpt_dir = ckpt_path.parent
stem = ckpt_path.stem.replace("best", "").rstrip("_")
prefix = stem + "_" if stem else ""
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}norm_stats.json", "ec_norm_stats.json", "norm_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.ec_norm_stats = json.load(f)
break
if self.ec_norm_stats is not None:
break
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}theta_stats.json", "ec_theta_stats.json", "theta_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.ec_theta_stats = json.load(f)
break
if hasattr(self, "ec_theta_stats") and self.ec_theta_stats is not None:
break
def _load_tpd(self, ckpt_path):
ckpt_path = Path(ckpt_path)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
args = checkpoint["args"]
self.tpd_use_summary = args.get("use_summary_features", False)
self.tpd_model = MultiMechanismFlowTPD(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
use_summary_features=self.tpd_use_summary,
)
self.tpd_model.load_state_dict(checkpoint["model_state_dict"], strict=False)
_fix_actnorm_initialized(self.tpd_model)
self.tpd_model.to(self.device).eval()
# Search for norm_stats in multiple locations
ckpt_dir = ckpt_path.parent
stem = ckpt_path.stem.replace("best", "").rstrip("_")
prefix = stem + "_" if stem else ""
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}norm_stats.json", "tpd_norm_stats.json", "norm_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.tpd_norm_stats = json.load(f)
break
if self.tpd_norm_stats is not None:
break
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}theta_stats.json", "tpd_theta_stats.json", "theta_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.tpd_theta_stats = json.load(f)
break
if hasattr(self, "tpd_theta_stats") and self.tpd_theta_stats is not None:
break
def _prepare_ec_tensor(self, potentials, fluxes, times, sigmas):
"""
Build model input tensor from preprocessed dimensionless CV data.
Args:
potentials: list of 1-D arrays (dimensionless theta)
fluxes: list of 1-D arrays (dimensionless flux)
times: list of 1-D arrays (dimensionless time) or None
sigmas: 1-D array of dimensionless scan rates
Returns:
dict of tensors ready for model.predict()
"""
from scipy.interpolate import interp1d
n_scans = len(potentials)
T_target = 672
pot_resampled = []
flux_resampled = []
time_resampled = []
flux_scales = []
for i in range(n_scans):
pot = np.asarray(potentials[i], dtype=np.float32)
flx = np.asarray(fluxes[i], dtype=np.float32)
if times is not None and times[i] is not None:
tim = np.asarray(times[i], dtype=np.float32)
else:
theta_range = pot.max() - pot.min()
sigma = sigmas[i]
total_time = 2.0 * theta_range / sigma
tim = np.linspace(0, total_time, len(pot), dtype=np.float32)
peak = np.max(np.abs(flx)) + 1e-30
flux_scales.append(np.log10(peak))
flx = flx / peak
t_uniform = np.linspace(tim[0], tim[-1], T_target)
pot_resampled.append(
interp1d(tim, pot, kind="linear", fill_value="extrapolate")(t_uniform)
)
flux_resampled.append(
interp1d(tim, flx, kind="linear", fill_value="extrapolate")(t_uniform)
)
time_resampled.append(t_uniform)
pot_arr = np.stack(pot_resampled).astype(np.float32)
flx_arr = np.stack(flux_resampled).astype(np.float32)
tim_arr = np.stack(time_resampled).astype(np.float32)
ns = self.ec_norm_stats
if ns:
pot_arr = (pot_arr - ns["potential"][0]) / ns["potential"][1]
flx_arr = (flx_arr - ns["flux"][0]) / ns["flux"][1]
tim_arr = (tim_arr - ns["time"][0]) / ns["time"][1]
# [1, N, 3, T]
waveforms = np.stack([pot_arr, flx_arr, tim_arr], axis=1)
x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device)
scan_mask = torch.ones(1, n_scans, T_target, dtype=torch.bool, device=self.device)
sigmas_t = torch.from_numpy(
np.log10(np.asarray(sigmas, dtype=np.float32))
).unsqueeze(0).to(self.device)
flux_scales_t = torch.from_numpy(
np.asarray(flux_scales, dtype=np.float32)
).unsqueeze(0).to(self.device)
return {
"input": x,
"scan_mask": scan_mask,
"sigmas": sigmas_t,
"flux_scales": flux_scales_t,
}
def _prepare_tpd_tensor(self, temperatures, rates, betas):
"""
Build model input tensor from TPD data.
Args:
temperatures: list of 1-D arrays (K)
rates: list of 1-D arrays (arb. units)
betas: 1-D array of heating rates (K/s)
Returns:
dict of tensors ready for model.predict()
"""
from scipy.interpolate import interp1d
n_rates = len(temperatures)
T_target = 500
temp_resampled = []
rate_resampled = []
for i in range(n_rates):
temp = np.asarray(temperatures[i], dtype=np.float32)
rate = np.asarray(rates[i], dtype=np.float32)
t_uniform = np.linspace(temp[0], temp[-1], T_target)
temp_resampled.append(t_uniform)
rate_resampled.append(
interp1d(temp, rate, kind="linear", fill_value="extrapolate")(t_uniform)
)
temp_arr = np.stack(temp_resampled).astype(np.float32)
rate_arr = np.stack(rate_resampled).astype(np.float32)
summary_t = None
if getattr(self, 'tpd_use_summary', False):
from preprocessing import extract_tpd_summary_stats
hr_arr = np.asarray(betas, dtype=np.float32)
lengths = np.full(n_rates, T_target, dtype=np.int32)
summary = extract_tpd_summary_stats(
temp_arr, rate_arr, lengths, hr_arr, n_rates)
summary_t = torch.from_numpy(summary).unsqueeze(0).to(self.device)
rate_scales = []
for i in range(n_rates):
peak = np.max(np.abs(rate_arr[i])) + 1e-30
rate_scales.append(np.log10(peak))
rate_arr[i] /= peak
ns = self.tpd_norm_stats
if ns:
temp_arr = (temp_arr - ns["temperature"][0]) / ns["temperature"][1]
rate_arr = (rate_arr - ns["rate"][0]) / ns["rate"][1]
# [1, N, 2, T]
waveforms = np.stack([temp_arr, rate_arr], axis=1)
x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device)
scan_mask = torch.ones(1, n_rates, T_target, dtype=torch.bool, device=self.device)
sigmas_t = torch.from_numpy(
np.log10(np.asarray(betas, dtype=np.float32))
).unsqueeze(0).to(self.device)
rate_scales_t = torch.from_numpy(
np.asarray(rate_scales, dtype=np.float32)
).unsqueeze(0).to(self.device)
result = {
"input": x,
"scan_mask": scan_mask,
"sigmas": sigmas_t,
"flux_scales": rate_scales_t,
}
if summary_t is not None:
result["summary"] = summary_t
return result
@torch.no_grad()
def predict_ec(self, potentials, fluxes, sigmas, times=None, n_samples=500, temperature=1.0):
"""
Run EC inference on dimensionless CV data.
Args:
potentials: list of 1-D arrays (dimensionless theta per scan rate)
fluxes: list of 1-D arrays (dimensionless flux per scan rate)
sigmas: list/array of dimensionless scan rates
times: optional list of 1-D time arrays
n_samples: posterior samples to draw
temperature: sampling temperature (>1 broadens posteriors)
Returns:
dict with mechanism_probs, mechanism_names, predicted_mechanism,
parameter_stats (per mechanism), posterior_samples (per mechanism)
"""
if self.ec_model is None:
raise RuntimeError("EC model not loaded")
tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas)
pred = self.ec_model.predict(
tensors["input"],
scan_mask=tensors["scan_mask"],
sigmas=tensors["sigmas"],
flux_scales=tensors["flux_scales"],
n_samples=n_samples,
temperature=temperature,
)
probs = pred["mechanism_probs"][0].cpu().numpy()
pred_idx = int(pred["mechanism_pred"][0].cpu().item())
pred_mech = MECHANISM_LIST[pred_idx]
param_stats = {}
samples_dict = {}
for mech in MECHANISM_LIST:
if pred["samples"][mech] is not None:
s = pred["samples"][mech][0].cpu().numpy() # [n_samples, D]
samples_dict[mech] = s
param_stats[mech] = {
"names": MECHANISM_PARAMS[mech]["names"],
"mean": s.mean(axis=0).tolist(),
"std": s.std(axis=0).tolist(),
"median": np.median(s, axis=0).tolist(),
"q05": np.quantile(s, 0.05, axis=0).tolist(),
"q95": np.quantile(s, 0.95, axis=0).tolist(),
}
return {
"domain": "ec",
"mechanism_probs": {m: float(probs[i]) for i, m in enumerate(MECHANISM_LIST)},
"mechanism_names": MECHANISM_LIST,
"predicted_mechanism": pred_mech,
"predicted_mechanism_idx": pred_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
}
@torch.no_grad()
def predict_tpd(self, temperatures, rates, betas, n_samples=500, temperature=1.0):
"""
Run TPD inference.
Args:
temperatures: list of 1-D arrays (K per heating rate)
rates: list of 1-D arrays (signal per heating rate)
betas: list/array of heating rates (K/s)
n_samples: posterior samples to draw
temperature: sampling temperature
Returns:
dict with mechanism_probs, parameter_stats, posterior_samples
"""
if self.tpd_model is None:
raise RuntimeError("TPD model not loaded")
tensors = self._prepare_tpd_tensor(temperatures, rates, betas)
pred = self.tpd_model.predict(
tensors["input"],
scan_mask=tensors["scan_mask"],
sigmas=tensors["sigmas"],
flux_scales=tensors["flux_scales"],
n_samples=n_samples,
temperature=temperature,
summary=tensors.get("summary"),
)
probs = pred["mechanism_probs"][0].cpu().numpy()
pred_idx = int(pred["mechanism_pred"][0].cpu().item())
pred_mech = TPD_MECHANISM_LIST[pred_idx]
param_stats = {}
samples_dict = {}
for mech in TPD_MECHANISM_LIST:
if pred["samples"][mech] is not None:
s = pred["samples"][mech][0].cpu().numpy()
samples_dict[mech] = s
param_stats[mech] = {
"names": TPD_MECHANISM_PARAMS[mech]["names"],
"mean": s.mean(axis=0).tolist(),
"std": s.std(axis=0).tolist(),
"median": np.median(s, axis=0).tolist(),
"q05": np.quantile(s, 0.05, axis=0).tolist(),
"q95": np.quantile(s, 0.95, axis=0).tolist(),
}
return {
"domain": "tpd",
"mechanism_probs": {m: float(probs[i]) for i, m in enumerate(TPD_MECHANISM_LIST)},
"mechanism_names": TPD_MECHANISM_LIST,
"predicted_mechanism": pred_mech,
"predicted_mechanism_idx": pred_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
}
# =====================================================================
# Signal Reconstruction
# =====================================================================
def reconstruct_ec(self, result, potentials, fluxes, sigmas,
base_params=None, mechanism=None):
"""
Reconstruct CV signals from inferred posterior median and compute metrics.
Args:
result: output dict from predict_ec()
potentials: list of 1-D arrays (original dimensionless theta)
fluxes: list of 1-D arrays (original dimensionless flux)
sigmas: list of dimensionless scan rates
base_params: dict of fixed simulation params; defaults used if None
mechanism: which mechanism to reconstruct (default: predicted)
Returns:
dict with 'observed', 'reconstructed' curve lists,
'nrmse', 'r2' per scan rate, and 'mean_nrmse', 'mean_r2'
"""
from evaluate_reconstruction import (
reconstruct_ec_signal, signal_nrmse, signal_r2,
)
mech = mechanism or result["predicted_mechanism"]
stats = result["parameter_stats"].get(mech)
if stats is None:
return None
theta_point = np.array(stats["median"])
if base_params is None:
pot0 = np.asarray(potentials[0])
base_params = {
"theta_i": float(pot0.max()),
"theta_v": float(pot0.min()),
"dA": 1.0,
"C_A_bulk": 1.0,
"C_B_bulk": 0.0,
"kinetics": mech,
}
try:
recon_results = reconstruct_ec_signal(
theta_point, mech, base_params, sigmas, n_spatial=64
)
except Exception:
return None
observed_curves = []
recon_curves = []
conc_curves = []
nrmses = []
r2s = []
for i, (pot, flx, sigma) in enumerate(zip(potentials, fluxes, sigmas)):
pot = np.asarray(pot)
flx = np.asarray(flx)
observed_curves.append({"x": pot, "y": flx})
if i < len(recon_results) and recon_results[i].get("success", False):
rec = recon_results[i]
rec_pot = np.asarray(rec["potential"])
rec_flx = np.asarray(rec["flux"])
n_obs = len(pot)
n_rec = len(rec_pot)
t_obs = np.linspace(0, 1, n_obs)
t_rec = np.linspace(0, 1, n_rec)
rec_flx_interp = np.interp(t_obs, t_rec, rec_flx)
recon_curves.append({"x": pot, "y": rec_flx_interp})
nrmse_val = signal_nrmse(flx, rec_flx_interp)
r2_val = signal_r2(flx, rec_flx_interp)
nrmses.append(nrmse_val)
r2s.append(r2_val)
if "c_ox_surface" in rec and "c_red_surface" in rec:
c_ox_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_ox_surface"]))
c_red_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_red_surface"]))
conc_curves.append({
"x": pot,
"c_ox": c_ox_interp,
"c_red": c_red_interp,
})
else:
conc_curves.append(None)
else:
recon_curves.append({"x": pot, "y": np.zeros_like(flx)})
nrmses.append(float("nan"))
r2s.append(float("nan"))
conc_curves.append(None)
valid_nrmse = [v for v in nrmses if np.isfinite(v)]
valid_r2 = [v for v in r2s if np.isfinite(v)]
return {
"observed": observed_curves,
"reconstructed": recon_curves,
"concentrations": conc_curves,
"nrmse": nrmses,
"r2": r2s,
"mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"),
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"),
}
def reconstruct_tpd(self, result, temperatures, rates, betas,
base_params=None, mechanism=None):
"""
Reconstruct TPD signals from inferred posterior median and compute metrics.
Args:
result: output dict from predict_tpd()
temperatures: list of 1-D arrays (K)
rates: list of 1-D arrays (signal)
betas: list of heating rates (K/s)
base_params: dict of fixed simulation params; defaults used if None
mechanism: which mechanism to reconstruct (default: predicted)
Returns:
dict with 'observed', 'reconstructed' curve lists,
'nrmse', 'r2' per heating rate, and 'mean_nrmse', 'mean_r2'
"""
from evaluate_reconstruction import (
reconstruct_tpd_signal, signal_nrmse, signal_r2,
)
mech = mechanism or result["predicted_mechanism"]
stats = result["parameter_stats"].get(mech)
if stats is None:
return None
theta_point = np.array(stats["median"])
if base_params is None:
temp0 = np.asarray(temperatures[0])
base_params = {
"mechanism": mech,
"T_start": float(temp0.min()),
"T_end": float(temp0.max()),
"n_points": 500,
}
try:
recon_results = reconstruct_tpd_signal(
theta_point, mech, base_params, betas
)
except Exception:
return None
observed_curves = []
recon_curves = []
nrmses = []
r2s = []
for i, (temp, rate, beta) in enumerate(zip(temperatures, rates, betas)):
temp = np.asarray(temp)
rate = np.asarray(rate)
observed_curves.append({"x": temp, "y": rate})
if i < len(recon_results) and recon_results[i].get("success", False):
rec = recon_results[i]
rec_temp = np.asarray(rec["temperature"])
rec_rate = np.asarray(rec["rate"])
rec_rate_interp = np.interp(temp, rec_temp, rec_rate)
recon_curves.append({"x": temp, "y": rec_rate_interp})
nrmse_val = signal_nrmse(rate, rec_rate_interp)
r2_val = signal_r2(rate, rec_rate_interp)
nrmses.append(nrmse_val)
r2s.append(r2_val)
else:
recon_curves.append({"x": temp, "y": np.zeros_like(rate)})
nrmses.append(float("nan"))
r2s.append(float("nan"))
valid_nrmse = [v for v in nrmses if np.isfinite(v)]
valid_r2 = [v for v in r2s if np.isfinite(v)]
return {
"observed": observed_curves,
"reconstructed": recon_curves,
"nrmse": nrmses,
"r2": r2s,
"mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"),
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"),
}