VincentCroft
commited on
Commit
·
c4598a9
1
Parent(s):
948686f
Add PMU fault training pipeline and improve Gradio app
Browse files- app.py +377 -140
- fault_classification_pmu.py +358 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,161 +1,398 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
- Port selection:
|
| 9 |
-
* If GRADIO_SERVER_PORT or PORT env var is set, try that.
|
| 10 |
-
* Otherwise find a free ephemeral port and use it.
|
| 11 |
-
* If binding fails, fall back to demo.launch() with no explicit port (Gradio picks).
|
| 12 |
-
- Reduces TF logging noise via TF_CPP_MIN_LOG_LEVEL (optional).
|
| 13 |
"""
|
|
|
|
|
|
|
|
|
|
| 14 |
import os
|
|
|
|
| 15 |
import socket
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
import pandas as pd
|
| 18 |
-
import gradio as gr
|
| 19 |
-
from tensorflow.keras.models import load_model
|
| 20 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 21 |
|
| 22 |
-
# Reduce TensorFlow log
|
| 23 |
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
-
print(f"Downloading {filename} from {
|
| 33 |
-
path = hf_hub_download(repo_id=
|
| 34 |
-
print("Downloaded
|
| 35 |
-
return path
|
| 36 |
-
except Exception as
|
| 37 |
-
print("Failed to download from
|
| 38 |
return None
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
try:
|
| 55 |
-
|
| 56 |
-
print("Loaded model
|
| 57 |
-
return
|
| 58 |
-
except Exception as
|
| 59 |
-
print("Failed to load model
|
| 60 |
return None
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
arr = np.array(arr)
|
| 67 |
-
if arr.ndim == 1:
|
| 68 |
-
if n_features is None:
|
| 69 |
-
# assume arr is flattened timesteps*features
|
| 70 |
-
return arr.reshape(1, n_timesteps, -1)
|
| 71 |
-
return arr.reshape(1, n_timesteps, int(n_features))
|
| 72 |
-
elif arr.ndim == 2:
|
| 73 |
-
# treat as (timesteps, features) -> add batch dim
|
| 74 |
-
if arr.shape[0] == 1:
|
| 75 |
-
return arr.reshape(1, arr.shape[1], -1)
|
| 76 |
-
return arr
|
| 77 |
-
else:
|
| 78 |
-
return arr
|
| 79 |
-
|
| 80 |
-
def predict_text(text, n_timesteps=1, n_features=None):
|
| 81 |
-
if MODEL is None:
|
| 82 |
-
return "模型未加载。请上传 'lstm_cnn_model.h5' 到 Space 根目录,或设置 HUB_REPO/HUB_FILENAME。"
|
| 83 |
-
try:
|
| 84 |
-
arr = np.fromstring(text, sep=',')
|
| 85 |
-
x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
|
| 86 |
-
probs = MODEL.predict(x)
|
| 87 |
-
label = int(np.argmax(probs, axis=1)[0])
|
| 88 |
-
return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})"
|
| 89 |
-
except Exception as e:
|
| 90 |
-
return f"预测失败: {e}"
|
| 91 |
-
|
| 92 |
-
def predict_csv(file, n_timesteps=1, n_features=None):
|
| 93 |
-
if MODEL is None:
|
| 94 |
-
return {"error": "模型未加载。请上传 'lstm_cnn_model.h5' 到 Space 根目录,或设置 HUB_REPO/HUB_FILENAME。"}
|
| 95 |
try:
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
return
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
return find_free_port()
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
try:
|
| 154 |
-
port =
|
| 155 |
-
print(f"Launching
|
| 156 |
-
demo.launch(server_name=
|
| 157 |
-
except OSError as
|
| 158 |
-
print("Failed to
|
| 159 |
-
print("Falling back to default demo.launch() (no explicit port).")
|
| 160 |
-
# last fallback: let Gradio choose/handle
|
| 161 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio front-end for Fault_Classification_PMU_Data models.
|
| 2 |
+
|
| 3 |
+
The application loads a CNN-LSTM model (and accompanying scaler/metadata)
|
| 4 |
+
produced by ``fault_classification_pmu.py`` and exposes a streamlined
|
| 5 |
+
prediction interface optimised for Hugging Face Spaces deployment. It supports
|
| 6 |
+
raw PMU time-series CSV uploads as well as manual comma separated feature
|
| 7 |
+
vectors.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
import os
|
| 13 |
+
import re
|
| 14 |
import socket
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 17 |
+
|
| 18 |
+
import gradio as gr
|
| 19 |
+
import joblib
|
| 20 |
import numpy as np
|
| 21 |
import pandas as pd
|
|
|
|
|
|
|
| 22 |
from huggingface_hub import hf_hub_download
|
| 23 |
+
from tensorflow.keras.models import load_model
|
| 24 |
|
| 25 |
+
# Reduce TensorFlow log verbosity
|
| 26 |
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
| 27 |
|
| 28 |
+
# --------------------------------------------------------------------------------------
|
| 29 |
+
# Configuration
|
| 30 |
+
# --------------------------------------------------------------------------------------
|
| 31 |
+
DEFAULT_FEATURE_COLUMNS: List[str] = [
|
| 32 |
+
"[325] UPMU_SUB22:FREQ",
|
| 33 |
+
"[326] UPMU_SUB22:DFDT",
|
| 34 |
+
"[327] UPMU_SUB22:FLAG",
|
| 35 |
+
"[328] UPMU_SUB22-L1:MAG",
|
| 36 |
+
"[329] UPMU_SUB22-L1:ANG",
|
| 37 |
+
"[330] UPMU_SUB22-L2:MAG",
|
| 38 |
+
"[331] UPMU_SUB22-L2:ANG",
|
| 39 |
+
"[332] UPMU_SUB22-L3:MAG",
|
| 40 |
+
"[333] UPMU_SUB22-L3:ANG",
|
| 41 |
+
"[334] UPMU_SUB22-C1:MAG",
|
| 42 |
+
"[335] UPMU_SUB22-C1:ANG",
|
| 43 |
+
"[336] UPMU_SUB22-C2:MAG",
|
| 44 |
+
"[337] UPMU_SUB22-C2:ANG",
|
| 45 |
+
"[338] UPMU_SUB22-C3:MAG",
|
| 46 |
+
"[339] UPMU_SUB22-C3:ANG",
|
| 47 |
+
]
|
| 48 |
+
DEFAULT_SEQUENCE_LENGTH = 32
|
| 49 |
+
DEFAULT_STRIDE = 4
|
| 50 |
+
|
| 51 |
+
LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
|
| 52 |
+
LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
|
| 53 |
+
LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
|
| 54 |
|
| 55 |
+
HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
|
| 56 |
+
HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
|
| 57 |
+
HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
|
| 58 |
+
HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
|
| 59 |
+
|
| 60 |
+
ENV_MODEL_PATH = "PMU_MODEL_PATH"
|
| 61 |
+
ENV_SCALER_PATH = "PMU_SCALER_PATH"
|
| 62 |
+
ENV_METADATA_PATH = "PMU_METADATA_PATH"
|
| 63 |
+
|
| 64 |
+
# --------------------------------------------------------------------------------------
|
| 65 |
+
# Utility functions for loading artifacts
|
| 66 |
+
# --------------------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def download_from_hub(filename: str) -> Optional[Path]:
|
| 69 |
+
if not HUB_REPO or not filename:
|
| 70 |
+
return None
|
| 71 |
try:
|
| 72 |
+
print(f"Downloading {filename} from {HUB_REPO} ...")
|
| 73 |
+
path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
|
| 74 |
+
print("Downloaded", path)
|
| 75 |
+
return Path(path)
|
| 76 |
+
except Exception as exc: # pragma: no cover - logging convenience
|
| 77 |
+
print("Failed to download", filename, "from", HUB_REPO, ":", exc)
|
| 78 |
return None
|
| 79 |
|
| 80 |
+
|
| 81 |
+
def resolve_artifact(local_name: str, env_var: str, hub_filename: str) -> Optional[Path]:
|
| 82 |
+
candidates = [Path(local_name)] if local_name else []
|
| 83 |
+
env_value = os.environ.get(env_var)
|
| 84 |
+
if env_value:
|
| 85 |
+
candidates.append(Path(env_value))
|
| 86 |
+
for candidate in candidates:
|
| 87 |
+
if candidate and candidate.exists():
|
| 88 |
+
return candidate
|
| 89 |
+
return download_from_hub(hub_filename)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_metadata(path: Optional[Path]) -> Dict:
|
| 93 |
+
if path and path.exists():
|
| 94 |
+
try:
|
| 95 |
+
return json.loads(path.read_text())
|
| 96 |
+
except Exception as exc: # pragma: no cover - metadata parsing errors
|
| 97 |
+
print("Failed to read metadata", path, exc)
|
| 98 |
+
return {}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def try_load_model(path: Optional[Path]):
|
| 102 |
+
if not path:
|
| 103 |
+
return None
|
| 104 |
try:
|
| 105 |
+
model = load_model(path)
|
| 106 |
+
print("Loaded model from", path)
|
| 107 |
+
return model
|
| 108 |
+
except Exception as exc: # pragma: no cover - runtime diagnostics
|
| 109 |
+
print("Failed to load model", path, exc)
|
| 110 |
return None
|
| 111 |
|
| 112 |
+
|
| 113 |
+
def try_load_scaler(path: Optional[Path]):
|
| 114 |
+
if not path:
|
| 115 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
try:
|
| 117 |
+
scaler = joblib.load(path)
|
| 118 |
+
print("Loaded scaler from", path)
|
| 119 |
+
return scaler
|
| 120 |
+
except Exception as exc:
|
| 121 |
+
print("Failed to load scaler", path, exc)
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
|
| 126 |
+
SCALER_PATH = resolve_artifact(LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME)
|
| 127 |
+
METADATA_PATH = resolve_artifact(LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME)
|
| 128 |
+
|
| 129 |
+
MODEL = try_load_model(MODEL_PATH)
|
| 130 |
+
SCALER = try_load_scaler(SCALER_PATH)
|
| 131 |
+
METADATA = load_metadata(METADATA_PATH)
|
| 132 |
+
|
| 133 |
+
FEATURE_COLUMNS: List[str] = METADATA.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
|
| 134 |
+
LABEL_CLASSES: List[str] = [str(label) for label in METADATA.get("label_classes", [])]
|
| 135 |
+
LABEL_COLUMN: str = METADATA.get("label_column", "Fault")
|
| 136 |
+
SEQUENCE_LENGTH: int = int(METADATA.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
|
| 137 |
+
DEFAULT_WINDOW_STRIDE: int = int(METADATA.get("stride", DEFAULT_STRIDE))
|
| 138 |
+
|
| 139 |
+
if MODEL is not None and not LABEL_CLASSES:
|
| 140 |
+
LABEL_CLASSES = [str(i) for i in range(MODEL.output_shape[-1])]
|
| 141 |
+
|
| 142 |
+
# --------------------------------------------------------------------------------------
|
| 143 |
+
# Pre-processing helpers
|
| 144 |
+
# --------------------------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
+
def ensure_ready():
|
| 147 |
+
if MODEL is None or SCALER is None:
|
| 148 |
+
raise RuntimeError(
|
| 149 |
+
"模型或特征缩放器未加载。请将 pmu_cnn_lstm_model.keras 和 pmu_feature_scaler.pkl "
|
| 150 |
+
"上传到 Space,或设置相关的 Hugging Face Hub 配置。"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def parse_text_features(text: str) -> np.ndarray:
|
| 155 |
+
cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
|
| 156 |
+
arr = np.fromstring(cleaned, sep=",")
|
| 157 |
+
if arr.size == 0:
|
| 158 |
+
raise ValueError("未解析到任何特征值,请输入以逗号分隔的数字。")
|
| 159 |
+
return arr.astype(np.float32)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def apply_scaler(sequences: np.ndarray) -> np.ndarray:
|
| 163 |
+
if SCALER is None:
|
| 164 |
+
return sequences
|
| 165 |
+
shape = sequences.shape
|
| 166 |
+
flattened = sequences.reshape(-1, shape[-1])
|
| 167 |
+
scaled = SCALER.transform(flattened)
|
| 168 |
+
return scaled.reshape(shape)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def make_sliding_windows(data: np.ndarray, sequence_length: int, stride: int) -> np.ndarray:
|
| 172 |
+
if data.shape[0] < sequence_length:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"数据行数 ({data.shape[0]}) 小于序列长度 ({sequence_length}),无法创建窗口。"
|
| 175 |
+
)
|
| 176 |
+
windows = [data[start : start + sequence_length] for start in range(0, data.shape[0] - sequence_length + 1, stride)]
|
| 177 |
+
return np.stack(windows)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def dataframe_to_sequences(
|
| 181 |
+
df: pd.DataFrame,
|
| 182 |
+
*,
|
| 183 |
+
sequence_length: int,
|
| 184 |
+
stride: int,
|
| 185 |
+
feature_columns: Sequence[str],
|
| 186 |
+
drop_label: bool = True,
|
| 187 |
+
) -> np.ndarray:
|
| 188 |
+
work_df = df.copy()
|
| 189 |
+
if drop_label and LABEL_COLUMN in work_df.columns:
|
| 190 |
+
work_df = work_df.drop(columns=[LABEL_COLUMN])
|
| 191 |
+
if "Timestamp" in work_df.columns:
|
| 192 |
+
work_df = work_df.sort_values("Timestamp")
|
| 193 |
+
|
| 194 |
+
available_cols = [c for c in feature_columns if c in work_df.columns]
|
| 195 |
+
n_features = len(feature_columns)
|
| 196 |
+
if available_cols and len(available_cols) == n_features:
|
| 197 |
+
array = work_df[available_cols].astype(np.float32).to_numpy()
|
| 198 |
+
return make_sliding_windows(array, sequence_length, stride)
|
| 199 |
+
|
| 200 |
+
numeric_df = work_df.select_dtypes(include=[np.number])
|
| 201 |
+
array = numeric_df.astype(np.float32).to_numpy()
|
| 202 |
+
if array.shape[1] == n_features * sequence_length:
|
| 203 |
+
return array.reshape(array.shape[0], sequence_length, n_features)
|
| 204 |
+
if sequence_length == 1 and array.shape[1] == n_features:
|
| 205 |
+
return array.reshape(array.shape[0], 1, n_features)
|
| 206 |
+
raise ValueError(
|
| 207 |
+
"CSV 列与预期特征不匹配。请包含完整的 PMU 特征列,或提供整形后的窗口数据。"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def label_name(index: int) -> str:
|
| 212 |
+
if 0 <= index < len(LABEL_CLASSES):
|
| 213 |
+
return str(LABEL_CLASSES[index])
|
| 214 |
+
return f"class_{index}"
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
|
| 218 |
+
rows: List[Dict[str, object]] = []
|
| 219 |
+
order = np.argsort(probabilities, axis=1)[:, ::-1]
|
| 220 |
+
for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
|
| 221 |
+
top_idx = int(ranking[0])
|
| 222 |
+
top_label = label_name(top_idx)
|
| 223 |
+
top_conf = float(prob_row[top_idx])
|
| 224 |
+
top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
|
| 225 |
+
rows.append(
|
| 226 |
+
{
|
| 227 |
+
"window": idx,
|
| 228 |
+
"predicted_label": top_label,
|
| 229 |
+
"confidence": round(top_conf, 4),
|
| 230 |
+
"top3": " | ".join(top3),
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
return pd.DataFrame(rows)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
|
| 237 |
+
payload: List[Dict[str, object]] = []
|
| 238 |
+
for idx, prob_row in enumerate(probabilities):
|
| 239 |
+
payload.append(
|
| 240 |
+
{
|
| 241 |
+
"window": int(idx),
|
| 242 |
+
"probabilities": {label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])},
|
| 243 |
+
}
|
| 244 |
+
)
|
| 245 |
+
return payload
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def predict_sequences(sequences: np.ndarray) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
|
| 249 |
+
ensure_ready()
|
| 250 |
+
sequences = apply_scaler(sequences.astype(np.float32))
|
| 251 |
+
probs = MODEL.predict(sequences, verbose=0)
|
| 252 |
+
table = format_predictions(probs)
|
| 253 |
+
json_probs = probabilities_to_json(probs)
|
| 254 |
+
status = f"共生成 {len(sequences)} 个窗口,模型输出维度 {probs.shape[1]}."
|
| 255 |
+
return status, table, json_probs
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def predict_from_text(text: str, sequence_length: int) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
|
| 259 |
+
arr = parse_text_features(text)
|
| 260 |
+
n_features = len(FEATURE_COLUMNS)
|
| 261 |
+
if arr.size % n_features != 0:
|
| 262 |
+
raise ValueError(
|
| 263 |
+
f"输入特征数量 {arr.size} 不是特征维度 {n_features} 的整数倍。请按照 {n_features} 个特征为一组输入。"
|
| 264 |
+
)
|
| 265 |
+
timesteps = arr.size // n_features
|
| 266 |
+
if timesteps != sequence_length:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"检测到 {timesteps} 个时间步,与当前设置的序列长度 {sequence_length} 不一致。"
|
| 269 |
+
)
|
| 270 |
+
sequences = arr.reshape(1, sequence_length, n_features)
|
| 271 |
+
status, table, probs = predict_sequences(sequences)
|
| 272 |
+
status = f"单窗口预测完成。{status}"
|
| 273 |
+
return status, table, probs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def predict_from_csv(file_obj, sequence_length: int, stride: int) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
|
| 277 |
+
df = pd.read_csv(file_obj.name)
|
| 278 |
+
sequences = dataframe_to_sequences(
|
| 279 |
+
df,
|
| 280 |
+
sequence_length=sequence_length,
|
| 281 |
+
stride=stride,
|
| 282 |
+
feature_columns=FEATURE_COLUMNS,
|
| 283 |
+
)
|
| 284 |
+
status, table, probs = predict_sequences(sequences)
|
| 285 |
+
status = f"CSV 处理完成,生成 {len(sequences)} 个窗口。{status}"
|
| 286 |
+
return status, table, probs
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# --------------------------------------------------------------------------------------
|
| 290 |
+
# Gradio interface
|
| 291 |
+
# --------------------------------------------------------------------------------------
|
| 292 |
+
|
| 293 |
+
def build_interface() -> gr.Blocks:
|
| 294 |
+
with gr.Blocks(title="Fault Classification - PMU Data") as demo:
|
| 295 |
+
gr.Markdown("# Fault Classification (PMU 数据)")
|
| 296 |
+
if MODEL is None or SCALER is None:
|
| 297 |
+
gr.Markdown(
|
| 298 |
+
"⚠️ **模型或缩放器未准备好。** 上传 `pmu_cnn_lstm_model.keras`、"
|
| 299 |
+
"`pmu_feature_scaler.pkl` 与 `pmu_metadata.json` 至 Space 根目录,或配置环境变量以从 Hugging Face Hub 自动下载。"
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
gr.Markdown(
|
| 303 |
+
"模型、特征缩放器与元数据均已加载。可以上传原始 PMU CSV 数据,或粘贴单个时间窗口的特征向量进行推理。"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
with gr.Accordion("特征说明", open=False):
|
| 307 |
+
gr.Markdown(
|
| 308 |
+
f"输入窗口按以下特征顺序排列 (每个时间步共 {len(FEATURE_COLUMNS)} 个特征):\n"
|
| 309 |
+
+ "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
|
| 310 |
+
)
|
| 311 |
+
gr.Markdown(
|
| 312 |
+
f"训练时使用的窗口长度默认为 **{SEQUENCE_LENGTH}**,滑动步长默认为 **{DEFAULT_WINDOW_STRIDE}**。"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
with gr.Row():
|
| 316 |
+
file_in = gr.File(label="上传 PMU CSV", file_types=[".csv"])
|
| 317 |
+
text_in = gr.Textbox(
|
| 318 |
+
lines=4,
|
| 319 |
+
label="或粘贴单个窗口的逗号分隔特征",
|
| 320 |
+
placeholder="49.97772,1.215825E-38,...",
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
with gr.Row():
|
| 324 |
+
sequence_length_input = gr.Slider(
|
| 325 |
+
minimum=1,
|
| 326 |
+
maximum=max(1, SEQUENCE_LENGTH * 2),
|
| 327 |
+
step=1,
|
| 328 |
+
value=SEQUENCE_LENGTH,
|
| 329 |
+
label="序列长度 (timesteps)",
|
| 330 |
+
)
|
| 331 |
+
stride_input = gr.Slider(
|
| 332 |
+
minimum=1,
|
| 333 |
+
maximum=max(1, SEQUENCE_LENGTH),
|
| 334 |
+
step=1,
|
| 335 |
+
value=max(1, DEFAULT_WINDOW_STRIDE),
|
| 336 |
+
label="CSV 滑动窗口步长",
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
predict_btn = gr.Button("执行预测", variant="primary")
|
| 340 |
+
status_out = gr.Textbox(label="状态", interactive=False)
|
| 341 |
+
table_out = gr.Dataframe(headers=["window", "predicted_label", "confidence", "top3"], label="预测结果", interactive=False)
|
| 342 |
+
probs_out = gr.JSON(label="各窗口概率分布")
|
| 343 |
+
|
| 344 |
+
def _run_prediction(file_obj, text, sequence_length, stride):
|
| 345 |
+
sequence_length = int(sequence_length)
|
| 346 |
+
stride = int(stride)
|
| 347 |
+
try:
|
| 348 |
+
if file_obj is not None:
|
| 349 |
+
return predict_from_csv(file_obj, sequence_length, stride)
|
| 350 |
+
if text and text.strip():
|
| 351 |
+
return predict_from_text(text, sequence_length)
|
| 352 |
+
return "请上传 CSV 或输入文本特征。", pd.DataFrame(), []
|
| 353 |
+
except Exception as exc:
|
| 354 |
+
return f"预测失败: {exc}", pd.DataFrame(), []
|
| 355 |
+
|
| 356 |
+
predict_btn.click(
|
| 357 |
+
_run_prediction,
|
| 358 |
+
inputs=[file_in, text_in, sequence_length_input, stride_input],
|
| 359 |
+
outputs=[status_out, table_out, probs_out],
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
return demo
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# --------------------------------------------------------------------------------------
|
| 366 |
+
# Launch helpers
|
| 367 |
+
# --------------------------------------------------------------------------------------
|
| 368 |
+
|
| 369 |
+
def find_free_port() -> int:
|
| 370 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 371 |
+
s.bind(("", 0))
|
| 372 |
+
return s.getsockname()[1]
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def choose_port() -> Optional[int]:
|
| 376 |
+
for env_var in ("GRADIO_SERVER_PORT", "PORT"):
|
| 377 |
+
value = os.environ.get(env_var)
|
| 378 |
+
if value:
|
| 379 |
+
try:
|
| 380 |
+
return int(value)
|
| 381 |
+
except ValueError:
|
| 382 |
+
pass
|
| 383 |
return find_free_port()
|
| 384 |
|
| 385 |
+
|
| 386 |
+
def main():
|
| 387 |
+
demo = build_interface()
|
| 388 |
try:
|
| 389 |
+
port = choose_port()
|
| 390 |
+
print(f"Launching Gradio app on port {port}")
|
| 391 |
+
demo.launch(server_name="0.0.0.0", server_port=port)
|
| 392 |
+
except OSError as exc:
|
| 393 |
+
print("Failed to launch on requested port:", exc)
|
|
|
|
|
|
|
| 394 |
demo.launch()
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
main()
|
fault_classification_pmu.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fault classification training utilities for PMU data.
|
| 2 |
+
|
| 3 |
+
This module trains a CNN-LSTM model on high-frequency PMU measurements to
|
| 4 |
+
classify transmission line faults. It implements a full training pipeline
|
| 5 |
+
including preprocessing, sequence generation, model definition, evaluation,
|
| 6 |
+
and artifact export so the resulting model can be served via the Gradio app
|
| 7 |
+
in this repository or on Hugging Face Spaces.
|
| 8 |
+
|
| 9 |
+
Example
|
| 10 |
+
-------
|
| 11 |
+
python fault_classification_pmu.py \
|
| 12 |
+
--data-path data/Fault_Classification_PMU_Data.csv \
|
| 13 |
+
--label-column FaultType \
|
| 14 |
+
--model-out pmu_cnn_lstm_model.keras \
|
| 15 |
+
--scaler-out pmu_feature_scaler.pkl \
|
| 16 |
+
--metadata-out pmu_metadata.json
|
| 17 |
+
|
| 18 |
+
The script accepts CSV input where each row contains a timestamped PMU
|
| 19 |
+
measurement and a categorical fault label. Features default to the 14 PMU
|
| 20 |
+
channels used in the project documentation, but any subset can be provided
|
| 21 |
+
via the ``--feature-columns`` argument. Data is automatically standardised
|
| 22 |
+
and windowed to create temporal sequences that feed into the neural network.
|
| 23 |
+
|
| 24 |
+
The exported metadata JSON file contains the feature ordering, label names,
|
| 25 |
+
sequence length, and stride. The Gradio front-end consumes this file to
|
| 26 |
+
replicate the same preprocessing steps during inference.
|
| 27 |
+
"""
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import List, Sequence, Tuple
|
| 34 |
+
|
| 35 |
+
import joblib
|
| 36 |
+
import numpy as np
|
| 37 |
+
import pandas as pd
|
| 38 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 39 |
+
from sklearn.model_selection import train_test_split
|
| 40 |
+
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
| 41 |
+
from tensorflow.keras import callbacks, layers, models, optimizers
|
| 42 |
+
|
| 43 |
+
# Default PMU feature set as described in the user provided table. Timestamp is
|
| 44 |
+
# intentionally omitted because it is not a model input feature.
|
| 45 |
+
DEFAULT_FEATURE_COLUMNS: List[str] = [
|
| 46 |
+
"[325] UPMU_SUB22:FREQ",
|
| 47 |
+
"[326] UPMU_SUB22:DFDT",
|
| 48 |
+
"[327] UPMU_SUB22:FLAG",
|
| 49 |
+
"[328] UPMU_SUB22-L1:MAG",
|
| 50 |
+
"[329] UPMU_SUB22-L1:ANG",
|
| 51 |
+
"[330] UPMU_SUB22-L2:MAG",
|
| 52 |
+
"[331] UPMU_SUB22-L2:ANG",
|
| 53 |
+
"[332] UPMU_SUB22-L3:MAG",
|
| 54 |
+
"[333] UPMU_SUB22-L3:ANG",
|
| 55 |
+
"[334] UPMU_SUB22-C1:MAG",
|
| 56 |
+
"[335] UPMU_SUB22-C1:ANG",
|
| 57 |
+
"[336] UPMU_SUB22-C2:MAG",
|
| 58 |
+
"[337] UPMU_SUB22-C2:ANG",
|
| 59 |
+
"[338] UPMU_SUB22-C3:MAG",
|
| 60 |
+
"[339] UPMU_SUB22-C3:ANG",
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _resolve_features(df: pd.DataFrame, feature_columns: Sequence[str] | None, label_column: str) -> List[str]:
|
| 65 |
+
if feature_columns:
|
| 66 |
+
missing = [c for c in feature_columns if c not in df.columns]
|
| 67 |
+
if missing:
|
| 68 |
+
raise ValueError(f"Feature columns not present in CSV: {missing}")
|
| 69 |
+
return list(feature_columns)
|
| 70 |
+
|
| 71 |
+
# Prefer the documented PMU ordering when the columns exist, falling back to
|
| 72 |
+
# any remaining numeric columns.
|
| 73 |
+
preferred = [c for c in DEFAULT_FEATURE_COLUMNS if c in df.columns]
|
| 74 |
+
|
| 75 |
+
excluded = {label_column, label_column.lower(), "timestamp", "Timestamp"}
|
| 76 |
+
remainder = [c for c in df.columns if c not in preferred and c not in excluded]
|
| 77 |
+
ordered = preferred + remainder
|
| 78 |
+
if not ordered:
|
| 79 |
+
raise ValueError("No feature columns detected. Specify --feature-columns explicitly.")
|
| 80 |
+
return ordered
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_dataset(
|
| 84 |
+
csv_path: Path,
|
| 85 |
+
*,
|
| 86 |
+
feature_columns: Sequence[str] | None,
|
| 87 |
+
label_column: str,
|
| 88 |
+
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
|
| 89 |
+
"""Load the dataset from CSV.
|
| 90 |
+
|
| 91 |
+
Parameters
|
| 92 |
+
----------
|
| 93 |
+
csv_path:
|
| 94 |
+
Path to the CSV file containing PMU measurements.
|
| 95 |
+
feature_columns:
|
| 96 |
+
Optional explicit ordering of feature columns.
|
| 97 |
+
label_column:
|
| 98 |
+
Name of the column containing the categorical fault label.
|
| 99 |
+
|
| 100 |
+
Returns
|
| 101 |
+
-------
|
| 102 |
+
features: np.ndarray
|
| 103 |
+
2-D array of shape (n_samples, n_features).
|
| 104 |
+
labels: np.ndarray
|
| 105 |
+
1-D array of label strings.
|
| 106 |
+
columns: list[str]
|
| 107 |
+
Actual feature ordering used.
|
| 108 |
+
"""
|
| 109 |
+
df = pd.read_csv(csv_path)
|
| 110 |
+
if label_column not in df.columns:
|
| 111 |
+
raise ValueError(f"Label column '{label_column}' not found in {csv_path}")
|
| 112 |
+
|
| 113 |
+
columns = _resolve_features(df, feature_columns, label_column)
|
| 114 |
+
features = df[columns].astype(np.float32).values
|
| 115 |
+
labels = df[label_column].astype(str).values
|
| 116 |
+
return features, labels, columns
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def create_sequences(
|
| 120 |
+
features: np.ndarray,
|
| 121 |
+
labels: np.ndarray,
|
| 122 |
+
*,
|
| 123 |
+
sequence_length: int,
|
| 124 |
+
stride: int,
|
| 125 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 126 |
+
"""Create overlapping sequences suitable for sequence models.
|
| 127 |
+
|
| 128 |
+
The label assigned to a sequence corresponds to the label of the final
|
| 129 |
+
timestep in the window. This choice aligns with fault detection use cases
|
| 130 |
+
where the most recent measurement dictates the state of the system.
|
| 131 |
+
"""
|
| 132 |
+
if sequence_length <= 0:
|
| 133 |
+
raise ValueError("sequence_length must be > 0")
|
| 134 |
+
if stride <= 0:
|
| 135 |
+
raise ValueError("stride must be > 0")
|
| 136 |
+
if features.shape[0] != labels.shape[0]:
|
| 137 |
+
raise ValueError("Features and labels must contain the same number of rows")
|
| 138 |
+
if features.shape[0] < sequence_length:
|
| 139 |
+
raise ValueError("Not enough samples to create a single sequence")
|
| 140 |
+
|
| 141 |
+
sequences: List[np.ndarray] = []
|
| 142 |
+
seq_labels: List[str] = []
|
| 143 |
+
for start in range(0, features.shape[0] - sequence_length + 1, stride):
|
| 144 |
+
end = start + sequence_length
|
| 145 |
+
sequences.append(features[start:end])
|
| 146 |
+
seq_labels.append(labels[end - 1])
|
| 147 |
+
return np.stack(sequences), np.array(seq_labels)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def build_cnn_lstm(
|
| 151 |
+
input_shape: Tuple[int, int],
|
| 152 |
+
num_classes: int,
|
| 153 |
+
*,
|
| 154 |
+
conv_filters: int = 128,
|
| 155 |
+
kernel_size: int = 3,
|
| 156 |
+
lstm_units: int = 128,
|
| 157 |
+
dropout: float = 0.3,
|
| 158 |
+
) -> models.Model:
|
| 159 |
+
"""Construct a compact yet expressive CNN-LSTM architecture."""
|
| 160 |
+
inputs = layers.Input(shape=input_shape)
|
| 161 |
+
x = layers.Conv1D(conv_filters, kernel_size, padding="same", activation="relu")(inputs)
|
| 162 |
+
x = layers.BatchNormalization()(x)
|
| 163 |
+
x = layers.Conv1D(conv_filters, kernel_size, dilation_rate=2, padding="same", activation="relu")(x)
|
| 164 |
+
x = layers.BatchNormalization()(x)
|
| 165 |
+
x = layers.Dropout(dropout)(x)
|
| 166 |
+
x = layers.LSTM(lstm_units, return_sequences=False)(x)
|
| 167 |
+
x = layers.Dropout(dropout)(x)
|
| 168 |
+
outputs = layers.Dense(num_classes, activation="softmax")(x)
|
| 169 |
+
model = models.Model(inputs, outputs)
|
| 170 |
+
model.compile(
|
| 171 |
+
optimizer=optimizers.Adam(learning_rate=1e-3),
|
| 172 |
+
loss="sparse_categorical_crossentropy",
|
| 173 |
+
metrics=["accuracy"],
|
| 174 |
+
)
|
| 175 |
+
return model
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def train_model(
|
| 179 |
+
sequences: np.ndarray,
|
| 180 |
+
labels: np.ndarray,
|
| 181 |
+
*,
|
| 182 |
+
validation_split: float,
|
| 183 |
+
batch_size: int,
|
| 184 |
+
epochs: int,
|
| 185 |
+
) -> Tuple[models.Model, LabelEncoder, dict]:
|
| 186 |
+
"""Train the CNN-LSTM model and return training history and validation outputs."""
|
| 187 |
+
label_encoder = LabelEncoder()
|
| 188 |
+
y = label_encoder.fit_transform(labels)
|
| 189 |
+
|
| 190 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 191 |
+
sequences, y, test_size=validation_split, stratify=y, random_state=42
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
model = build_cnn_lstm(input_shape=sequences.shape[1:], num_classes=len(label_encoder.classes_))
|
| 195 |
+
|
| 196 |
+
callbacks_list = [
|
| 197 |
+
callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-5),
|
| 198 |
+
callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
history = model.fit(
|
| 202 |
+
X_train,
|
| 203 |
+
y_train,
|
| 204 |
+
validation_data=(X_val, y_val),
|
| 205 |
+
epochs=epochs,
|
| 206 |
+
batch_size=batch_size,
|
| 207 |
+
callbacks=callbacks_list,
|
| 208 |
+
verbose=2,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
y_pred = model.predict(X_val, verbose=0).argmax(axis=1)
|
| 212 |
+
metrics = {
|
| 213 |
+
"history": history.history,
|
| 214 |
+
"validation": {
|
| 215 |
+
"y_true": y_val,
|
| 216 |
+
"y_pred": y_pred,
|
| 217 |
+
"class_names": label_encoder.classes_.tolist(),
|
| 218 |
+
},
|
| 219 |
+
}
|
| 220 |
+
return model, label_encoder, metrics
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def standardise_sequences(sequences: np.ndarray) -> Tuple[np.ndarray, StandardScaler]:
|
| 224 |
+
"""Apply standard scaling per feature across all timesteps."""
|
| 225 |
+
scaler = StandardScaler()
|
| 226 |
+
flattened = sequences.reshape(-1, sequences.shape[-1])
|
| 227 |
+
scaled = scaler.fit_transform(flattened)
|
| 228 |
+
return scaled.reshape(sequences.shape), scaler
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def export_artifacts(
|
| 232 |
+
*,
|
| 233 |
+
model: models.Model,
|
| 234 |
+
scaler: StandardScaler,
|
| 235 |
+
label_encoder: LabelEncoder,
|
| 236 |
+
feature_columns: Sequence[str],
|
| 237 |
+
label_column: str,
|
| 238 |
+
sequence_length: int,
|
| 239 |
+
stride: int,
|
| 240 |
+
model_path: Path,
|
| 241 |
+
scaler_path: Path,
|
| 242 |
+
metadata_path: Path,
|
| 243 |
+
metrics: dict,
|
| 244 |
+
) -> None:
|
| 245 |
+
"""Persist trained assets to disk for deployment."""
|
| 246 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
scaler_path.parent.mkdir(parents=True, exist_ok=True)
|
| 248 |
+
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
| 249 |
+
model.save(model_path)
|
| 250 |
+
joblib.dump(scaler, scaler_path)
|
| 251 |
+
|
| 252 |
+
metadata = {
|
| 253 |
+
"feature_columns": list(feature_columns),
|
| 254 |
+
"label_classes": label_encoder.classes_.tolist(),
|
| 255 |
+
"label_column": label_column,
|
| 256 |
+
"sequence_length": sequence_length,
|
| 257 |
+
"stride": stride,
|
| 258 |
+
"model_path": str(model_path),
|
| 259 |
+
"scaler_path": str(scaler_path),
|
| 260 |
+
"training_history": metrics["history"],
|
| 261 |
+
"classification_report": classification_report(
|
| 262 |
+
metrics["validation"]["y_true"], metrics["validation"]["y_pred"], target_names=label_encoder.classes_
|
| 263 |
+
),
|
| 264 |
+
"confusion_matrix": metrics["validation"].get("confusion_matrix")
|
| 265 |
+
if metrics["validation"].get("confusion_matrix") is not None
|
| 266 |
+
else None,
|
| 267 |
+
}
|
| 268 |
+
# Add confusion matrix lazily to avoid recomputation.
|
| 269 |
+
if metadata["confusion_matrix"] is None:
|
| 270 |
+
cm = confusion_matrix(metrics["validation"]["y_true"], metrics["validation"]["y_pred"])
|
| 271 |
+
metadata["confusion_matrix"] = cm.tolist()
|
| 272 |
+
|
| 273 |
+
metadata_path.write_text(json.dumps(metadata, indent=2))
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def run_training(args: argparse.Namespace) -> None:
|
| 277 |
+
csv_path = Path(args.data_path)
|
| 278 |
+
model_out = Path(args.model_out)
|
| 279 |
+
scaler_out = Path(args.scaler_out)
|
| 280 |
+
metadata_out = Path(args.metadata_out)
|
| 281 |
+
|
| 282 |
+
features, labels, feature_columns = load_dataset(
|
| 283 |
+
csv_path, feature_columns=args.feature_columns, label_column=args.label_column
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
sequences, seq_labels = create_sequences(
|
| 287 |
+
features,
|
| 288 |
+
labels,
|
| 289 |
+
sequence_length=args.sequence_length,
|
| 290 |
+
stride=args.stride,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
sequences, scaler = standardise_sequences(sequences)
|
| 294 |
+
model, label_encoder, metrics = train_model(
|
| 295 |
+
sequences,
|
| 296 |
+
seq_labels,
|
| 297 |
+
validation_split=args.validation_split,
|
| 298 |
+
batch_size=args.batch_size,
|
| 299 |
+
epochs=args.epochs,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
export_artifacts(
|
| 303 |
+
model=model,
|
| 304 |
+
scaler=scaler,
|
| 305 |
+
label_encoder=label_encoder,
|
| 306 |
+
feature_columns=feature_columns,
|
| 307 |
+
label_column=args.label_column,
|
| 308 |
+
sequence_length=args.sequence_length,
|
| 309 |
+
stride=args.stride,
|
| 310 |
+
model_path=model_out,
|
| 311 |
+
scaler_path=scaler_out,
|
| 312 |
+
metadata_path=metadata_out,
|
| 313 |
+
metrics=metrics,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
print("Training complete")
|
| 317 |
+
print(f"Model saved to : {model_out}")
|
| 318 |
+
print(f"Scaler saved to : {scaler_out}")
|
| 319 |
+
print(f"Metadata saved to : {metadata_out}")
|
| 320 |
+
print("Validation metrics:")
|
| 321 |
+
report = classification_report(
|
| 322 |
+
metrics["validation"]["y_true"], metrics["validation"]["y_pred"], target_names=metrics["validation"]["class_names"]
|
| 323 |
+
)
|
| 324 |
+
print(report)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
|
| 328 |
+
parser = argparse.ArgumentParser(description="Train a CNN-LSTM model for PMU fault classification")
|
| 329 |
+
parser.add_argument("--data-path", required=True, help="Path to Fault_Classification_PMU_Data CSV")
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--label-column",
|
| 332 |
+
default="Fault",
|
| 333 |
+
help="Name of the target label column (default: Fault)",
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--feature-columns",
|
| 337 |
+
nargs="*",
|
| 338 |
+
default=None,
|
| 339 |
+
help="Optional explicit list of feature columns. Defaults to all non-label columns",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument("--sequence-length", type=int, default=32, help="Number of timesteps per training window")
|
| 342 |
+
parser.add_argument("--stride", type=int, default=4, help="Step size between consecutive windows")
|
| 343 |
+
parser.add_argument("--validation-split", type=float, default=0.2, help="Validation set fraction")
|
| 344 |
+
parser.add_argument("--batch-size", type=int, default=128, help="Training batch size")
|
| 345 |
+
parser.add_argument("--epochs", type=int, default=50, help="Maximum number of training epochs")
|
| 346 |
+
parser.add_argument("--model-out", default="pmu_cnn_lstm_model.keras", help="Path to save trained Keras model")
|
| 347 |
+
parser.add_argument("--scaler-out", default="pmu_feature_scaler.pkl", help="Path to save fitted StandardScaler")
|
| 348 |
+
parser.add_argument("--metadata-out", default="pmu_metadata.json", help="Path to save metadata JSON")
|
| 349 |
+
return parser.parse_args(argv)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def main(argv: Sequence[str] | None = None) -> None:
|
| 353 |
+
args = parse_args(argv)
|
| 354 |
+
run_training(args)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
main()
|
requirements.txt
CHANGED
|
@@ -5,3 +5,4 @@ pandas
|
|
| 5 |
scikit-learn
|
| 6 |
huggingface_hub
|
| 7 |
matplotlib
|
|
|
|
|
|
| 5 |
scikit-learn
|
| 6 |
huggingface_hub
|
| 7 |
matplotlib
|
| 8 |
+
joblib
|