File size: 5,875 Bytes
6373c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e484a46
6373c5a
e484a46
 
 
6373c5a
e484a46
6373c5a
 
 
e484a46
 
 
6373c5a
e484a46
6373c5a
 
 
 
ba24c6a
e484a46
6373c5a
 
 
 
 
 
e484a46
6373c5a
 
 
 
e484a46
6373c5a
 
 
e484a46
 
6373c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e484a46
 
6373c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e484a46
 
6373c5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# scripts/run_inference.py
"""
CLI inference with preprocessing parity.
Applies: resample → baseline (deg=2) → smooth (w=11,o=2) → normalize
unless explicitly disabled via flags.

Usage (examples):
python scripts/run_inference.py \
    --input datasets/rdwp/sta-1.txt \
    --arch figure2 \
    --weights outputs/figure2_model.pth \
    --target-len 500

# Disable smoothing only:
python scripts/run_inference.py --input ... --arch resnet --weights ... --disable-smooth
"""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import json
import logging
from pathlib import Path
from typing import cast
from torch import nn

import numpy as np
import torch
import torch.nn.functional as F

from models.registry import build, choices
from utils.preprocessing import preprocess_spectrum, TARGET_LENGTH
from scripts.plot_spectrum import load_spectrum
from scripts.discover_raman_files import label_file


def parse_args():
    p = argparse.ArgumentParser(description="Raman spectrum inference (parity with CLI preprocessing).")
    p.add_argument("--input", required=True, help="Path to a single Raman .txt file (2 columns: x, y).")
    p.add_argument("--arch", required=True, choices=choices(), help="Model architecture key.")
    p.add_argument("--weights", required=True, help="Path to model weights (.pth).")
    p.add_argument("--target-len", type=int, default=TARGET_LENGTH, help="Resample length (default: 500).")

    # Default = ON; use disable- flags to turn steps off explicitly.
    p.add_argument("--disable-baseline", action="store_true", help="Disable baseline correction.")
    p.add_argument("--disable-smooth", action="store_true", help="Disable Savitzky–Golay smoothing.")
    p.add_argument("--disable-normalize", action="store_true", help="Disable min-max normalization.")

    p.add_argument("--output", default=None, help="Optional output JSON path (defaults to outputs/inference/<name>.json).")
    p.add_argument("--device", default="cpu", choices=["cpu", "cuda"], help="Compute device (default: cpu).")
    return p.parse_args()


def _load_state_dict_safe(path: str):
    """Load a state dict safely across torch versions & checkpoint formats."""
    try:
        obj = torch.load(path, map_location="cpu", weights_only=True)  # newer torch
    except TypeError:
        obj = torch.load(path, map_location="cpu")  # fallback for older torch

    # Accept either a plain state_dict or a checkpoint dict that contains one
    if isinstance(obj, dict):
        for k in ("state_dict", "model_state_dict", "model"):
            if k in obj and isinstance(obj[k], dict):
                obj = obj[k]
                break

    if not isinstance(obj, dict):
        raise ValueError(
            "Loaded object is not a state_dict or checkpoint with a state_dict. "
            f"Type={type(obj)} from file={path}"
        )

    # Strip DataParallel 'module.' prefixes if present
    if any(key.startswith("module.") for key in obj.keys()):
        obj = {key.replace("module.", "", 1): val for key, val in obj.items()}

    return obj


def main():
    logging.basicConfig(level=logging.INFO, format="INFO: %(message)s")
    args = parse_args()

    in_path = Path(args.input)
    if not in_path.exists():
        raise FileNotFoundError(f"Input file not found: {in_path}")

    # --- Load raw spectrum
    x_raw, y_raw = load_spectrum(str(in_path))
    if len(x_raw) < 10:
        raise ValueError("Input spectrum has too few points (<10).")

    # --- Preprocess (single source of truth)
    _, y_proc = preprocess_spectrum(
        np.array(x_raw),
        np.array(y_raw),
        target_len=args.target_len,
        do_baseline=not args.disable_baseline,
        do_smooth=not args.disable_smooth,
        do_normalize=not args.disable_normalize,
        out_dtype="float32",
    )

    # --- Build model & load weights (safe)
    device = torch.device(args.device if (args.device == "cuda" and torch.cuda.is_available()) else "cpu")
    model = cast(nn.Module, build(args.arch, args.target_len)).to(device)
    state = _load_state_dict_safe(args.weights)
    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing or unexpected:
        logging.info("Loaded with non-strict keys. missing=%d unexpected=%d", len(missing), len(unexpected))

    model.eval()

    # Shape: (B, C, L) = (1, 1, target_len)
    x_tensor = torch.from_numpy(y_proc[None, None, :]).to(device)

    with torch.no_grad():
        logits = model(x_tensor).float().cpu()  # shape (1, num_classes)
        probs = F.softmax(logits, dim=1)

    probs_np = probs.numpy().ravel().tolist()
    logits_np = logits.numpy().ravel().tolist()
    pred_label = int(np.argmax(probs_np))

    # Optional ground-truth from filename (if encoded)
    true_label = label_file(str(in_path))

    # --- Prepare output
    out_dir = Path("outputs") / "inference"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = Path(args.output) if args.output else (out_dir / f"{in_path.stem}_{args.arch}.json")

    result = {
        "input_file": str(in_path),
        "arch": args.arch,
        "weights": str(args.weights),
        "target_len": args.target_len,
        "preprocessing": {
            "baseline": not args.disable_baseline,
            "smooth": not args.disable_smooth,
            "normalize": not args.disable_normalize,
        },
        "predicted_label": pred_label,
        "true_label": true_label,
        "probs": probs_np,
        "logits": logits_np,
    }

    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2)

    logging.info("Predicted Label: %d  True Label: %s", pred_label, true_label)
    logging.info("Raw Logits: %s", logits_np)
    logging.info("Result saved to %s", out_path)


if __name__ == "__main__":
    main()