polymer-aging-ml / scripts /preprocess_dataset.py
devjas1
(SYNC): bring parity backend (utils/ scripts/ models/ tests/) from feat/ui-parity-rebuild; no UI changes
6373c5a
"""preprocess_dataset.py
Canonical Raman preprocessing for dataset splits.
Uses the single source of truth in utils.preprocessing:
resample β†’ baseline (deg=2) β†’ smooth (w=11,o=2) β†’ normalize.
"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np
from utils.preprocessing import (
TARGET_LENGTH,
preprocess_spectrum
)
from scripts.discover_raman_files import list_txt_files, label_file
from scripts.plot_spectrum import load_spectrum
def preprocess_dataset(
dataset_dir: str,
target_len: int = TARGET_LENGTH,
baseline_correction: bool = True,
apply_smoothing: bool = True,
normalize: bool = True,
out_dtype: str = "float32",
):
"""
Load, preprocess, and label Raman spectra in dataset_dir.
Returns
-------
X : np.ndarray, shape (N, target_len), dtype=out_dtype
Preprocessed spectra (resampled and transformed).
y : np.ndarray, shape (N,), dtype=int64
Integer labels (e.g., 0 = stable, 1 = weathered).
"""
txt_paths = list_txt_files(dataset_dir)
X, y_labels = [], []
for path in txt_paths:
label = label_file(path)
if label is None:
continue
x_raw, y_raw = load_spectrum(path)
if len(x_raw) < 10:
continue # Skip files with too few points
# === Single-source-of-truth path ===
_, y_processed = preprocess_spectrum(
np.asarray(x_raw),
np.asarray(y_raw),
target_len=target_len,
do_baseline=baseline_correction,
do_smooth=apply_smoothing,
do_normalize=normalize,
out_dtype=out_dtype # str is OK (DTypeLike),
)
# === Collect ===
X.append(y_processed)
y_labels.append(int(label))
if not X:
# === No valid samples ===
return np.empty((0, target_len), dtype=out_dtype), np.empty((0,), dtype=np.int64)
X_arr = np.asarray(X, dtype=np.dtype(out_dtype))
Y_arr = np.asarray(y_labels, dtype=np.int64)
return X_arr, Y_arr
# === Optional: Run directly for quick smoke test ===
if __name__ == "__main__":
test_dataset_dir = os.path.join("datasets", "rdwp")
X, y = preprocess_dataset(test_dataset_dir)
print(f"X shape: {X.shape} dtype={X.dtype}")
print(f"y shape: {y.shape} dtype={y.dtype}")
if y.size:
try:
counts = np.bincount(y, minlength=2)
print(f"Label distribution: {counts} (stable, weathered)")
except Exception as e:
print(f"Could not compute label distribution {e}")