polymer-aging-ml / scripts /run_inference.py
devjas1
(SYNC): bring parity backend (utils/ scripts/ models/ tests/) from feat/ui-parity-rebuild; no UI changes
6373c5a
# scripts/run_inference.py
"""
CLI inference with preprocessing parity.
Applies: resample → baseline (deg=2) → smooth (w=11,o=2) → normalize
unless explicitly disabled via flags.
Usage (examples):
python scripts/run_inference.py \
--input datasets/rdwp/sta-1.txt \
--arch figure2 \
--weights outputs/figure2_model.pth \
--target-len 500
# Disable smoothing only:
python scripts/run_inference.py --input ... --arch resnet --weights ... --disable-smooth
"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import json
import logging
from pathlib import Path
from typing import cast
from torch import nn
import numpy as np
import torch
import torch.nn.functional as F
from models.registry import build, choices
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
from scripts.plot_spectrum import load_spectrum
from scripts.discover_raman_files import label_file
def parse_args():
p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).")
p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).")
p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.")
p.add_argument("--weights", required=True, help="Path to model weights (.pth).")
p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).")
# Default = ON; use disable- flags to turn steps off explicitly.
p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.")
p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.")
p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.")
p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).")
p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).")
return p.parse_args()
def _load_state_dict_safe(path: str):
"""Load a state dict safely across torch versions & checkpoint formats."""
try:
obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch
except TypeError:
obj = torch.load(path, map_location="cpu") # fallback for older torch
# Accept either a plain state_dict or a checkpoint dict that contains one
if isinstance(obj, dict):
for k in ("state_dict", "model_state_dict", "model"):
if k in obj and isinstance(obj[k], dict):
obj = obj[k]
break
if not isinstance(obj, dict):
raise ValueError(
"Loaded object is not a state_dict or checkpoint with a state_dict. "
f"Type={type(obj)} from file={path}"
)
# Strip DataParallel 'module.' prefixes if present
if any(key.startswith("module.") for key in obj.keys()):
obj = {key.replace("module.", "", 1): val for key, val in obj.items()}
return obj
def main():
logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
args = parse_args()
in_path = Path(args.input)
if not in_path.exists():
raise FileNotFoundError(f"Input file not found: {in_path}")
# --- Load raw spectrum
x_raw, y_raw = load_spectrum(str(in_path))
if len(x_raw) < 10:
raise ValueError("Input spectrum has too few points (<10).")
# --- Preprocess (single source of truth)
_, y_proc = preprocess_spectrum(
np.array(x_raw),
np.array(y_raw),
target_len=args.target_len,
do_baseline=not args.disable_baseline,
do_smooth=not args.disable_smooth,
do_normalize=not args.disable_normalize,
out_dtype="float32",
)
# --- Build model & load weights (safe)
device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
model = cast(nn.Module, build(args.arch, args.target_len)).to(device)
state = _load_state_dict_safe(args.weights)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing or unexpected:
logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected))
model.eval()
# Shape: (B, C, L) = (1, 1, target_len)
x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)
with torch.no_grad():
logits = model(x_tensor).float().cpu() # shape (1, num_classes)
probs = F.softmax(logits, dim=1)
probs_np = probs.numpy().ravel().tolist()
logits_np = logits.numpy().ravel().tolist()
pred_label = int(np.argmax(probs_np))
# Optional ground-truth from filename (if encoded)
true_label = label_file(str(in_path))
# --- Prepare output
out_dir = Path("outputs") / "inference"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json")
result = {
"input_file": str(in_path),
"arch": args.arch,
"weights": str(args.weights),
"target_len": args.target_len,
"preprocessing": {
"baseline": not args.disable_baseline,
"smooth": not args.disable_smooth,
"normalize": not args.disable_normalize,
},
"predicted_label": pred_label,
"true_label": true_label,
"probs": probs_np,
"logits": logits_np,
}
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2)
logging.info("Predicted Label: %d True Label: %s", pred_label, true_label)
logging.info("Raw Logits: %s", logits_np)
logging.info("Result saved to %s", out_path)
if __name__ == "__main__":
main()