Spaces:
Running
Running
devjas1
(SYNC): bring parity backend (utils/ scripts/ models/ tests/) from feat/ui-parity-rebuild; no UI changes
6373c5a
# scripts/run_inference.py | |
""" | |
CLI inference with preprocessing parity. | |
Applies: resample → baseline (deg=2) → smooth (w=11,o=2) → normalize | |
unless explicitly disabled via flags. | |
Usage (examples): | |
python scripts/run_inference.py \ | |
--input datasets/rdwp/sta-1.txt \ | |
--arch figure2 \ | |
--weights outputs/figure2_model.pth \ | |
--target-len 500 | |
# Disable smoothing only: | |
python scripts/run_inference.py --input ... --arch resnet --weights ... --disable-smooth | |
""" | |
import os | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
import argparse | |
import json | |
import logging | |
from pathlib import Path | |
from typing import cast | |
from torch import nn | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from models.registry import build, choices | |
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH | |
from scripts.plot_spectrum import load_spectrum | |
from scripts.discover_raman_files import label_file | |
def parse_args(): | |
p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).") | |
p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).") | |
p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.") | |
p.add_argument("--weights", required=True, help="Path to model weights (.pth).") | |
p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).") | |
# Default = ON; use disable- flags to turn steps off explicitly. | |
p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.") | |
p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.") | |
p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.") | |
p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).") | |
p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).") | |
return p.parse_args() | |
def _load_state_dict_safe(path: str): | |
"""Load a state dict safely across torch versions & checkpoint formats.""" | |
try: | |
obj = torch.load(path, map_location="cpu", weights_only=True) # newer torch | |
except TypeError: | |
obj = torch.load(path, map_location="cpu") # fallback for older torch | |
# Accept either a plain state_dict or a checkpoint dict that contains one | |
if isinstance(obj, dict): | |
for k in ("state_dict", "model_state_dict", "model"): | |
if k in obj and isinstance(obj[k], dict): | |
obj = obj[k] | |
break | |
if not isinstance(obj, dict): | |
raise ValueError( | |
"Loaded object is not a state_dict or checkpoint with a state_dict. " | |
f"Type={type(obj)} from file={path}" | |
) | |
# Strip DataParallel 'module.' prefixes if present | |
if any(key.startswith("module.") for key in obj.keys()): | |
obj = {key.replace("module.", "", 1): val for key, val in obj.items()} | |
return obj | |
def main(): | |
logging.basicConfig(level=logging.INFO, format="INFO: %(message)s") | |
args = parse_args() | |
in_path = Path(args.input) | |
if not in_path.exists(): | |
raise FileNotFoundError(f"Input file not found: {in_path}") | |
# --- Load raw spectrum | |
x_raw, y_raw = load_spectrum(str(in_path)) | |
if len(x_raw) < 10: | |
raise ValueError("Input spectrum has too few points (<10).") | |
# --- Preprocess (single source of truth) | |
_, y_proc = preprocess_spectrum( | |
np.array(x_raw), | |
np.array(y_raw), | |
target_len=args.target_len, | |
do_baseline=not args.disable_baseline, | |
do_smooth=not args.disable_smooth, | |
do_normalize=not args.disable_normalize, | |
out_dtype="float32", | |
) | |
# --- Build model & load weights (safe) | |
device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu") | |
model = cast(nn.Module, build(args.arch, args.target_len)).to(device) | |
state = _load_state_dict_safe(args.weights) | |
missing, unexpected = model.load_state_dict(state, strict=False) | |
if missing or unexpected: | |
logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected)) | |
model.eval() | |
# Shape: (B, C, L) = (1, 1, target_len) | |
x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device) | |
with torch.no_grad(): | |
logits = model(x_tensor).float().cpu() # shape (1, num_classes) | |
probs = F.softmax(logits, dim=1) | |
probs_np = probs.numpy().ravel().tolist() | |
logits_np = logits.numpy().ravel().tolist() | |
pred_label = int(np.argmax(probs_np)) | |
# Optional ground-truth from filename (if encoded) | |
true_label = label_file(str(in_path)) | |
# --- Prepare output | |
out_dir = Path("outputs") / "inference" | |
out_dir.mkdir(parents=True, exist_ok=True) | |
out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json") | |
result = { | |
"input_file": str(in_path), | |
"arch": args.arch, | |
"weights": str(args.weights), | |
"target_len": args.target_len, | |
"preprocessing": { | |
"baseline": not args.disable_baseline, | |
"smooth": not args.disable_smooth, | |
"normalize": not args.disable_normalize, | |
}, | |
"predicted_label": pred_label, | |
"true_label": true_label, | |
"probs": probs_np, | |
"logits": logits_np, | |
} | |
with open(out_path, "w", encoding="utf-8") as f: | |
json.dump(result, f, indent=2) | |
logging.info("Predicted Label: %d True Label: %s", pred_label, true_label) | |
logging.info("Raw Logits: %s", logits_np) | |
logging.info("Result saved to %s", out_path) | |
if __name__ == "__main__": | |
main() | |