| """ |
| Walnut Rancidity Predictor β Inference Script |
| Usage: |
| from model.predict import predict_storage_risk |
| result = predict_storage_risk(sequence) |
| """ |
|
|
| import sys, os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import joblib |
|
|
| |
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| MODEL_PATH = ROOT / "models" / "walnut_rancidity_lstm_attention.pt" |
| SCALER_PATH = ROOT / "models" / "feature_scaler.pkl" |
|
|
| FEATURE_COLS = [ |
| "temperature", "humidity", "moisture", "oxygen", |
| "peroxide_value", "free_fatty_acids", "hexanal_level", "oxidation_index", |
| ] |
|
|
| SEQ_LEN = 30 |
|
|
|
|
| |
| class Attention(nn.Module): |
| def __init__(self, hidden_size: int): |
| super().__init__() |
| self.attn = nn.Linear(hidden_size, 1) |
|
|
| def forward(self, lstm_out: torch.Tensor) -> torch.Tensor: |
| scores = self.attn(lstm_out).squeeze(-1) |
| weights = torch.softmax(scores, dim=-1) |
| context = (weights.unsqueeze(-1) * lstm_out).sum(dim=1) |
| return context |
|
|
|
|
| class WalnutLSTMAttention(nn.Module): |
| def __init__(self, n_features: int, hidden: int, n_layers: int, dropout: float): |
| super().__init__() |
| self.lstm = nn.LSTM( |
| input_size=n_features, |
| hidden_size=hidden, |
| num_layers=n_layers, |
| dropout=dropout if n_layers > 1 else 0.0, |
| batch_first=True, |
| ) |
| self.attn = Attention(hidden) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.head_rancidity = nn.Sequential( |
| nn.Linear(hidden, 32), nn.ReLU(), |
| nn.Linear(32, 1), nn.Sigmoid(), |
| ) |
| self.head_shelf_life = nn.Sequential( |
| nn.Linear(hidden, 32), nn.ReLU(), |
| nn.Linear(32, 1), |
| ) |
| self.head_decay = nn.Sequential( |
| nn.Linear(hidden, 32), nn.ReLU(), |
| nn.Linear(32, 1), nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| lstm_out, _ = self.lstm(x) |
| context = self.attn(lstm_out) |
| context = self.dropout(context) |
| rp = self.head_rancidity(context).squeeze(-1) |
| sl = self.head_shelf_life(context).squeeze(-1) |
| dc = self.head_decay(context).squeeze(-1) |
| return rp, sl, dc |
|
|
|
|
| |
| _model = None |
| _scaler = None |
|
|
|
|
| def _load_artifacts(): |
| global _model, _scaler |
| if _model is not None: |
| return |
|
|
| ckpt = torch.load(MODEL_PATH, map_location="cpu") |
| cfg = ckpt["config"] |
|
|
| _model = WalnutLSTMAttention( |
| n_features=cfg["n_features"], |
| hidden=cfg["hidden"], |
| n_layers=cfg["n_layers"], |
| dropout=cfg["dropout"], |
| ) |
| _model.load_state_dict(ckpt["model_state"]) |
| _model.eval() |
|
|
| _scaler = joblib.load(SCALER_PATH) |
|
|
|
|
| |
| def predict_storage_risk(sequence: list | np.ndarray) -> dict: |
| """ |
| Predict walnut storage risk from a time-series sequence. |
| |
| Parameters |
| ---------- |
| sequence : array-like of shape (SEQ_LEN, 8) or (N, 8) |
| Each row contains the 8 features in order: |
| [temperature, humidity, moisture, oxygen, |
| peroxide_value, free_fatty_acids, hexanal_level, oxidation_index] |
| |
| If more than SEQ_LEN rows are provided, the last SEQ_LEN rows are used. |
| If fewer rows are provided, the sequence is zero-padded at the front. |
| |
| Returns |
| ------- |
| dict with keys: |
| rancidity_probability : float [0, 1] |
| shelf_life_remaining_days : float (days) |
| risk_level : "LOW" | "MEDIUM" | "HIGH" |
| """ |
| _load_artifacts() |
|
|
| seq = np.array(sequence, dtype=np.float32) |
| if seq.ndim == 1: |
| seq = seq.reshape(1, -1) |
|
|
| |
| if len(seq) > SEQ_LEN: |
| seq = seq[-SEQ_LEN:] |
| elif len(seq) < SEQ_LEN: |
| pad = np.zeros((SEQ_LEN - len(seq), seq.shape[1]), dtype=np.float32) |
| seq = np.vstack([pad, seq]) |
|
|
| |
| seq_scaled = _scaler.transform(seq) |
| x = torch.tensor(seq_scaled[np.newaxis], dtype=torch.float32) |
|
|
| with torch.no_grad(): |
| rp_pred, sl_pred, dc_pred = _model(x) |
|
|
| rancidity_prob = float(rp_pred.item()) |
| shelf_life = float(sl_pred.item()) * 180.0 |
|
|
| if rancidity_prob < 0.3: |
| risk_level = "LOW" |
| elif rancidity_prob <= 0.7: |
| risk_level = "MEDIUM" |
| else: |
| risk_level = "HIGH" |
|
|
| return { |
| "rancidity_probability": round(rancidity_prob, 4), |
| "shelf_life_remaining_days": round(max(shelf_life, 0.0), 2), |
| "risk_level": risk_level, |
| } |
|
|
|
|
| |
| if __name__ == "__main__": |
| print("Running demo inference β¦") |
|
|
| |
| cold_seq = np.column_stack([ |
| np.full(30, 5.0), |
| np.full(30, 50.0), |
| np.full(30, 4.0), |
| np.full(30, 0.20), |
| np.linspace(0.5, 1.2, 30), |
| np.linspace(0.05, 0.10, 30), |
| np.linspace(0.1, 0.3, 30), |
| np.linspace(0.2, 0.5, 30), |
| ]) |
| result = predict_storage_risk(cold_seq) |
| print(f"Cold storage β {result}") |
|
|
| |
| hot_seq = np.column_stack([ |
| np.full(30, 38.0), |
| np.full(30, 80.0), |
| np.full(30, 7.5), |
| np.full(30, 0.22), |
| np.linspace(2.0, 8.0, 30), |
| np.linspace(0.2, 0.6, 30), |
| np.linspace(0.8, 2.5, 30), |
| np.linspace(1.0, 3.5, 30), |
| ]) |
| result = predict_storage_risk(hot_seq) |
| print(f"Hot transport β {result}") |
|
|