|
|
"""Gradio front-end for Fault_Classification_PMU_Data models. |
|
|
|
|
|
The application loads a CNN-LSTM model (and accompanying scaler/metadata) |
|
|
produced by ``fault_classification_pmu.py`` and exposes a streamlined |
|
|
prediction interface optimised for Hugging Face Spaces deployment. It supports |
|
|
raw PMU time-series CSV uploads as well as manual comma separated feature |
|
|
vectors. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
|
|
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") |
|
|
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") |
|
|
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") |
|
|
|
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import gradio as gr |
|
|
import joblib |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import requests |
|
|
from huggingface_hub import hf_hub_download |
|
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
from fault_classification_pmu import ( |
|
|
DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS, |
|
|
LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES, |
|
|
train_from_dataframe, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS) |
|
|
DEFAULT_SEQUENCE_LENGTH = 32 |
|
|
DEFAULT_STRIDE = 4 |
|
|
|
|
|
LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras") |
|
|
LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl") |
|
|
LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json") |
|
|
|
|
|
MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve() |
|
|
MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
HUB_REPO = os.environ.get("PMU_HUB_REPO", "") |
|
|
HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE) |
|
|
HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE) |
|
|
HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE) |
|
|
|
|
|
ENV_MODEL_PATH = "PMU_MODEL_PATH" |
|
|
ENV_SCALER_PATH = "PMU_SCALER_PATH" |
|
|
ENV_METADATA_PATH = "PMU_METADATA_PATH" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_from_hub(filename: str) -> Optional[Path]: |
|
|
if not HUB_REPO or not filename: |
|
|
return None |
|
|
try: |
|
|
print(f"Downloading {filename} from {HUB_REPO} ...") |
|
|
|
|
|
path = hf_hub_download(repo_id=HUB_REPO, filename=filename) |
|
|
print("Downloaded", path) |
|
|
return Path(path) |
|
|
except Exception as exc: |
|
|
print("Failed to download", filename, "from", HUB_REPO, ":", exc) |
|
|
print("Continuing without pre-trained model...") |
|
|
return None |
|
|
|
|
|
|
|
|
def resolve_artifact( |
|
|
local_name: str, env_var: str, hub_filename: str |
|
|
) -> Optional[Path]: |
|
|
print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}") |
|
|
candidates = [Path(local_name)] if local_name else [] |
|
|
if local_name: |
|
|
candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name) |
|
|
env_value = os.environ.get(env_var) |
|
|
if env_value: |
|
|
candidates.append(Path(env_value)) |
|
|
|
|
|
for candidate in candidates: |
|
|
if candidate and candidate.exists(): |
|
|
print(f"Found local artifact: {candidate}") |
|
|
return candidate |
|
|
|
|
|
print(f"No local artifacts found, checking hub...") |
|
|
|
|
|
if HUB_REPO: |
|
|
return download_from_hub(hub_filename) |
|
|
else: |
|
|
print("No HUB_REPO configured, skipping download") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_metadata(path: Optional[Path]) -> Dict: |
|
|
if path and path.exists(): |
|
|
try: |
|
|
return json.loads(path.read_text()) |
|
|
except Exception as exc: |
|
|
print("Failed to read metadata", path, exc) |
|
|
return {} |
|
|
|
|
|
|
|
|
def try_load_scaler(path: Optional[Path]): |
|
|
if not path: |
|
|
return None |
|
|
try: |
|
|
scaler = joblib.load(path) |
|
|
print("Loaded scaler from", path) |
|
|
return scaler |
|
|
except Exception as exc: |
|
|
print("Failed to load scaler", path, exc) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
print("Starting application initialization...") |
|
|
try: |
|
|
MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME) |
|
|
print(f"Model path resolved: {MODEL_PATH}") |
|
|
except Exception as e: |
|
|
print(f"Model path resolution failed: {e}") |
|
|
MODEL_PATH = None |
|
|
|
|
|
try: |
|
|
SCALER_PATH = resolve_artifact( |
|
|
LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME |
|
|
) |
|
|
print(f"Scaler path resolved: {SCALER_PATH}") |
|
|
except Exception as e: |
|
|
print(f"Scaler path resolution failed: {e}") |
|
|
SCALER_PATH = None |
|
|
|
|
|
try: |
|
|
METADATA_PATH = resolve_artifact( |
|
|
LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME |
|
|
) |
|
|
print(f"Metadata path resolved: {METADATA_PATH}") |
|
|
except Exception as e: |
|
|
print(f"Metadata path resolution failed: {e}") |
|
|
METADATA_PATH = None |
|
|
|
|
|
try: |
|
|
METADATA = load_metadata(METADATA_PATH) |
|
|
print(f"Metadata loaded: {len(METADATA)} entries") |
|
|
except Exception as e: |
|
|
print(f"Metadata loading failed: {e}") |
|
|
METADATA = {} |
|
|
|
|
|
|
|
|
QUEUE_MAX_SIZE = 32 |
|
|
|
|
|
|
|
|
EVENT_CONCURRENCY_LIMIT = 2 |
|
|
|
|
|
|
|
|
def try_load_model(path: Optional[Path], model_type: str, model_format: str): |
|
|
if not path: |
|
|
return None |
|
|
try: |
|
|
if model_type == "svm" or model_format == "joblib": |
|
|
model = joblib.load(path) |
|
|
else: |
|
|
model = load_model(path) |
|
|
print("Loaded model from", path) |
|
|
return model |
|
|
except Exception as exc: |
|
|
print("Failed to load model", path, exc) |
|
|
return None |
|
|
|
|
|
|
|
|
FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS) |
|
|
LABEL_CLASSES: List[str] = [] |
|
|
LABEL_COLUMN: str = "Fault" |
|
|
SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH |
|
|
DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE |
|
|
MODEL_TYPE: str = "cnn_lstm" |
|
|
MODEL_FORMAT: str = "keras" |
|
|
|
|
|
|
|
|
def _model_output_path(filename: str) -> str: |
|
|
return str(MODEL_OUTPUT_DIR / Path(filename).name) |
|
|
|
|
|
|
|
|
MODEL_FILENAME_BY_TYPE: Dict[str, str] = { |
|
|
"cnn_lstm": Path(LOCAL_MODEL_FILE).name, |
|
|
"tcn": "pmu_tcn_model.keras", |
|
|
"svm": "pmu_svm_model.joblib", |
|
|
} |
|
|
|
|
|
REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS) |
|
|
TRAINING_UPLOAD_DIR = Path( |
|
|
os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads") |
|
|
) |
|
|
TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
TRAINING_DATA_REPO = os.environ.get( |
|
|
"PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData" |
|
|
) |
|
|
TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main") |
|
|
TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset")) |
|
|
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {} |
|
|
|
|
|
|
|
|
APP_CSS = """ |
|
|
#available-files-section { |
|
|
position: relative; |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
gap: 0.75rem; |
|
|
border-radius: 0.75rem; |
|
|
isolation: isolate; |
|
|
} |
|
|
|
|
|
#available-files-grid { |
|
|
position: relative; |
|
|
overflow: visible; |
|
|
} |
|
|
|
|
|
#available-files-grid .form { |
|
|
position: static; |
|
|
min-height: 16rem; |
|
|
} |
|
|
|
|
|
#available-files-grid .wrap { |
|
|
display: grid; |
|
|
grid-template-columns: repeat(4, minmax(0, 1fr)); |
|
|
gap: 0.5rem; |
|
|
max-height: 24rem; |
|
|
min-height: 16rem; |
|
|
overflow-y: auto; |
|
|
padding-right: 0.25rem; |
|
|
} |
|
|
|
|
|
#available-files-grid .wrap > div { |
|
|
min-width: 0; |
|
|
} |
|
|
|
|
|
#available-files-grid .wrap label { |
|
|
margin: 0; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
padding: 0.45rem 0.65rem; |
|
|
border-radius: 0.65rem; |
|
|
background-color: rgba(255, 255, 255, 0.05); |
|
|
border: 1px solid rgba(255, 255, 255, 0.08); |
|
|
transition: background-color 0.2s ease, border-color 0.2s ease; |
|
|
min-height: 2.5rem; |
|
|
} |
|
|
|
|
|
#available-files-grid .wrap label:hover { |
|
|
background-color: rgba(90, 200, 250, 0.16); |
|
|
border-color: rgba(90, 200, 250, 0.4); |
|
|
} |
|
|
|
|
|
#available-files-grid .wrap label span { |
|
|
overflow: hidden; |
|
|
text-overflow: ellipsis; |
|
|
white-space: nowrap; |
|
|
} |
|
|
|
|
|
#available-files-section .gradio-loading, |
|
|
#available-files-grid .gradio-loading { |
|
|
position: absolute; |
|
|
top: 0; |
|
|
left: 0; |
|
|
right: 0; |
|
|
bottom: 0; |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
background: rgba(10, 14, 23, 0.92); |
|
|
border-radius: 0.75rem; |
|
|
z-index: 999; |
|
|
padding: 1.5rem; |
|
|
pointer-events: auto; |
|
|
} |
|
|
|
|
|
#available-files-section .gradio-loading { |
|
|
position: absolute; |
|
|
inset: 0; |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
background: rgba(10, 14, 23, 0.92); |
|
|
border-radius: 0.75rem; |
|
|
z-index: 999; |
|
|
padding: 1.5rem; |
|
|
pointer-events: auto; |
|
|
} |
|
|
|
|
|
#available-files-section .gradio-loading > * { |
|
|
width: 100%; |
|
|
} |
|
|
|
|
|
#available-files-section .gradio-loading progress, |
|
|
#available-files-section .gradio-loading .progress-bar, |
|
|
#available-files-section .gradio-loading .loading-progress, |
|
|
#available-files-section .gradio-loading [role="progressbar"], |
|
|
#available-files-section .gradio-loading .wrap, |
|
|
#available-files-section .gradio-loading .inner { |
|
|
width: 100% !important; |
|
|
max-width: none !important; |
|
|
} |
|
|
|
|
|
#available-files-section .gradio-loading .status, |
|
|
#available-files-section .gradio-loading .message, |
|
|
#available-files-section .gradio-loading .label { |
|
|
text-align: center; |
|
|
} |
|
|
|
|
|
#date-browser-row { |
|
|
gap: 0.75rem; |
|
|
} |
|
|
|
|
|
#date-browser-row .date-browser-column { |
|
|
flex: 1 1 0%; |
|
|
min-width: 0; |
|
|
} |
|
|
|
|
|
#date-browser-row .date-browser-column > .gradio-dropdown, |
|
|
#date-browser-row .date-browser-column > .gradio-button { |
|
|
width: 100%; |
|
|
} |
|
|
|
|
|
#date-browser-row .date-browser-column > .gradio-dropdown > div { |
|
|
width: 100%; |
|
|
} |
|
|
|
|
|
#date-browser-row .date-browser-column .gradio-button { |
|
|
justify-content: center; |
|
|
} |
|
|
|
|
|
#training-files-summary textarea { |
|
|
max-height: 12rem; |
|
|
overflow-y: auto; |
|
|
} |
|
|
|
|
|
#download-selected-button { |
|
|
width: 100%; |
|
|
position: relative; |
|
|
z-index: 0; |
|
|
} |
|
|
|
|
|
#download-selected-button .gradio-button { |
|
|
width: 100%; |
|
|
justify-content: center; |
|
|
} |
|
|
|
|
|
#artifact-download-row { |
|
|
gap: 0.75rem; |
|
|
} |
|
|
|
|
|
#artifact-download-row .artifact-download-button { |
|
|
flex: 1 1 0%; |
|
|
min-width: 0; |
|
|
} |
|
|
|
|
|
#artifact-download-row .artifact-download-button .gradio-button { |
|
|
width: 100%; |
|
|
justify-content: center; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
def _github_cache_key(path: str) -> str: |
|
|
return path or "__root__" |
|
|
|
|
|
|
|
|
def _github_api_url(path: str) -> str: |
|
|
clean_path = path.strip("/") |
|
|
base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents" |
|
|
if clean_path: |
|
|
return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}" |
|
|
return f"{base}?ref={TRAINING_DATA_BRANCH}" |
|
|
|
|
|
|
|
|
def list_remote_directory( |
|
|
path: str = "", *, force_refresh: bool = False |
|
|
) -> List[Dict[str, Any]]: |
|
|
key = _github_cache_key(path) |
|
|
if not force_refresh and key in GITHUB_CONTENT_CACHE: |
|
|
return GITHUB_CONTENT_CACHE[key] |
|
|
|
|
|
url = _github_api_url(path) |
|
|
response = requests.get(url, timeout=30) |
|
|
if response.status_code != 200: |
|
|
raise RuntimeError( |
|
|
f"GitHub API request failed for `{path or '.'}` (status {response.status_code})." |
|
|
) |
|
|
|
|
|
payload = response.json() |
|
|
if not isinstance(payload, list): |
|
|
raise RuntimeError( |
|
|
"Unexpected GitHub API payload. Expected a directory listing." |
|
|
) |
|
|
|
|
|
GITHUB_CONTENT_CACHE[key] = payload |
|
|
return payload |
|
|
|
|
|
|
|
|
def list_remote_years(force_refresh: bool = False) -> List[str]: |
|
|
entries = list_remote_directory("", force_refresh=force_refresh) |
|
|
years = [item["name"] for item in entries if item.get("type") == "dir"] |
|
|
return sorted(years) |
|
|
|
|
|
|
|
|
def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]: |
|
|
if not year: |
|
|
return [] |
|
|
entries = list_remote_directory(year, force_refresh=force_refresh) |
|
|
months = [item["name"] for item in entries if item.get("type") == "dir"] |
|
|
return sorted(months) |
|
|
|
|
|
|
|
|
def list_remote_days( |
|
|
year: str, month: str, *, force_refresh: bool = False |
|
|
) -> List[str]: |
|
|
if not year or not month: |
|
|
return [] |
|
|
entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh) |
|
|
days = [item["name"] for item in entries if item.get("type") == "dir"] |
|
|
return sorted(days) |
|
|
|
|
|
|
|
|
def list_remote_files( |
|
|
year: str, month: str, day: str, *, force_refresh: bool = False |
|
|
) -> List[str]: |
|
|
if not year or not month or not day: |
|
|
return [] |
|
|
entries = list_remote_directory( |
|
|
f"{year}/{month}/{day}", force_refresh=force_refresh |
|
|
) |
|
|
files = [item["name"] for item in entries if item.get("type") == "file"] |
|
|
return sorted(files) |
|
|
|
|
|
|
|
|
def download_repository_file(year: str, month: str, day: str, filename: str) -> Path: |
|
|
if not filename: |
|
|
raise ValueError("Filename cannot be empty when downloading repository data.") |
|
|
|
|
|
relative_parts = [part for part in (year, month, day, filename) if part] |
|
|
if len(relative_parts) < 4: |
|
|
raise ValueError("Provide year, month, day, and filename to download a CSV.") |
|
|
|
|
|
relative_path = "/".join(relative_parts) |
|
|
raw_url = ( |
|
|
f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/" |
|
|
f"{TRAINING_DATA_BRANCH}/{relative_path}" |
|
|
) |
|
|
|
|
|
response = requests.get(raw_url, stream=True, timeout=120) |
|
|
if response.status_code != 200: |
|
|
raise RuntimeError( |
|
|
f"Failed to download `{relative_path}` (status {response.status_code})." |
|
|
) |
|
|
|
|
|
target_dir = TRAINING_DATA_DIR.joinpath(year, month, day) |
|
|
target_dir.mkdir(parents=True, exist_ok=True) |
|
|
target_path = target_dir / filename |
|
|
|
|
|
with open(target_path, "wb") as handle: |
|
|
for chunk in response.iter_content(chunk_size=1 << 20): |
|
|
if chunk: |
|
|
handle.write(chunk) |
|
|
|
|
|
return target_path |
|
|
|
|
|
|
|
|
def _normalise_header(name: str) -> str: |
|
|
return str(name).strip().lower() |
|
|
|
|
|
|
|
|
def guess_label_from_columns( |
|
|
columns: Sequence[str], preferred: Optional[str] = None |
|
|
) -> Optional[str]: |
|
|
if not columns: |
|
|
return preferred |
|
|
|
|
|
lookup = {_normalise_header(col): str(col) for col in columns} |
|
|
|
|
|
if preferred: |
|
|
preferred_stripped = preferred.strip() |
|
|
for col in columns: |
|
|
if str(col).strip() == preferred_stripped: |
|
|
return str(col) |
|
|
preferred_norm = _normalise_header(preferred) |
|
|
if preferred_norm in lookup: |
|
|
return lookup[preferred_norm] |
|
|
|
|
|
for guess in TRAINING_LABEL_GUESSES: |
|
|
guess_norm = _normalise_header(guess) |
|
|
if guess_norm in lookup: |
|
|
return lookup[guess_norm] |
|
|
|
|
|
for col in columns: |
|
|
if _normalise_header(col).startswith("fault"): |
|
|
return str(col) |
|
|
|
|
|
return str(columns[0]) |
|
|
|
|
|
|
|
|
def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str: |
|
|
lines = [Path(path).name for path in paths] |
|
|
lines.extend(notes) |
|
|
return "\n".join(lines) if lines else "No training files available." |
|
|
|
|
|
|
|
|
def read_training_status(status_file_path: str) -> str: |
|
|
"""Read the current training status from file.""" |
|
|
try: |
|
|
if Path(status_file_path).exists(): |
|
|
with open(status_file_path, "r") as f: |
|
|
return f.read().strip() |
|
|
except Exception: |
|
|
pass |
|
|
return "Training status unavailable" |
|
|
|
|
|
|
|
|
def _persist_uploaded_file(file_obj) -> Optional[Path]: |
|
|
if file_obj is None: |
|
|
return None |
|
|
|
|
|
if isinstance(file_obj, (str, Path)): |
|
|
source = Path(file_obj) |
|
|
original_name = source.name |
|
|
else: |
|
|
source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", "")) |
|
|
original_name = getattr(file_obj, "orig_name", source.name) or source.name |
|
|
if not source or not source.exists(): |
|
|
return None |
|
|
|
|
|
original_name = Path(original_name).name or source.name |
|
|
|
|
|
base_path = Path(original_name) |
|
|
destination = TRAINING_UPLOAD_DIR / base_path.name |
|
|
counter = 1 |
|
|
while destination.exists(): |
|
|
suffix = base_path.suffix or ".csv" |
|
|
destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}" |
|
|
counter += 1 |
|
|
|
|
|
shutil.copy2(source, destination) |
|
|
return destination |
|
|
|
|
|
|
|
|
def prepare_training_paths( |
|
|
paths: Sequence[str], current_label: str, cleanup_missing: bool = False |
|
|
): |
|
|
valid_paths: List[str] = [] |
|
|
notes: List[str] = [] |
|
|
columns_map: Dict[str, str] = {} |
|
|
for path in paths: |
|
|
try: |
|
|
df = load_measurement_csv(path) |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ Skipped {Path(path).name}: {exc}") |
|
|
if cleanup_missing: |
|
|
try: |
|
|
Path(path).unlink(missing_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
continue |
|
|
valid_paths.append(str(path)) |
|
|
for col in df.columns: |
|
|
columns_map[_normalise_header(col)] = str(col) |
|
|
|
|
|
summary = summarise_training_files(valid_paths, notes) |
|
|
preferred = current_label or LABEL_COLUMN |
|
|
dropdown_choices = ( |
|
|
sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN] |
|
|
) |
|
|
guessed = guess_label_from_columns(dropdown_choices, preferred) |
|
|
dropdown_value = guessed or preferred or LABEL_COLUMN |
|
|
|
|
|
return ( |
|
|
valid_paths, |
|
|
summary, |
|
|
gr.update(choices=dropdown_choices, value=dropdown_value), |
|
|
) |
|
|
|
|
|
|
|
|
def append_training_files(new_files, existing_paths: Sequence[str], current_label: str): |
|
|
if isinstance(existing_paths, (str, Path)): |
|
|
paths: List[str] = [str(existing_paths)] |
|
|
elif existing_paths is None: |
|
|
paths = [] |
|
|
else: |
|
|
paths = list(existing_paths) |
|
|
if new_files: |
|
|
for file in new_files: |
|
|
persisted = _persist_uploaded_file(file) |
|
|
if persisted is None: |
|
|
continue |
|
|
path_str = str(persisted) |
|
|
if path_str not in paths: |
|
|
paths.append(path_str) |
|
|
|
|
|
return prepare_training_paths(paths, current_label, cleanup_missing=True) |
|
|
|
|
|
|
|
|
def load_repository_training_files(current_label: str, force_refresh: bool = False): |
|
|
if force_refresh: |
|
|
|
|
|
for cached in list(TRAINING_DATA_DIR.glob("*")): |
|
|
|
|
|
|
|
|
break |
|
|
|
|
|
csv_paths = sorted( |
|
|
str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file() |
|
|
) |
|
|
if not csv_paths: |
|
|
message = ( |
|
|
"No local database CSVs are available yet. Use the database browser " |
|
|
"below to download specific days before training." |
|
|
) |
|
|
default_label = current_label or LABEL_COLUMN or "Fault" |
|
|
return ( |
|
|
[], |
|
|
message, |
|
|
gr.update(choices=[default_label], value=default_label), |
|
|
message, |
|
|
) |
|
|
|
|
|
valid_paths, summary, label_update = prepare_training_paths( |
|
|
csv_paths, current_label, cleanup_missing=False |
|
|
) |
|
|
|
|
|
info = ( |
|
|
f"Ready with {len(valid_paths)} CSV file(s) cached locally under " |
|
|
f"the database cache `{TRAINING_DATA_DIR}`." |
|
|
) |
|
|
|
|
|
return valid_paths, summary, label_update, info |
|
|
|
|
|
|
|
|
def refresh_remote_browser(force_refresh: bool = False): |
|
|
if force_refresh: |
|
|
GITHUB_CONTENT_CACHE.clear() |
|
|
try: |
|
|
years = list_remote_years(force_refresh=force_refresh) |
|
|
if years: |
|
|
message = "Select a year, month, and day to list available CSV files." |
|
|
else: |
|
|
message = ( |
|
|
"⚠️ No directories were found in the database root. Verify the upstream " |
|
|
"structure." |
|
|
) |
|
|
return ( |
|
|
gr.update(choices=years, value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
message, |
|
|
) |
|
|
except Exception as exc: |
|
|
return ( |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
f"⚠️ Failed to query database: {exc}", |
|
|
) |
|
|
|
|
|
|
|
|
def on_year_change(year: Optional[str]): |
|
|
if not year: |
|
|
return ( |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
"Select a year to continue.", |
|
|
) |
|
|
try: |
|
|
months = list_remote_months(year) |
|
|
message = ( |
|
|
f"Year `{year}` selected. Choose a month to drill down." |
|
|
if months |
|
|
else f"⚠️ No months available under `{year}`." |
|
|
) |
|
|
return ( |
|
|
gr.update(choices=months, value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
message, |
|
|
) |
|
|
except Exception as exc: |
|
|
return ( |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
f"⚠️ Failed to list months: {exc}", |
|
|
) |
|
|
|
|
|
|
|
|
def on_month_change(year: Optional[str], month: Optional[str]): |
|
|
if not year or not month: |
|
|
return ( |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
"Select a month to continue.", |
|
|
) |
|
|
try: |
|
|
days = list_remote_days(year, month) |
|
|
message = ( |
|
|
f"Month `{year}/{month}` ready. Pick a day to view files." |
|
|
if days |
|
|
else f"⚠️ No day folders found under `{year}/{month}`." |
|
|
) |
|
|
return ( |
|
|
gr.update(choices=days, value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
message, |
|
|
) |
|
|
except Exception as exc: |
|
|
return ( |
|
|
gr.update(choices=[], value=None), |
|
|
gr.update(choices=[], value=[]), |
|
|
f"⚠️ Failed to list days: {exc}", |
|
|
) |
|
|
|
|
|
|
|
|
def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]): |
|
|
if not year or not month or not day: |
|
|
return ( |
|
|
gr.update(choices=[], value=[]), |
|
|
"Select a day to load file names.", |
|
|
) |
|
|
try: |
|
|
files = list_remote_files(year, month, day) |
|
|
message = ( |
|
|
f"{len(files)} file(s) available for `{year}/{month}/{day}`." |
|
|
if files |
|
|
else f"⚠️ No CSV files found under `{year}/{month}/{day}`." |
|
|
) |
|
|
return ( |
|
|
gr.update(choices=files, value=[]), |
|
|
message, |
|
|
) |
|
|
except Exception as exc: |
|
|
return ( |
|
|
gr.update(choices=[], value=[]), |
|
|
f"⚠️ Failed to list files: {exc}", |
|
|
) |
|
|
|
|
|
|
|
|
def download_selected_files( |
|
|
year: Optional[str], |
|
|
month: Optional[str], |
|
|
day: Optional[str], |
|
|
filenames: Sequence[str], |
|
|
current_label: str, |
|
|
): |
|
|
if not filenames: |
|
|
message = "Select at least one CSV before downloading." |
|
|
local = load_repository_training_files(current_label) |
|
|
return (*local, gr.update(), message) |
|
|
|
|
|
success: List[str] = [] |
|
|
notes: List[str] = [] |
|
|
for filename in filenames: |
|
|
try: |
|
|
path = download_repository_file( |
|
|
year or "", month or "", day or "", filename |
|
|
) |
|
|
success.append(str(path)) |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ {filename}: {exc}") |
|
|
|
|
|
local = load_repository_training_files(current_label) |
|
|
|
|
|
message_lines = [] |
|
|
if success: |
|
|
message_lines.append( |
|
|
f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`." |
|
|
) |
|
|
if notes: |
|
|
message_lines.extend(notes) |
|
|
if not message_lines: |
|
|
message_lines.append("No files were downloaded.") |
|
|
|
|
|
return (*local, gr.update(value=[]), "\n".join(message_lines)) |
|
|
|
|
|
|
|
|
def download_day_bundle( |
|
|
year: Optional[str], |
|
|
month: Optional[str], |
|
|
day: Optional[str], |
|
|
current_label: str, |
|
|
): |
|
|
if not (year and month and day): |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
"Select a year, month, and day before downloading an entire day.", |
|
|
) |
|
|
|
|
|
try: |
|
|
files = list_remote_files(year, month, day) |
|
|
except Exception as exc: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}", |
|
|
) |
|
|
|
|
|
if not files: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
f"No CSV files were found for `{year}/{month}/{day}`.", |
|
|
) |
|
|
|
|
|
result = list(download_selected_files(year, month, day, files, current_label)) |
|
|
result[-1] = ( |
|
|
f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n" |
|
|
f"{result[-1]}" |
|
|
) |
|
|
return tuple(result) |
|
|
|
|
|
|
|
|
def download_month_bundle( |
|
|
year: Optional[str], month: Optional[str], current_label: str |
|
|
): |
|
|
if not (year and month): |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
"Select a year and month before downloading an entire month.", |
|
|
) |
|
|
|
|
|
try: |
|
|
days = list_remote_days(year, month) |
|
|
except Exception as exc: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}", |
|
|
) |
|
|
|
|
|
if not days: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
f"No day folders were found for `{year}/{month}`.", |
|
|
) |
|
|
|
|
|
downloaded = 0 |
|
|
notes: List[str] = [] |
|
|
for day in days: |
|
|
try: |
|
|
files = list_remote_files(year, month, day) |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}") |
|
|
continue |
|
|
if not files: |
|
|
notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.") |
|
|
continue |
|
|
for filename in files: |
|
|
try: |
|
|
download_repository_file(year, month, day, filename) |
|
|
downloaded += 1 |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}") |
|
|
|
|
|
local = load_repository_training_files(current_label) |
|
|
message_lines = [] |
|
|
if downloaded: |
|
|
message_lines.append( |
|
|
f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the " |
|
|
f"database cache `{TRAINING_DATA_DIR}`." |
|
|
) |
|
|
message_lines.extend(notes) |
|
|
if not message_lines: |
|
|
message_lines.append("No files were downloaded.") |
|
|
|
|
|
return (*local, gr.update(value=[]), "\n".join(message_lines)) |
|
|
|
|
|
|
|
|
def download_year_bundle(year: Optional[str], current_label: str): |
|
|
if not year: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
"Select a year before downloading an entire year of CSVs.", |
|
|
) |
|
|
|
|
|
try: |
|
|
months = list_remote_months(year) |
|
|
except Exception as exc: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
f"⚠️ Failed to enumerate months for `{year}`: {exc}", |
|
|
) |
|
|
|
|
|
if not months: |
|
|
local = load_repository_training_files(current_label) |
|
|
return ( |
|
|
*local, |
|
|
gr.update(), |
|
|
f"No month folders were found for `{year}`.", |
|
|
) |
|
|
|
|
|
downloaded = 0 |
|
|
notes: List[str] = [] |
|
|
for month in months: |
|
|
try: |
|
|
days = list_remote_days(year, month) |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}") |
|
|
continue |
|
|
if not days: |
|
|
notes.append(f"⚠️ No day folders in `{year}/{month}`.") |
|
|
continue |
|
|
for day in days: |
|
|
try: |
|
|
files = list_remote_files(year, month, day) |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}") |
|
|
continue |
|
|
if not files: |
|
|
notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.") |
|
|
continue |
|
|
for filename in files: |
|
|
try: |
|
|
download_repository_file(year, month, day, filename) |
|
|
downloaded += 1 |
|
|
except Exception as exc: |
|
|
notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}") |
|
|
|
|
|
local = load_repository_training_files(current_label) |
|
|
message_lines = [] |
|
|
if downloaded: |
|
|
message_lines.append( |
|
|
f"Downloaded {downloaded} CSV file(s) for `{year}` into the " |
|
|
f"database cache `{TRAINING_DATA_DIR}`." |
|
|
) |
|
|
message_lines.extend(notes) |
|
|
if not message_lines: |
|
|
message_lines.append("No files were downloaded.") |
|
|
|
|
|
return (*local, gr.update(value=[]), "\n".join(message_lines)) |
|
|
|
|
|
|
|
|
def clear_downloaded_cache(current_label: str): |
|
|
status_message = "" |
|
|
try: |
|
|
if TRAINING_DATA_DIR.exists(): |
|
|
shutil.rmtree(TRAINING_DATA_DIR) |
|
|
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True) |
|
|
status_message = ( |
|
|
f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`." |
|
|
) |
|
|
except Exception as exc: |
|
|
status_message = f"⚠️ Failed to clear database cache: {exc}" |
|
|
|
|
|
local = load_repository_training_files(current_label, force_refresh=True) |
|
|
remote = list(refresh_remote_browser(force_refresh=False)) |
|
|
if status_message: |
|
|
previous = remote[-1] |
|
|
if isinstance(previous, str) and previous: |
|
|
remote[-1] = f"{status_message}\n{previous}" |
|
|
else: |
|
|
remote[-1] = status_message |
|
|
|
|
|
return (*local, *remote) |
|
|
|
|
|
|
|
|
def normalise_output_directory(directory: Optional[str]) -> Path: |
|
|
base = Path(directory or MODEL_OUTPUT_DIR) |
|
|
base = base.expanduser() |
|
|
if not base.is_absolute(): |
|
|
base = (Path.cwd() / base).resolve() |
|
|
return base |
|
|
|
|
|
|
|
|
def resolve_output_path( |
|
|
directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str |
|
|
) -> Path: |
|
|
if isinstance(directory, Path): |
|
|
base = directory |
|
|
else: |
|
|
base = normalise_output_directory(directory) |
|
|
candidate = Path(filename or "").expanduser() |
|
|
if str(candidate): |
|
|
if candidate.is_absolute(): |
|
|
return candidate |
|
|
return (base / candidate).resolve() |
|
|
return (base / fallback).resolve() |
|
|
|
|
|
|
|
|
ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = ( |
|
|
".keras", |
|
|
".h5", |
|
|
".joblib", |
|
|
".pkl", |
|
|
".json", |
|
|
".onnx", |
|
|
".zip", |
|
|
".txt", |
|
|
) |
|
|
|
|
|
|
|
|
def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]: |
|
|
base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR)) |
|
|
candidates = {str(base)} |
|
|
try: |
|
|
for candidate in base.parent.iterdir(): |
|
|
if candidate.is_dir(): |
|
|
candidates.add(str(candidate.resolve())) |
|
|
except Exception: |
|
|
pass |
|
|
return sorted(candidates), str(base) |
|
|
|
|
|
|
|
|
def gather_artifact_choices( |
|
|
directory: Optional[str], selection: Optional[str] = None |
|
|
) -> Tuple[List[Tuple[str, str]], Optional[str]]: |
|
|
base = normalise_output_directory(directory) |
|
|
choices: List[Tuple[str, str]] = [] |
|
|
selected_value: Optional[str] = None |
|
|
if base.exists(): |
|
|
try: |
|
|
artifacts = sorted( |
|
|
[ |
|
|
path |
|
|
for path in base.iterdir() |
|
|
if path.is_file() |
|
|
and ( |
|
|
not ARTIFACT_FILE_EXTENSIONS |
|
|
or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS |
|
|
) |
|
|
], |
|
|
key=lambda path: path.name.lower(), |
|
|
) |
|
|
choices = [(artifact.name, str(artifact)) for artifact in artifacts] |
|
|
except Exception: |
|
|
choices = [] |
|
|
|
|
|
if selection and any(value == selection for _, value in choices): |
|
|
selected_value = selection |
|
|
elif choices: |
|
|
selected_value = choices[0][1] |
|
|
|
|
|
return choices, selected_value |
|
|
|
|
|
|
|
|
def download_button_state(path: Optional[Union[str, Path]]): |
|
|
if not path: |
|
|
return gr.update(value=None, visible=False) |
|
|
candidate = Path(path) |
|
|
if candidate.exists(): |
|
|
return gr.update(value=str(candidate), visible=True) |
|
|
return gr.update(value=None, visible=False) |
|
|
|
|
|
|
|
|
def clear_training_files(): |
|
|
default_label = LABEL_COLUMN or "Fault" |
|
|
for cached_file in TRAINING_UPLOAD_DIR.glob("*"): |
|
|
try: |
|
|
if cached_file.is_file(): |
|
|
cached_file.unlink(missing_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
return ( |
|
|
[], |
|
|
"No training files selected.", |
|
|
gr.update(choices=[default_label], value=default_label), |
|
|
gr.update(value=None), |
|
|
) |
|
|
|
|
|
|
|
|
PROJECT_OVERVIEW_MD = """ |
|
|
## Project Overview |
|
|
|
|
|
This project focuses on classifying faults in electrical transmission lines and |
|
|
grid-connected photovoltaic (PV) systems by combining ensemble learning |
|
|
techniques with deep neural architectures. |
|
|
|
|
|
## Datasets |
|
|
|
|
|
### Transmission Line Fault Dataset |
|
|
- 134,406 samples collected from Phasor Measurement Units (PMUs) |
|
|
- 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles |
|
|
- Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G |
|
|
- Time span: 0 to 5.7 seconds with high-frequency sampling |
|
|
|
|
|
### Grid-Connected PV System Fault Dataset |
|
|
- 2,163,480 samples from 16 experimental scenarios |
|
|
- 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf) |
|
|
- Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals |
|
|
|
|
|
## Data Format Quick Reference |
|
|
|
|
|
Each measurement file may be comma or tab separated and typically exposes the |
|
|
following ordered columns: |
|
|
|
|
|
1. `Timestamp` |
|
|
2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz) |
|
|
3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change |
|
|
4. `[327] UPMU_SUB22:FLAG` – PMU status flag |
|
|
5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude |
|
|
6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle |
|
|
7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude |
|
|
8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle |
|
|
9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude |
|
|
10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle |
|
|
11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude |
|
|
12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle |
|
|
13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude |
|
|
14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle |
|
|
15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude |
|
|
16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle |
|
|
|
|
|
The training tab automatically downloads the latest CSV exports from the |
|
|
`VincentCroft/ThesisModelData` repository and concatenates them before building |
|
|
sliding windows. |
|
|
|
|
|
## Models Developed |
|
|
|
|
|
1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV). |
|
|
2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy. |
|
|
3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV). |
|
|
|
|
|
## Results Summary |
|
|
|
|
|
- **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94% |
|
|
- **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91% |
|
|
|
|
|
Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to |
|
|
fine-tune or retrain any of the supported models directly within Hugging Face |
|
|
Spaces. The logs panel will surface TensorBoard archives whenever deep-learning |
|
|
models are trained. |
|
|
""" |
|
|
|
|
|
|
|
|
def load_measurement_csv(path: str) -> pd.DataFrame: |
|
|
"""Read a PMU/PV measurement file with flexible separators and column mapping.""" |
|
|
|
|
|
try: |
|
|
df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig") |
|
|
except Exception: |
|
|
df = None |
|
|
for separator in ("\t", ",", ";"): |
|
|
try: |
|
|
df = pd.read_csv( |
|
|
path, sep=separator, engine="python", encoding="utf-8-sig" |
|
|
) |
|
|
break |
|
|
except Exception: |
|
|
df = None |
|
|
if df is None: |
|
|
raise |
|
|
|
|
|
|
|
|
df.columns = [str(col).strip() for col in df.columns] |
|
|
|
|
|
print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns") |
|
|
print(f"Columns: {list(df.columns)}") |
|
|
print(f"Data shape: {df.shape}") |
|
|
|
|
|
|
|
|
if len(df) < 100: |
|
|
print( |
|
|
f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training." |
|
|
) |
|
|
|
|
|
|
|
|
has_label = any( |
|
|
col.lower() in ["fault", "label", "class", "target"] for col in df.columns |
|
|
) |
|
|
if not has_label: |
|
|
print( |
|
|
"Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples." |
|
|
) |
|
|
df["Fault"] = "Normal" |
|
|
|
|
|
|
|
|
column_mapping = {} |
|
|
expected_cols = list(REQUIRED_PMU_COLUMNS) |
|
|
|
|
|
|
|
|
if "Timestamp" in df.columns: |
|
|
numeric_cols = [col for col in df.columns if col != "Timestamp"] |
|
|
if len(numeric_cols) >= len(expected_cols): |
|
|
|
|
|
for i, expected_col in enumerate(expected_cols): |
|
|
if i < len(numeric_cols): |
|
|
column_mapping[numeric_cols[i]] = expected_col |
|
|
|
|
|
|
|
|
df = df.rename(columns=column_mapping) |
|
|
|
|
|
|
|
|
missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns] |
|
|
if missing: |
|
|
|
|
|
available_numeric = df.select_dtypes(include=[np.number]).columns.tolist() |
|
|
if len(available_numeric) >= len(expected_cols): |
|
|
|
|
|
for i, expected_col in enumerate(expected_cols): |
|
|
if i < len(available_numeric): |
|
|
if available_numeric[i] not in df.columns: |
|
|
continue |
|
|
df = df.rename(columns={available_numeric[i]: expected_col}) |
|
|
|
|
|
|
|
|
missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns] |
|
|
|
|
|
if missing: |
|
|
missing_str = ", ".join(missing) |
|
|
available_str = ", ".join(df.columns.tolist()) |
|
|
raise ValueError( |
|
|
f"Missing required PMU feature columns: {missing_str}. " |
|
|
f"Available columns: {available_str}. " |
|
|
"Please ensure your CSV has the correct format with Timestamp followed by PMU measurements." |
|
|
) |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def apply_metadata(metadata: Dict[str, Any]) -> None: |
|
|
global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT |
|
|
FEATURE_COLUMNS = [ |
|
|
str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS) |
|
|
] |
|
|
LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])] |
|
|
LABEL_COLUMN = str(metadata.get("label_column", "Fault")) |
|
|
SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH)) |
|
|
DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE)) |
|
|
MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower() |
|
|
MODEL_FORMAT = str( |
|
|
metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras") |
|
|
).lower() |
|
|
|
|
|
|
|
|
apply_metadata(METADATA) |
|
|
|
|
|
|
|
|
def sync_label_classes_from_model(model: Optional[object]) -> None: |
|
|
global LABEL_CLASSES |
|
|
if model is None: |
|
|
return |
|
|
if hasattr(model, "classes_"): |
|
|
LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")] |
|
|
elif not LABEL_CLASSES and hasattr(model, "output_shape"): |
|
|
LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))] |
|
|
|
|
|
|
|
|
|
|
|
print("Loading model and scaler...") |
|
|
try: |
|
|
MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT) |
|
|
print(f"Model loaded: {MODEL is not None}") |
|
|
except Exception as e: |
|
|
print(f"Model loading failed: {e}") |
|
|
MODEL = None |
|
|
|
|
|
try: |
|
|
SCALER = try_load_scaler(SCALER_PATH) |
|
|
print(f"Scaler loaded: {SCALER is not None}") |
|
|
except Exception as e: |
|
|
print(f"Scaler loading failed: {e}") |
|
|
SCALER = None |
|
|
|
|
|
try: |
|
|
sync_label_classes_from_model(MODEL) |
|
|
print("Label classes synchronized") |
|
|
except Exception as e: |
|
|
print(f"Label sync failed: {e}") |
|
|
|
|
|
print("Application initialization completed.") |
|
|
print( |
|
|
f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}" |
|
|
) |
|
|
|
|
|
|
|
|
def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None: |
|
|
global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA |
|
|
MODEL_PATH = model_path |
|
|
SCALER_PATH = scaler_path |
|
|
METADATA_PATH = metadata_path |
|
|
METADATA = load_metadata(metadata_path) |
|
|
apply_metadata(METADATA) |
|
|
MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT) |
|
|
SCALER = try_load_scaler(scaler_path) |
|
|
sync_label_classes_from_model(MODEL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_ready(): |
|
|
if MODEL is None or SCALER is None: |
|
|
raise RuntimeError( |
|
|
"The model and feature scaler are not available. Upload the trained model " |
|
|
"(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), " |
|
|
"the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root " |
|
|
"or configure the Hugging Face Hub environment variables so the artifacts can be downloaded " |
|
|
"automatically." |
|
|
) |
|
|
|
|
|
|
|
|
def parse_text_features(text: str) -> np.ndarray: |
|
|
cleaned = re.sub(r"[;\n\t]+", ",", text.strip()) |
|
|
arr = np.fromstring(cleaned, sep=",") |
|
|
if arr.size == 0: |
|
|
raise ValueError( |
|
|
"No feature values were parsed. Please enter comma-separated numbers." |
|
|
) |
|
|
return arr.astype(np.float32) |
|
|
|
|
|
|
|
|
def apply_scaler(sequences: np.ndarray) -> np.ndarray: |
|
|
if SCALER is None: |
|
|
return sequences |
|
|
shape = sequences.shape |
|
|
flattened = sequences.reshape(-1, shape[-1]) |
|
|
scaled = SCALER.transform(flattened) |
|
|
return scaled.reshape(shape) |
|
|
|
|
|
|
|
|
def make_sliding_windows( |
|
|
data: np.ndarray, sequence_length: int, stride: int |
|
|
) -> np.ndarray: |
|
|
if data.shape[0] < sequence_length: |
|
|
raise ValueError( |
|
|
f"The dataset contains {data.shape[0]} rows which is less than the requested sequence " |
|
|
f"length {sequence_length}. Provide more samples or reduce the sequence length." |
|
|
) |
|
|
windows = [ |
|
|
data[start : start + sequence_length] |
|
|
for start in range(0, data.shape[0] - sequence_length + 1, stride) |
|
|
] |
|
|
return np.stack(windows) |
|
|
|
|
|
|
|
|
def dataframe_to_sequences( |
|
|
df: pd.DataFrame, |
|
|
*, |
|
|
sequence_length: int, |
|
|
stride: int, |
|
|
feature_columns: Sequence[str], |
|
|
drop_label: bool = True, |
|
|
) -> np.ndarray: |
|
|
work_df = df.copy() |
|
|
if drop_label and LABEL_COLUMN in work_df.columns: |
|
|
work_df = work_df.drop(columns=[LABEL_COLUMN]) |
|
|
if "Timestamp" in work_df.columns: |
|
|
work_df = work_df.sort_values("Timestamp") |
|
|
|
|
|
available_cols = [c for c in feature_columns if c in work_df.columns] |
|
|
n_features = len(feature_columns) |
|
|
if available_cols and len(available_cols) == n_features: |
|
|
array = work_df[available_cols].astype(np.float32).to_numpy() |
|
|
return make_sliding_windows(array, sequence_length, stride) |
|
|
|
|
|
numeric_df = work_df.select_dtypes(include=[np.number]) |
|
|
array = numeric_df.astype(np.float32).to_numpy() |
|
|
if array.shape[1] == n_features * sequence_length: |
|
|
return array.reshape(array.shape[0], sequence_length, n_features) |
|
|
if sequence_length == 1 and array.shape[1] == n_features: |
|
|
return array.reshape(array.shape[0], 1, n_features) |
|
|
raise ValueError( |
|
|
"CSV columns do not match the expected feature layout. Include the full PMU feature set " |
|
|
"or provide pre-shaped sliding window data." |
|
|
) |
|
|
|
|
|
|
|
|
def label_name(index: int) -> str: |
|
|
if 0 <= index < len(LABEL_CLASSES): |
|
|
return str(LABEL_CLASSES[index]) |
|
|
return f"class_{index}" |
|
|
|
|
|
|
|
|
def format_predictions(probabilities: np.ndarray) -> pd.DataFrame: |
|
|
rows: List[Dict[str, object]] = [] |
|
|
order = np.argsort(probabilities, axis=1)[:, ::-1] |
|
|
for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)): |
|
|
top_idx = int(ranking[0]) |
|
|
top_label = label_name(top_idx) |
|
|
top_conf = float(prob_row[top_idx]) |
|
|
top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]] |
|
|
rows.append( |
|
|
{ |
|
|
"window": idx, |
|
|
"predicted_label": top_label, |
|
|
"confidence": round(top_conf, 4), |
|
|
"top3": " | ".join(top3), |
|
|
} |
|
|
) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
|
|
|
def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]: |
|
|
payload: List[Dict[str, object]] = [] |
|
|
for idx, prob_row in enumerate(probabilities): |
|
|
payload.append( |
|
|
{ |
|
|
"window": int(idx), |
|
|
"probabilities": { |
|
|
label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0]) |
|
|
}, |
|
|
} |
|
|
) |
|
|
return payload |
|
|
|
|
|
|
|
|
def predict_sequences( |
|
|
sequences: np.ndarray, |
|
|
) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]: |
|
|
ensure_ready() |
|
|
sequences = apply_scaler(sequences.astype(np.float32)) |
|
|
if MODEL_TYPE == "svm": |
|
|
flattened = sequences.reshape(sequences.shape[0], -1) |
|
|
if hasattr(MODEL, "predict_proba"): |
|
|
probs = MODEL.predict_proba(flattened) |
|
|
else: |
|
|
raise RuntimeError( |
|
|
"Loaded SVM model does not expose predict_proba. Retrain with probability=True." |
|
|
) |
|
|
else: |
|
|
probs = MODEL.predict(sequences, verbose=0) |
|
|
table = format_predictions(probs) |
|
|
json_probs = probabilities_to_json(probs) |
|
|
architecture = MODEL_TYPE.replace("_", "-").upper() |
|
|
status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}." |
|
|
return status, table, json_probs |
|
|
|
|
|
|
|
|
def predict_from_text( |
|
|
text: str, sequence_length: int |
|
|
) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]: |
|
|
arr = parse_text_features(text) |
|
|
n_features = len(FEATURE_COLUMNS) |
|
|
if arr.size % n_features != 0: |
|
|
raise ValueError( |
|
|
f"The number of values ({arr.size}) is not a multiple of the feature dimension " |
|
|
f"({n_features}). Provide values in groups of {n_features}." |
|
|
) |
|
|
timesteps = arr.size // n_features |
|
|
if timesteps != sequence_length: |
|
|
raise ValueError( |
|
|
f"Detected {timesteps} timesteps which does not match the configured sequence length " |
|
|
f"({sequence_length})." |
|
|
) |
|
|
sequences = arr.reshape(1, sequence_length, n_features) |
|
|
status, table, probs = predict_sequences(sequences) |
|
|
status = f"Single window prediction complete. {status}" |
|
|
return status, table, probs |
|
|
|
|
|
|
|
|
def predict_from_csv( |
|
|
file_obj, sequence_length: int, stride: int |
|
|
) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]: |
|
|
df = load_measurement_csv(file_obj.name) |
|
|
sequences = dataframe_to_sequences( |
|
|
df, |
|
|
sequence_length=sequence_length, |
|
|
stride=stride, |
|
|
feature_columns=FEATURE_COLUMNS, |
|
|
) |
|
|
status, table, probs = predict_sequences(sequences) |
|
|
status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}" |
|
|
return status, table, probs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame: |
|
|
rows: List[Dict[str, Any]] = [] |
|
|
for label, metrics in report.items(): |
|
|
if isinstance(metrics, dict): |
|
|
row = {"label": label} |
|
|
for key, value in metrics.items(): |
|
|
if key == "support": |
|
|
row[key] = int(value) |
|
|
else: |
|
|
row[key] = round(float(value), 4) |
|
|
rows.append(row) |
|
|
else: |
|
|
rows.append({"label": label, "accuracy": round(float(metrics), 4)}) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
|
|
|
def confusion_matrix_to_dataframe( |
|
|
confusion: Sequence[Sequence[float]], labels: Sequence[str] |
|
|
) -> pd.DataFrame: |
|
|
if not confusion: |
|
|
return pd.DataFrame() |
|
|
df = pd.DataFrame(confusion, index=list(labels), columns=list(labels)) |
|
|
df.index.name = "True Label" |
|
|
df.columns.name = "Predicted Label" |
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_interface() -> gr.Blocks: |
|
|
theme = gr.themes.Soft( |
|
|
primary_hue="sky", secondary_hue="blue", neutral_hue="gray" |
|
|
).set( |
|
|
body_background_fill="#1f1f1f", |
|
|
body_text_color="#f5f5f5", |
|
|
block_background_fill="#262626", |
|
|
block_border_color="#333333", |
|
|
button_primary_background_fill="#5ac8fa", |
|
|
button_primary_background_fill_hover="#48b5eb", |
|
|
button_primary_border_color="#38bdf8", |
|
|
button_primary_text_color="#0f172a", |
|
|
button_secondary_background_fill="#3f3f46", |
|
|
button_secondary_text_color="#f5f5f5", |
|
|
) |
|
|
|
|
|
def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str: |
|
|
if value is None: |
|
|
return "" |
|
|
path = Path(value).expanduser() |
|
|
try: |
|
|
return str(path.resolve()) |
|
|
except Exception: |
|
|
return str(path) |
|
|
|
|
|
with gr.Blocks( |
|
|
title="Fault Classification - PMU Data", theme=theme, css=APP_CSS |
|
|
) as demo: |
|
|
gr.Markdown("# Fault Classification for PMU & PV Data") |
|
|
gr.Markdown( |
|
|
"🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers." |
|
|
) |
|
|
if MODEL is None or SCALER is None: |
|
|
gr.Markdown( |
|
|
"⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, " |
|
|
"`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, " |
|
|
"or configure the Hugging Face Hub environment variables so they can be downloaded." |
|
|
) |
|
|
else: |
|
|
class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown" |
|
|
gr.Markdown( |
|
|
f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with " |
|
|
f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and " |
|
|
f"{class_count} target classes. Use the tabs below to run inference or fine-tune " |
|
|
"the model with your own CSV files." |
|
|
) |
|
|
|
|
|
with gr.Accordion("Feature Reference", open=False): |
|
|
gr.Markdown( |
|
|
f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n" |
|
|
+ "\n".join(f"- {name}" for name in FEATURE_COLUMNS) |
|
|
) |
|
|
gr.Markdown( |
|
|
f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, " |
|
|
f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed." |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Overview"): |
|
|
gr.Markdown(PROJECT_OVERVIEW_MD) |
|
|
with gr.Tab("Inference"): |
|
|
gr.Markdown("## Run Inference") |
|
|
with gr.Row(): |
|
|
file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"]) |
|
|
text_in = gr.Textbox( |
|
|
lines=4, |
|
|
label="Or paste a single window (comma separated)", |
|
|
placeholder="49.97772,1.215825E-38,...", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
sequence_length_input = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=max(1, SEQUENCE_LENGTH * 2), |
|
|
step=1, |
|
|
value=SEQUENCE_LENGTH, |
|
|
label="Sequence length (timesteps)", |
|
|
) |
|
|
stride_input = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=max(1, SEQUENCE_LENGTH), |
|
|
step=1, |
|
|
value=max(1, DEFAULT_WINDOW_STRIDE), |
|
|
label="CSV window stride", |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("🚀 Run Inference", variant="primary") |
|
|
status_out = gr.Textbox(label="Status", interactive=False) |
|
|
table_out = gr.Dataframe( |
|
|
headers=["window", "predicted_label", "confidence", "top3"], |
|
|
label="Predictions", |
|
|
interactive=False, |
|
|
) |
|
|
probs_out = gr.JSON(label="Per-window probabilities") |
|
|
|
|
|
def _run_prediction(file_obj, text, sequence_length, stride): |
|
|
sequence_length = int(sequence_length) |
|
|
stride = int(stride) |
|
|
try: |
|
|
if file_obj is not None: |
|
|
return predict_from_csv(file_obj, sequence_length, stride) |
|
|
if text and text.strip(): |
|
|
return predict_from_text(text, sequence_length) |
|
|
return ( |
|
|
"Please upload a CSV file or provide feature values.", |
|
|
pd.DataFrame(), |
|
|
[], |
|
|
) |
|
|
except Exception as exc: |
|
|
return f"Prediction failed: {exc}", pd.DataFrame(), [] |
|
|
|
|
|
predict_btn.click( |
|
|
_run_prediction, |
|
|
inputs=[file_in, text_in, sequence_length_input, stride_input], |
|
|
outputs=[status_out, table_out, probs_out], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
with gr.Tab("Training"): |
|
|
gr.Markdown("## Train or Fine-tune the Model") |
|
|
gr.Markdown( |
|
|
"Training data is automatically downloaded from the database. " |
|
|
"Refresh the cache if new files are added upstream." |
|
|
) |
|
|
|
|
|
training_files_state = gr.State([]) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
training_files_summary = gr.Textbox( |
|
|
label="Database training CSVs", |
|
|
value="Training dataset not loaded yet.", |
|
|
lines=4, |
|
|
interactive=False, |
|
|
elem_id="training-files-summary", |
|
|
) |
|
|
with gr.Column(scale=2, min_width=240): |
|
|
dataset_info = gr.Markdown( |
|
|
"No local database CSVs downloaded yet.", |
|
|
) |
|
|
dataset_refresh = gr.Button( |
|
|
"🔄 Reload dataset from database", |
|
|
variant="secondary", |
|
|
) |
|
|
clear_cache_button = gr.Button( |
|
|
"🧹 Clear downloaded cache", |
|
|
variant="secondary", |
|
|
) |
|
|
|
|
|
with gr.Accordion("📂 DataBaseBrowser", open=False): |
|
|
gr.Markdown( |
|
|
"Browse the upstream database by date and download only the CSVs you need." |
|
|
) |
|
|
with gr.Row(elem_id="date-browser-row"): |
|
|
with gr.Column(scale=1, elem_classes=["date-browser-column"]): |
|
|
year_selector = gr.Dropdown(label="Year", choices=[]) |
|
|
year_download_button = gr.Button( |
|
|
"⬇️ Download year CSVs", variant="secondary" |
|
|
) |
|
|
with gr.Column(scale=1, elem_classes=["date-browser-column"]): |
|
|
month_selector = gr.Dropdown(label="Month", choices=[]) |
|
|
month_download_button = gr.Button( |
|
|
"⬇️ Download month CSVs", variant="secondary" |
|
|
) |
|
|
with gr.Column(scale=1, elem_classes=["date-browser-column"]): |
|
|
day_selector = gr.Dropdown(label="Day", choices=[]) |
|
|
day_download_button = gr.Button( |
|
|
"⬇️ Download day CSVs", variant="secondary" |
|
|
) |
|
|
with gr.Column(elem_id="available-files-section"): |
|
|
available_files = gr.CheckboxGroup( |
|
|
label="Available CSV files", |
|
|
choices=[], |
|
|
value=[], |
|
|
elem_id="available-files-grid", |
|
|
) |
|
|
download_button = gr.Button( |
|
|
"⬇️ Download selected CSVs", |
|
|
variant="secondary", |
|
|
elem_id="download-selected-button", |
|
|
) |
|
|
repo_status = gr.Markdown( |
|
|
"Click 'Reload dataset from database' to fetch the directory tree." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
label_input = gr.Dropdown( |
|
|
value=LABEL_COLUMN, |
|
|
choices=[LABEL_COLUMN], |
|
|
allow_custom_value=True, |
|
|
label="Label column name", |
|
|
) |
|
|
model_selector = gr.Radio( |
|
|
choices=["CNN-LSTM", "TCN", "SVM"], |
|
|
value=( |
|
|
"TCN" |
|
|
if MODEL_TYPE == "tcn" |
|
|
else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM") |
|
|
), |
|
|
label="Model architecture", |
|
|
) |
|
|
sequence_length_train = gr.Slider( |
|
|
minimum=4, |
|
|
maximum=max(32, SEQUENCE_LENGTH * 2), |
|
|
step=1, |
|
|
value=SEQUENCE_LENGTH, |
|
|
label="Sequence length", |
|
|
) |
|
|
stride_train = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=max(32, SEQUENCE_LENGTH * 2), |
|
|
step=1, |
|
|
value=max(1, DEFAULT_WINDOW_STRIDE), |
|
|
label="Stride", |
|
|
) |
|
|
|
|
|
model_default = MODEL_FILENAME_BY_TYPE.get( |
|
|
MODEL_TYPE, Path(LOCAL_MODEL_FILE).name |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
validation_train = gr.Slider( |
|
|
minimum=0.05, |
|
|
maximum=0.4, |
|
|
step=0.05, |
|
|
value=0.2, |
|
|
label="Validation split", |
|
|
) |
|
|
batch_train = gr.Slider( |
|
|
minimum=32, |
|
|
maximum=512, |
|
|
step=32, |
|
|
value=128, |
|
|
label="Batch size", |
|
|
) |
|
|
epochs_train = gr.Slider( |
|
|
minimum=5, |
|
|
maximum=100, |
|
|
step=5, |
|
|
value=50, |
|
|
label="Epochs", |
|
|
) |
|
|
|
|
|
directory_choices, directory_default = gather_directory_choices( |
|
|
str(MODEL_OUTPUT_DIR) |
|
|
) |
|
|
artifact_choices, default_artifact = gather_artifact_choices( |
|
|
directory_default |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
output_directory = gr.Dropdown( |
|
|
value=directory_default, |
|
|
label="Output directory", |
|
|
choices=directory_choices, |
|
|
allow_custom_value=True, |
|
|
) |
|
|
model_name = gr.Textbox( |
|
|
value=model_default, |
|
|
label="Model output filename", |
|
|
) |
|
|
scaler_name = gr.Textbox( |
|
|
value=Path(LOCAL_SCALER_FILE).name, |
|
|
label="Scaler output filename", |
|
|
) |
|
|
metadata_name = gr.Textbox( |
|
|
value=Path(LOCAL_METADATA_FILE).name, |
|
|
label="Metadata output filename", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
artifact_browser = gr.Dropdown( |
|
|
label="Saved artifacts in directory", |
|
|
choices=artifact_choices, |
|
|
value=default_artifact, |
|
|
) |
|
|
artifact_download_button = gr.DownloadButton( |
|
|
"⬇️ Download selected artifact", |
|
|
value=default_artifact, |
|
|
visible=bool(default_artifact), |
|
|
variant="secondary", |
|
|
) |
|
|
|
|
|
def on_output_directory_change(selected_dir, current_selection): |
|
|
choices, normalised = gather_directory_choices(selected_dir) |
|
|
artifact_options, selected = gather_artifact_choices( |
|
|
normalised, current_selection |
|
|
) |
|
|
return ( |
|
|
gr.update(choices=choices, value=normalised), |
|
|
gr.update(choices=artifact_options, value=selected), |
|
|
download_button_state(selected), |
|
|
) |
|
|
|
|
|
def on_artifact_change(selected_path): |
|
|
return download_button_state(selected_path) |
|
|
|
|
|
output_directory.change( |
|
|
on_output_directory_change, |
|
|
inputs=[output_directory, artifact_browser], |
|
|
outputs=[ |
|
|
output_directory, |
|
|
artifact_browser, |
|
|
artifact_download_button, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
artifact_browser.change( |
|
|
on_artifact_change, |
|
|
inputs=[artifact_browser], |
|
|
outputs=[artifact_download_button], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
with gr.Row(elem_id="artifact-download-row"): |
|
|
model_download_button = gr.DownloadButton( |
|
|
"⬇️ Download model file", |
|
|
value=None, |
|
|
visible=False, |
|
|
elem_classes=["artifact-download-button"], |
|
|
) |
|
|
scaler_download_button = gr.DownloadButton( |
|
|
"⬇️ Download scaler file", |
|
|
value=None, |
|
|
visible=False, |
|
|
elem_classes=["artifact-download-button"], |
|
|
) |
|
|
metadata_download_button = gr.DownloadButton( |
|
|
"⬇️ Download metadata file", |
|
|
value=None, |
|
|
visible=False, |
|
|
elem_classes=["artifact-download-button"], |
|
|
) |
|
|
tensorboard_download_button = gr.DownloadButton( |
|
|
"⬇️ Download TensorBoard logs", |
|
|
value=None, |
|
|
visible=False, |
|
|
elem_classes=["artifact-download-button"], |
|
|
) |
|
|
|
|
|
model_download_button.file_name = Path(LOCAL_MODEL_FILE).name |
|
|
scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name |
|
|
metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name |
|
|
tensorboard_download_button.file_name = "tensorboard_logs.zip" |
|
|
|
|
|
tensorboard_toggle = gr.Checkbox( |
|
|
value=True, |
|
|
label="Enable TensorBoard logging (creates downloadable archive)", |
|
|
) |
|
|
|
|
|
def _suggest_model_filename(choice: str, current_value: str): |
|
|
choice_key = (choice or "cnn_lstm").lower().replace("-", "_") |
|
|
suggested = MODEL_FILENAME_BY_TYPE.get( |
|
|
choice_key, Path(LOCAL_MODEL_FILE).name |
|
|
) |
|
|
known_defaults = set(MODEL_FILENAME_BY_TYPE.values()) |
|
|
current_name = Path(current_value).name if current_value else "" |
|
|
if current_name and current_name not in known_defaults: |
|
|
return gr.update() |
|
|
return gr.update(value=suggested) |
|
|
|
|
|
model_selector.change( |
|
|
_suggest_model_filename, |
|
|
inputs=[model_selector, model_name], |
|
|
outputs=model_name, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
train_button = gr.Button("🛠️ Start Training", variant="primary") |
|
|
progress_button = gr.Button( |
|
|
"📊 Check Progress", variant="secondary" |
|
|
) |
|
|
|
|
|
|
|
|
training_status = gr.Textbox(label="Training Status", interactive=False) |
|
|
report_output = gr.Dataframe( |
|
|
label="Classification report", interactive=False |
|
|
) |
|
|
history_output = gr.JSON(label="Training history") |
|
|
confusion_output = gr.Dataframe( |
|
|
label="Confusion matrix", interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("📋 Progress Messages", open=True): |
|
|
progress_messages = gr.Textbox( |
|
|
label="Training Messages", |
|
|
lines=8, |
|
|
max_lines=20, |
|
|
interactive=False, |
|
|
autoscroll=True, |
|
|
placeholder="Click 'Check Progress' to see training updates...", |
|
|
) |
|
|
with gr.Row(): |
|
|
gr.Button("🗑️ Clear Messages", variant="secondary").click( |
|
|
lambda: "", outputs=[progress_messages] |
|
|
) |
|
|
|
|
|
def _run_training( |
|
|
file_paths, |
|
|
label_column, |
|
|
model_choice, |
|
|
sequence_length, |
|
|
stride, |
|
|
validation_split, |
|
|
batch_size, |
|
|
epochs, |
|
|
output_dir, |
|
|
model_filename, |
|
|
scaler_filename, |
|
|
metadata_filename, |
|
|
enable_tensorboard, |
|
|
): |
|
|
base_dir = normalise_output_directory(output_dir) |
|
|
try: |
|
|
base_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
model_path = resolve_output_path( |
|
|
base_dir, |
|
|
model_filename, |
|
|
Path(LOCAL_MODEL_FILE).name, |
|
|
) |
|
|
scaler_path = resolve_output_path( |
|
|
base_dir, |
|
|
scaler_filename, |
|
|
Path(LOCAL_SCALER_FILE).name, |
|
|
) |
|
|
metadata_path = resolve_output_path( |
|
|
base_dir, |
|
|
metadata_filename, |
|
|
Path(LOCAL_METADATA_FILE).name, |
|
|
) |
|
|
|
|
|
model_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
scaler_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
metadata_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
status_file = model_path.parent / "training_status.txt" |
|
|
|
|
|
|
|
|
with open(status_file, "w") as f: |
|
|
f.write("Starting training setup...") |
|
|
|
|
|
if not file_paths: |
|
|
raise ValueError( |
|
|
"No training CSVs were found in the database cache. " |
|
|
"Use 'Reload dataset from database' and try again." |
|
|
) |
|
|
|
|
|
with open(status_file, "w") as f: |
|
|
f.write("Loading and validating CSV files...") |
|
|
|
|
|
available_paths = [ |
|
|
path for path in file_paths if Path(path).exists() |
|
|
] |
|
|
missing_paths = [ |
|
|
Path(path).name |
|
|
for path in file_paths |
|
|
if not Path(path).exists() |
|
|
] |
|
|
if not available_paths: |
|
|
raise ValueError( |
|
|
"Database training dataset is unavailable. Reload the dataset and retry." |
|
|
) |
|
|
|
|
|
dfs = [load_measurement_csv(path) for path in available_paths] |
|
|
combined = pd.concat(dfs, ignore_index=True) |
|
|
|
|
|
|
|
|
total_samples = len(combined) |
|
|
if total_samples < 100: |
|
|
print( |
|
|
f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results." |
|
|
) |
|
|
print( |
|
|
"Automatically switching to SVM for small dataset compatibility." |
|
|
) |
|
|
if model_choice in ["cnn_lstm", "tcn"]: |
|
|
model_choice = "svm" |
|
|
print( |
|
|
f"Model type changed to SVM for better small dataset performance." |
|
|
) |
|
|
if total_samples < 10: |
|
|
raise ValueError( |
|
|
f"Insufficient data: {total_samples} samples. Need at least 10 samples for training." |
|
|
) |
|
|
|
|
|
label_column = (label_column or LABEL_COLUMN).strip() |
|
|
if not label_column: |
|
|
raise ValueError("Label column name cannot be empty.") |
|
|
|
|
|
model_choice = ( |
|
|
(model_choice or "CNN-LSTM").lower().replace("-", "_") |
|
|
) |
|
|
if model_choice not in {"cnn_lstm", "tcn", "svm"}: |
|
|
raise ValueError( |
|
|
"Select CNN-LSTM, TCN, or SVM for the model architecture." |
|
|
) |
|
|
|
|
|
with open(status_file, "w") as f: |
|
|
f.write( |
|
|
f"Starting {model_choice.upper()} training with {len(combined)} samples..." |
|
|
) |
|
|
|
|
|
|
|
|
result = train_from_dataframe( |
|
|
combined, |
|
|
label_column=label_column, |
|
|
feature_columns=None, |
|
|
sequence_length=int(sequence_length), |
|
|
stride=int(stride), |
|
|
validation_split=float(validation_split), |
|
|
batch_size=int(batch_size), |
|
|
epochs=int(epochs), |
|
|
model_type=model_choice, |
|
|
model_path=model_path, |
|
|
scaler_path=scaler_path, |
|
|
metadata_path=metadata_path, |
|
|
enable_tensorboard=bool(enable_tensorboard), |
|
|
) |
|
|
|
|
|
refresh_artifacts( |
|
|
Path(result["model_path"]), |
|
|
Path(result["scaler_path"]), |
|
|
Path(result["metadata_path"]), |
|
|
) |
|
|
|
|
|
report_df = classification_report_to_dataframe( |
|
|
result["classification_report"] |
|
|
) |
|
|
confusion_df = confusion_matrix_to_dataframe( |
|
|
result["confusion_matrix"], result["class_names"] |
|
|
) |
|
|
tensorboard_dir = result.get("tensorboard_log_dir") |
|
|
tensorboard_zip = result.get("tensorboard_zip_path") |
|
|
|
|
|
architecture = result["model_type"].replace("_", "-").upper() |
|
|
status = ( |
|
|
f"Training complete using a {architecture} architecture. " |
|
|
f"{result['num_sequences']} windows derived from " |
|
|
f"{result['num_samples']} rows across {len(available_paths)} file(s)." |
|
|
f" Artifacts saved to:" |
|
|
f"\n• Model: {result['model_path']}\n" |
|
|
f"• Scaler: {result['scaler_path']}\n" |
|
|
f"• Metadata: {result['metadata_path']}" |
|
|
) |
|
|
|
|
|
status += f"\nLabel column used: {result.get('label_column', label_column)}" |
|
|
|
|
|
if tensorboard_dir: |
|
|
status += ( |
|
|
f"\nTensorBoard logs directory: {tensorboard_dir}" |
|
|
f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.' |
|
|
"\nDownload the archive below to explore the run offline." |
|
|
) |
|
|
|
|
|
if missing_paths: |
|
|
skipped = ", ".join(missing_paths) |
|
|
status = f"⚠️ Skipped missing files: {skipped}\n" + status |
|
|
|
|
|
artifact_choices, selected_artifact = gather_artifact_choices( |
|
|
str(base_dir), result["model_path"] |
|
|
) |
|
|
|
|
|
return ( |
|
|
status, |
|
|
report_df, |
|
|
result["history"], |
|
|
confusion_df, |
|
|
download_button_state(result["model_path"]), |
|
|
download_button_state(result["scaler_path"]), |
|
|
download_button_state(result["metadata_path"]), |
|
|
download_button_state(tensorboard_zip), |
|
|
gr.update(value=result.get("label_column", label_column)), |
|
|
gr.update( |
|
|
choices=artifact_choices, value=selected_artifact |
|
|
), |
|
|
download_button_state(selected_artifact), |
|
|
) |
|
|
except Exception as exc: |
|
|
artifact_choices, selected_artifact = gather_artifact_choices( |
|
|
str(base_dir) |
|
|
) |
|
|
return ( |
|
|
f"Training failed: {exc}", |
|
|
pd.DataFrame(), |
|
|
{}, |
|
|
pd.DataFrame(), |
|
|
download_button_state(None), |
|
|
download_button_state(None), |
|
|
download_button_state(None), |
|
|
download_button_state(None), |
|
|
gr.update(), |
|
|
gr.update( |
|
|
choices=artifact_choices, value=selected_artifact |
|
|
), |
|
|
download_button_state(selected_artifact), |
|
|
) |
|
|
|
|
|
def _check_progress(output_dir, model_filename, current_messages): |
|
|
"""Check training progress by reading status file and accumulate messages.""" |
|
|
model_path = resolve_output_path( |
|
|
output_dir, model_filename, Path(LOCAL_MODEL_FILE).name |
|
|
) |
|
|
status_file = model_path.parent / "training_status.txt" |
|
|
status_message = read_training_status(str(status_file)) |
|
|
|
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
timestamp = datetime.now().strftime("%H:%M:%S") |
|
|
new_message = f"[{timestamp}] {status_message}" |
|
|
|
|
|
|
|
|
if current_messages: |
|
|
lines = current_messages.split("\n") |
|
|
lines.append(new_message) |
|
|
|
|
|
if len(lines) > 50: |
|
|
lines = lines[-50:] |
|
|
accumulated_messages = "\n".join(lines) |
|
|
else: |
|
|
accumulated_messages = new_message |
|
|
|
|
|
return accumulated_messages |
|
|
|
|
|
train_button.click( |
|
|
_run_training, |
|
|
inputs=[ |
|
|
training_files_state, |
|
|
label_input, |
|
|
model_selector, |
|
|
sequence_length_train, |
|
|
stride_train, |
|
|
validation_train, |
|
|
batch_train, |
|
|
epochs_train, |
|
|
output_directory, |
|
|
model_name, |
|
|
scaler_name, |
|
|
metadata_name, |
|
|
tensorboard_toggle, |
|
|
], |
|
|
outputs=[ |
|
|
training_status, |
|
|
report_output, |
|
|
history_output, |
|
|
confusion_output, |
|
|
model_download_button, |
|
|
scaler_download_button, |
|
|
metadata_download_button, |
|
|
tensorboard_download_button, |
|
|
label_input, |
|
|
artifact_browser, |
|
|
artifact_download_button, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
progress_button.click( |
|
|
_check_progress, |
|
|
inputs=[output_directory, model_name, progress_messages], |
|
|
outputs=[progress_messages], |
|
|
) |
|
|
|
|
|
year_selector.change( |
|
|
on_year_change, |
|
|
inputs=[year_selector], |
|
|
outputs=[ |
|
|
month_selector, |
|
|
day_selector, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
month_selector.change( |
|
|
on_month_change, |
|
|
inputs=[year_selector, month_selector], |
|
|
outputs=[day_selector, available_files, repo_status], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
day_selector.change( |
|
|
on_day_change, |
|
|
inputs=[year_selector, month_selector, day_selector], |
|
|
outputs=[available_files, repo_status], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
download_button.click( |
|
|
download_selected_files, |
|
|
inputs=[ |
|
|
year_selector, |
|
|
month_selector, |
|
|
day_selector, |
|
|
available_files, |
|
|
label_input, |
|
|
], |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
year_download_button.click( |
|
|
download_year_bundle, |
|
|
inputs=[year_selector, label_input], |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
month_download_button.click( |
|
|
download_month_bundle, |
|
|
inputs=[year_selector, month_selector, label_input], |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
day_download_button.click( |
|
|
download_day_bundle, |
|
|
inputs=[year_selector, month_selector, day_selector, label_input], |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
def _reload_dataset(current_label): |
|
|
local = load_repository_training_files( |
|
|
current_label, force_refresh=True |
|
|
) |
|
|
remote = refresh_remote_browser(force_refresh=True) |
|
|
return (*local, *remote) |
|
|
|
|
|
dataset_refresh.click( |
|
|
_reload_dataset, |
|
|
inputs=[label_input], |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
year_selector, |
|
|
month_selector, |
|
|
day_selector, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
clear_cache_button.click( |
|
|
clear_downloaded_cache, |
|
|
inputs=[label_input], |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
year_selector, |
|
|
month_selector, |
|
|
day_selector, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
concurrency_limit=EVENT_CONCURRENCY_LIMIT, |
|
|
) |
|
|
|
|
|
def _initialise_dataset(): |
|
|
local = load_repository_training_files( |
|
|
LABEL_COLUMN, force_refresh=False |
|
|
) |
|
|
remote = refresh_remote_browser(force_refresh=False) |
|
|
return (*local, *remote) |
|
|
|
|
|
demo.load( |
|
|
_initialise_dataset, |
|
|
inputs=None, |
|
|
outputs=[ |
|
|
training_files_state, |
|
|
training_files_summary, |
|
|
label_input, |
|
|
dataset_info, |
|
|
year_selector, |
|
|
month_selector, |
|
|
day_selector, |
|
|
available_files, |
|
|
repo_status, |
|
|
], |
|
|
queue=False, |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_server_port() -> int: |
|
|
for env_var in ("PORT", "GRADIO_SERVER_PORT"): |
|
|
value = os.environ.get(env_var) |
|
|
if value: |
|
|
try: |
|
|
return int(value) |
|
|
except ValueError: |
|
|
print(f"Ignoring invalid port value from {env_var}: {value}") |
|
|
return 7860 |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("Building Gradio interface...") |
|
|
try: |
|
|
demo = build_interface() |
|
|
print("Interface built successfully") |
|
|
except Exception as e: |
|
|
print(f"Failed to build interface: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
return |
|
|
|
|
|
print("Setting up queue...") |
|
|
try: |
|
|
demo.queue(max_size=QUEUE_MAX_SIZE) |
|
|
print("Queue configured") |
|
|
except Exception as e: |
|
|
print(f"Failed to configure queue: {e}") |
|
|
|
|
|
try: |
|
|
port = resolve_server_port() |
|
|
print(f"Launching Gradio app on port {port}") |
|
|
demo.launch(server_name="0.0.0.0", server_port=port, show_error=True) |
|
|
except OSError as exc: |
|
|
print("Failed to launch on requested port:", exc) |
|
|
try: |
|
|
demo.launch(server_name="0.0.0.0", show_error=True) |
|
|
except Exception as e: |
|
|
print(f"Failed to launch completely: {e}") |
|
|
except Exception as e: |
|
|
print(f"Unexpected launch error: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 50) |
|
|
print("PMU Fault Classification App Starting") |
|
|
print(f"Python version: {os.sys.version}") |
|
|
print(f"Working directory: {os.getcwd()}") |
|
|
print(f"HUB_REPO: {HUB_REPO}") |
|
|
print(f"Model available: {MODEL is not None}") |
|
|
print(f"Scaler available: {SCALER is not None}") |
|
|
print("=" * 50) |
|
|
main() |
|
|
|