GuanHuaYu student
Update
97ea8a4
"""Gradio front-end for Fault_Classification_PMU_Data models.
The application loads a CNN-LSTM model (and accompanying scaler/metadata)
produced by ``fault_classification_pmu.py`` and exposes a streamlined
prediction interface optimised for Hugging Face Spaces deployment. It supports
raw PMU time-series CSV uploads as well as manual comma separated feature
vectors.
"""
from __future__ import annotations
import json
import os
import shutil
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import gradio as gr
import joblib
import numpy as np
import pandas as pd
import requests
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
from fault_classification_pmu import (
DEFAULT_FEATURE_COLUMNS as TRAINING_DEFAULT_FEATURE_COLUMNS,
LABEL_GUESS_CANDIDATES as TRAINING_LABEL_GUESSES,
train_from_dataframe,
)
# --------------------------------------------------------------------------------------
# Configuration
# --------------------------------------------------------------------------------------
DEFAULT_FEATURE_COLUMNS: List[str] = list(TRAINING_DEFAULT_FEATURE_COLUMNS)
DEFAULT_SEQUENCE_LENGTH = 32
DEFAULT_STRIDE = 4
LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
HUB_METADATA_FILENAME = os.environ.get("PMU_HUB_METADATA_FILENAME", LOCAL_METADATA_FILE)
ENV_MODEL_PATH = "PMU_MODEL_PATH"
ENV_SCALER_PATH = "PMU_SCALER_PATH"
ENV_METADATA_PATH = "PMU_METADATA_PATH"
# --------------------------------------------------------------------------------------
# Utility functions for loading artifacts
# --------------------------------------------------------------------------------------
def download_from_hub(filename: str) -> Optional[Path]:
if not HUB_REPO or not filename:
return None
try:
print(f"Downloading {filename} from {HUB_REPO} ...")
# Add timeout to prevent hanging
path = hf_hub_download(repo_id=HUB_REPO, filename=filename)
print("Downloaded", path)
return Path(path)
except Exception as exc: # pragma: no cover - logging convenience
print("Failed to download", filename, "from", HUB_REPO, ":", exc)
print("Continuing without pre-trained model...")
return None
def resolve_artifact(
local_name: str, env_var: str, hub_filename: str
) -> Optional[Path]:
print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
candidates = [Path(local_name)] if local_name else []
if local_name:
candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
env_value = os.environ.get(env_var)
if env_value:
candidates.append(Path(env_value))
for candidate in candidates:
if candidate and candidate.exists():
print(f"Found local artifact: {candidate}")
return candidate
print(f"No local artifacts found, checking hub...")
# Only try to download if we have a hub repo configured
if HUB_REPO:
return download_from_hub(hub_filename)
else:
print("No HUB_REPO configured, skipping download")
return None
def load_metadata(path: Optional[Path]) -> Dict:
if path and path.exists():
try:
return json.loads(path.read_text())
except Exception as exc: # pragma: no cover - metadata parsing errors
print("Failed to read metadata", path, exc)
return {}
def try_load_scaler(path: Optional[Path]):
if not path:
return None
try:
scaler = joblib.load(path)
print("Loaded scaler from", path)
return scaler
except Exception as exc:
print("Failed to load scaler", path, exc)
return None
# Initialize paths with error handling
print("Starting application initialization...")
try:
MODEL_PATH = resolve_artifact(LOCAL_MODEL_FILE, ENV_MODEL_PATH, HUB_MODEL_FILENAME)
print(f"Model path resolved: {MODEL_PATH}")
except Exception as e:
print(f"Model path resolution failed: {e}")
MODEL_PATH = None
try:
SCALER_PATH = resolve_artifact(
LOCAL_SCALER_FILE, ENV_SCALER_PATH, HUB_SCALER_FILENAME
)
print(f"Scaler path resolved: {SCALER_PATH}")
except Exception as e:
print(f"Scaler path resolution failed: {e}")
SCALER_PATH = None
try:
METADATA_PATH = resolve_artifact(
LOCAL_METADATA_FILE, ENV_METADATA_PATH, HUB_METADATA_FILENAME
)
print(f"Metadata path resolved: {METADATA_PATH}")
except Exception as e:
print(f"Metadata path resolution failed: {e}")
METADATA_PATH = None
try:
METADATA = load_metadata(METADATA_PATH)
print(f"Metadata loaded: {len(METADATA)} entries")
except Exception as e:
print(f"Metadata loading failed: {e}")
METADATA = {}
# Queuing configuration
QUEUE_MAX_SIZE = 32
# Apply a small per-event concurrency limit to avoid relying on the deprecated
# ``concurrency_count`` parameter when enabling Gradio's request queue.
EVENT_CONCURRENCY_LIMIT = 2
def try_load_model(path: Optional[Path], model_type: str, model_format: str):
if not path:
return None
try:
if model_type == "svm" or model_format == "joblib":
model = joblib.load(path)
else:
model = load_model(path)
print("Loaded model from", path)
return model
except Exception as exc: # pragma: no cover - runtime diagnostics
print("Failed to load model", path, exc)
return None
FEATURE_COLUMNS: List[str] = list(DEFAULT_FEATURE_COLUMNS)
LABEL_CLASSES: List[str] = []
LABEL_COLUMN: str = "Fault"
SEQUENCE_LENGTH: int = DEFAULT_SEQUENCE_LENGTH
DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
MODEL_TYPE: str = "cnn_lstm"
MODEL_FORMAT: str = "keras"
def _model_output_path(filename: str) -> str:
return str(MODEL_OUTPUT_DIR / Path(filename).name)
MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
"cnn_lstm": Path(LOCAL_MODEL_FILE).name,
"tcn": "pmu_tcn_model.keras",
"svm": "pmu_svm_model.joblib",
}
REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
TRAINING_UPLOAD_DIR = Path(
os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads")
)
TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DATA_REPO = os.environ.get(
"PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData"
)
TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
APP_CSS = """
#available-files-section {
position: relative;
display: flex;
flex-direction: column;
gap: 0.75rem;
border-radius: 0.75rem;
isolation: isolate;
}
#available-files-grid {
position: relative;
overflow: visible;
}
#available-files-grid .form {
position: static;
min-height: 16rem;
}
#available-files-grid .wrap {
display: grid;
grid-template-columns: repeat(4, minmax(0, 1fr));
gap: 0.5rem;
max-height: 24rem;
min-height: 16rem;
overflow-y: auto;
padding-right: 0.25rem;
}
#available-files-grid .wrap > div {
min-width: 0;
}
#available-files-grid .wrap label {
margin: 0;
display: flex;
align-items: center;
padding: 0.45rem 0.65rem;
border-radius: 0.65rem;
background-color: rgba(255, 255, 255, 0.05);
border: 1px solid rgba(255, 255, 255, 0.08);
transition: background-color 0.2s ease, border-color 0.2s ease;
min-height: 2.5rem;
}
#available-files-grid .wrap label:hover {
background-color: rgba(90, 200, 250, 0.16);
border-color: rgba(90, 200, 250, 0.4);
}
#available-files-grid .wrap label span {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
#available-files-section .gradio-loading,
#available-files-grid .gradio-loading {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
width: 100%;
height: 100%;
display: flex;
align-items: center;
justify-content: center;
background: rgba(10, 14, 23, 0.92);
border-radius: 0.75rem;
z-index: 999;
padding: 1.5rem;
pointer-events: auto;
}
#available-files-section .gradio-loading {
position: absolute;
inset: 0;
width: 100%;
height: 100%;
display: flex;
align-items: center;
justify-content: center;
background: rgba(10, 14, 23, 0.92);
border-radius: 0.75rem;
z-index: 999;
padding: 1.5rem;
pointer-events: auto;
}
#available-files-section .gradio-loading > * {
width: 100%;
}
#available-files-section .gradio-loading progress,
#available-files-section .gradio-loading .progress-bar,
#available-files-section .gradio-loading .loading-progress,
#available-files-section .gradio-loading [role="progressbar"],
#available-files-section .gradio-loading .wrap,
#available-files-section .gradio-loading .inner {
width: 100% !important;
max-width: none !important;
}
#available-files-section .gradio-loading .status,
#available-files-section .gradio-loading .message,
#available-files-section .gradio-loading .label {
text-align: center;
}
#date-browser-row {
gap: 0.75rem;
}
#date-browser-row .date-browser-column {
flex: 1 1 0%;
min-width: 0;
}
#date-browser-row .date-browser-column > .gradio-dropdown,
#date-browser-row .date-browser-column > .gradio-button {
width: 100%;
}
#date-browser-row .date-browser-column > .gradio-dropdown > div {
width: 100%;
}
#date-browser-row .date-browser-column .gradio-button {
justify-content: center;
}
#training-files-summary textarea {
max-height: 12rem;
overflow-y: auto;
}
#download-selected-button {
width: 100%;
position: relative;
z-index: 0;
}
#download-selected-button .gradio-button {
width: 100%;
justify-content: center;
}
#artifact-download-row {
gap: 0.75rem;
}
#artifact-download-row .artifact-download-button {
flex: 1 1 0%;
min-width: 0;
}
#artifact-download-row .artifact-download-button .gradio-button {
width: 100%;
justify-content: center;
}
"""
def _github_cache_key(path: str) -> str:
return path or "__root__"
def _github_api_url(path: str) -> str:
clean_path = path.strip("/")
base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
if clean_path:
return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
return f"{base}?ref={TRAINING_DATA_BRANCH}"
def list_remote_directory(
path: str = "", *, force_refresh: bool = False
) -> List[Dict[str, Any]]:
key = _github_cache_key(path)
if not force_refresh and key in GITHUB_CONTENT_CACHE:
return GITHUB_CONTENT_CACHE[key]
url = _github_api_url(path)
response = requests.get(url, timeout=30)
if response.status_code != 200:
raise RuntimeError(
f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
)
payload = response.json()
if not isinstance(payload, list):
raise RuntimeError(
"Unexpected GitHub API payload. Expected a directory listing."
)
GITHUB_CONTENT_CACHE[key] = payload
return payload
def list_remote_years(force_refresh: bool = False) -> List[str]:
entries = list_remote_directory("", force_refresh=force_refresh)
years = [item["name"] for item in entries if item.get("type") == "dir"]
return sorted(years)
def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
if not year:
return []
entries = list_remote_directory(year, force_refresh=force_refresh)
months = [item["name"] for item in entries if item.get("type") == "dir"]
return sorted(months)
def list_remote_days(
year: str, month: str, *, force_refresh: bool = False
) -> List[str]:
if not year or not month:
return []
entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
days = [item["name"] for item in entries if item.get("type") == "dir"]
return sorted(days)
def list_remote_files(
year: str, month: str, day: str, *, force_refresh: bool = False
) -> List[str]:
if not year or not month or not day:
return []
entries = list_remote_directory(
f"{year}/{month}/{day}", force_refresh=force_refresh
)
files = [item["name"] for item in entries if item.get("type") == "file"]
return sorted(files)
def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
if not filename:
raise ValueError("Filename cannot be empty when downloading repository data.")
relative_parts = [part for part in (year, month, day, filename) if part]
if len(relative_parts) < 4:
raise ValueError("Provide year, month, day, and filename to download a CSV.")
relative_path = "/".join(relative_parts)
raw_url = (
f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
f"{TRAINING_DATA_BRANCH}/{relative_path}"
)
response = requests.get(raw_url, stream=True, timeout=120)
if response.status_code != 200:
raise RuntimeError(
f"Failed to download `{relative_path}` (status {response.status_code})."
)
target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
target_dir.mkdir(parents=True, exist_ok=True)
target_path = target_dir / filename
with open(target_path, "wb") as handle:
for chunk in response.iter_content(chunk_size=1 << 20):
if chunk:
handle.write(chunk)
return target_path
def _normalise_header(name: str) -> str:
return str(name).strip().lower()
def guess_label_from_columns(
columns: Sequence[str], preferred: Optional[str] = None
) -> Optional[str]:
if not columns:
return preferred
lookup = {_normalise_header(col): str(col) for col in columns}
if preferred:
preferred_stripped = preferred.strip()
for col in columns:
if str(col).strip() == preferred_stripped:
return str(col)
preferred_norm = _normalise_header(preferred)
if preferred_norm in lookup:
return lookup[preferred_norm]
for guess in TRAINING_LABEL_GUESSES:
guess_norm = _normalise_header(guess)
if guess_norm in lookup:
return lookup[guess_norm]
for col in columns:
if _normalise_header(col).startswith("fault"):
return str(col)
return str(columns[0])
def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
lines = [Path(path).name for path in paths]
lines.extend(notes)
return "\n".join(lines) if lines else "No training files available."
def read_training_status(status_file_path: str) -> str:
"""Read the current training status from file."""
try:
if Path(status_file_path).exists():
with open(status_file_path, "r") as f:
return f.read().strip()
except Exception:
pass
return "Training status unavailable"
def _persist_uploaded_file(file_obj) -> Optional[Path]:
if file_obj is None:
return None
if isinstance(file_obj, (str, Path)):
source = Path(file_obj)
original_name = source.name
else:
source = Path(getattr(file_obj, "name", "") or getattr(file_obj, "path", ""))
original_name = getattr(file_obj, "orig_name", source.name) or source.name
if not source or not source.exists():
return None
original_name = Path(original_name).name or source.name
base_path = Path(original_name)
destination = TRAINING_UPLOAD_DIR / base_path.name
counter = 1
while destination.exists():
suffix = base_path.suffix or ".csv"
destination = TRAINING_UPLOAD_DIR / f"{base_path.stem}_{counter}{suffix}"
counter += 1
shutil.copy2(source, destination)
return destination
def prepare_training_paths(
paths: Sequence[str], current_label: str, cleanup_missing: bool = False
):
valid_paths: List[str] = []
notes: List[str] = []
columns_map: Dict[str, str] = {}
for path in paths:
try:
df = load_measurement_csv(path)
except Exception as exc: # pragma: no cover - user file diagnostics
notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
if cleanup_missing:
try:
Path(path).unlink(missing_ok=True)
except Exception:
pass
continue
valid_paths.append(str(path))
for col in df.columns:
columns_map[_normalise_header(col)] = str(col)
summary = summarise_training_files(valid_paths, notes)
preferred = current_label or LABEL_COLUMN
dropdown_choices = (
sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
)
guessed = guess_label_from_columns(dropdown_choices, preferred)
dropdown_value = guessed or preferred or LABEL_COLUMN
return (
valid_paths,
summary,
gr.update(choices=dropdown_choices, value=dropdown_value),
)
def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
if isinstance(existing_paths, (str, Path)):
paths: List[str] = [str(existing_paths)]
elif existing_paths is None:
paths = []
else:
paths = list(existing_paths)
if new_files:
for file in new_files:
persisted = _persist_uploaded_file(file)
if persisted is None:
continue
path_str = str(persisted)
if path_str not in paths:
paths.append(path_str)
return prepare_training_paths(paths, current_label, cleanup_missing=True)
def load_repository_training_files(current_label: str, force_refresh: bool = False):
if force_refresh:
# Clearing the cache is enough because downloads are now on-demand.
for cached in list(TRAINING_DATA_DIR.glob("*")):
# On refresh we keep previously downloaded files; no deletion required.
# The flag triggers downstream UI updates only.
break
csv_paths = sorted(
str(path) for path in TRAINING_DATA_DIR.rglob("*.csv") if path.is_file()
)
if not csv_paths:
message = (
"No local database CSVs are available yet. Use the database browser "
"below to download specific days before training."
)
default_label = current_label or LABEL_COLUMN or "Fault"
return (
[],
message,
gr.update(choices=[default_label], value=default_label),
message,
)
valid_paths, summary, label_update = prepare_training_paths(
csv_paths, current_label, cleanup_missing=False
)
info = (
f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
f"the database cache `{TRAINING_DATA_DIR}`."
)
return valid_paths, summary, label_update, info
def refresh_remote_browser(force_refresh: bool = False):
if force_refresh:
GITHUB_CONTENT_CACHE.clear()
try:
years = list_remote_years(force_refresh=force_refresh)
if years:
message = "Select a year, month, and day to list available CSV files."
else:
message = (
"⚠️ No directories were found in the database root. Verify the upstream "
"structure."
)
return (
gr.update(choices=years, value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
message,
)
except Exception as exc:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
f"⚠️ Failed to query database: {exc}",
)
def on_year_change(year: Optional[str]):
if not year:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
"Select a year to continue.",
)
try:
months = list_remote_months(year)
message = (
f"Year `{year}` selected. Choose a month to drill down."
if months
else f"⚠️ No months available under `{year}`."
)
return (
gr.update(choices=months, value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
message,
)
except Exception as exc:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
f"⚠️ Failed to list months: {exc}",
)
def on_month_change(year: Optional[str], month: Optional[str]):
if not year or not month:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
"Select a month to continue.",
)
try:
days = list_remote_days(year, month)
message = (
f"Month `{year}/{month}` ready. Pick a day to view files."
if days
else f"⚠️ No day folders found under `{year}/{month}`."
)
return (
gr.update(choices=days, value=None),
gr.update(choices=[], value=[]),
message,
)
except Exception as exc:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=[]),
f"⚠️ Failed to list days: {exc}",
)
def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
if not year or not month or not day:
return (
gr.update(choices=[], value=[]),
"Select a day to load file names.",
)
try:
files = list_remote_files(year, month, day)
message = (
f"{len(files)} file(s) available for `{year}/{month}/{day}`."
if files
else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
)
return (
gr.update(choices=files, value=[]),
message,
)
except Exception as exc:
return (
gr.update(choices=[], value=[]),
f"⚠️ Failed to list files: {exc}",
)
def download_selected_files(
year: Optional[str],
month: Optional[str],
day: Optional[str],
filenames: Sequence[str],
current_label: str,
):
if not filenames:
message = "Select at least one CSV before downloading."
local = load_repository_training_files(current_label)
return (*local, gr.update(), message)
success: List[str] = []
notes: List[str] = []
for filename in filenames:
try:
path = download_repository_file(
year or "", month or "", day or "", filename
)
success.append(str(path))
except Exception as exc:
notes.append(f"⚠️ {filename}: {exc}")
local = load_repository_training_files(current_label)
message_lines = []
if success:
message_lines.append(
f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
)
if notes:
message_lines.extend(notes)
if not message_lines:
message_lines.append("No files were downloaded.")
return (*local, gr.update(value=[]), "\n".join(message_lines))
def download_day_bundle(
year: Optional[str],
month: Optional[str],
day: Optional[str],
current_label: str,
):
if not (year and month and day):
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
"Select a year, month, and day before downloading an entire day.",
)
try:
files = list_remote_files(year, month, day)
except Exception as exc:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
)
if not files:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
f"No CSV files were found for `{year}/{month}/{day}`.",
)
result = list(download_selected_files(year, month, day, files, current_label))
result[-1] = (
f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
f"{result[-1]}"
)
return tuple(result)
def download_month_bundle(
year: Optional[str], month: Optional[str], current_label: str
):
if not (year and month):
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
"Select a year and month before downloading an entire month.",
)
try:
days = list_remote_days(year, month)
except Exception as exc:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
)
if not days:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
f"No day folders were found for `{year}/{month}`.",
)
downloaded = 0
notes: List[str] = []
for day in days:
try:
files = list_remote_files(year, month, day)
except Exception as exc:
notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
continue
if not files:
notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
continue
for filename in files:
try:
download_repository_file(year, month, day, filename)
downloaded += 1
except Exception as exc:
notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
local = load_repository_training_files(current_label)
message_lines = []
if downloaded:
message_lines.append(
f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
f"database cache `{TRAINING_DATA_DIR}`."
)
message_lines.extend(notes)
if not message_lines:
message_lines.append("No files were downloaded.")
return (*local, gr.update(value=[]), "\n".join(message_lines))
def download_year_bundle(year: Optional[str], current_label: str):
if not year:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
"Select a year before downloading an entire year of CSVs.",
)
try:
months = list_remote_months(year)
except Exception as exc:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
f"⚠️ Failed to enumerate months for `{year}`: {exc}",
)
if not months:
local = load_repository_training_files(current_label)
return (
*local,
gr.update(),
f"No month folders were found for `{year}`.",
)
downloaded = 0
notes: List[str] = []
for month in months:
try:
days = list_remote_days(year, month)
except Exception as exc:
notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
continue
if not days:
notes.append(f"⚠️ No day folders in `{year}/{month}`.")
continue
for day in days:
try:
files = list_remote_files(year, month, day)
except Exception as exc:
notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
continue
if not files:
notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
continue
for filename in files:
try:
download_repository_file(year, month, day, filename)
downloaded += 1
except Exception as exc:
notes.append(f"⚠️ {year}/{month}/{day}/{filename}: {exc}")
local = load_repository_training_files(current_label)
message_lines = []
if downloaded:
message_lines.append(
f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
f"database cache `{TRAINING_DATA_DIR}`."
)
message_lines.extend(notes)
if not message_lines:
message_lines.append("No files were downloaded.")
return (*local, gr.update(value=[]), "\n".join(message_lines))
def clear_downloaded_cache(current_label: str):
status_message = ""
try:
if TRAINING_DATA_DIR.exists():
shutil.rmtree(TRAINING_DATA_DIR)
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
status_message = (
f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
)
except Exception as exc:
status_message = f"⚠️ Failed to clear database cache: {exc}"
local = load_repository_training_files(current_label, force_refresh=True)
remote = list(refresh_remote_browser(force_refresh=False))
if status_message:
previous = remote[-1]
if isinstance(previous, str) and previous:
remote[-1] = f"{status_message}\n{previous}"
else:
remote[-1] = status_message
return (*local, *remote)
def normalise_output_directory(directory: Optional[str]) -> Path:
base = Path(directory or MODEL_OUTPUT_DIR)
base = base.expanduser()
if not base.is_absolute():
base = (Path.cwd() / base).resolve()
return base
def resolve_output_path(
directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
) -> Path:
if isinstance(directory, Path):
base = directory
else:
base = normalise_output_directory(directory)
candidate = Path(filename or "").expanduser()
if str(candidate):
if candidate.is_absolute():
return candidate
return (base / candidate).resolve()
return (base / fallback).resolve()
ARTIFACT_FILE_EXTENSIONS: Tuple[str, ...] = (
".keras",
".h5",
".joblib",
".pkl",
".json",
".onnx",
".zip",
".txt",
)
def gather_directory_choices(current: Optional[str]) -> Tuple[List[str], str]:
base = normalise_output_directory(current or str(MODEL_OUTPUT_DIR))
candidates = {str(base)}
try:
for candidate in base.parent.iterdir():
if candidate.is_dir():
candidates.add(str(candidate.resolve()))
except Exception:
pass
return sorted(candidates), str(base)
def gather_artifact_choices(
directory: Optional[str], selection: Optional[str] = None
) -> Tuple[List[Tuple[str, str]], Optional[str]]:
base = normalise_output_directory(directory)
choices: List[Tuple[str, str]] = []
selected_value: Optional[str] = None
if base.exists():
try:
artifacts = sorted(
[
path
for path in base.iterdir()
if path.is_file()
and (
not ARTIFACT_FILE_EXTENSIONS
or path.suffix.lower() in ARTIFACT_FILE_EXTENSIONS
)
],
key=lambda path: path.name.lower(),
)
choices = [(artifact.name, str(artifact)) for artifact in artifacts]
except Exception:
choices = []
if selection and any(value == selection for _, value in choices):
selected_value = selection
elif choices:
selected_value = choices[0][1]
return choices, selected_value
def download_button_state(path: Optional[Union[str, Path]]):
if not path:
return gr.update(value=None, visible=False)
candidate = Path(path)
if candidate.exists():
return gr.update(value=str(candidate), visible=True)
return gr.update(value=None, visible=False)
def clear_training_files():
default_label = LABEL_COLUMN or "Fault"
for cached_file in TRAINING_UPLOAD_DIR.glob("*"):
try:
if cached_file.is_file():
cached_file.unlink(missing_ok=True)
except Exception:
pass
return (
[],
"No training files selected.",
gr.update(choices=[default_label], value=default_label),
gr.update(value=None),
)
PROJECT_OVERVIEW_MD = """
## Project Overview
This project focuses on classifying faults in electrical transmission lines and
grid-connected photovoltaic (PV) systems by combining ensemble learning
techniques with deep neural architectures.
## Datasets
### Transmission Line Fault Dataset
- 134,406 samples collected from Phasor Measurement Units (PMUs)
- 14 monitored channels covering currents, voltages, magnitudes, frequency, and phase angles
- Labels span symmetrical and asymmetrical faults: NF, L-G, LL, LL-G, LLL, and LLL-G
- Time span: 0 to 5.7 seconds with high-frequency sampling
### Grid-Connected PV System Fault Dataset
- 2,163,480 samples from 16 experimental scenarios
- 14 features including PV array measurements (Ipv, Vpv, Vdc), three-phase currents/voltages, aggregate magnitudes (Iabc, Vabc), and frequency indicators (If, Vf)
- Captures array, inverter, grid anomaly, feedback sensor, and MPPT controller faults at 9.9989 μs sampling intervals
## Data Format Quick Reference
Each measurement file may be comma or tab separated and typically exposes the
following ordered columns:
1. `Timestamp`
2. `[325] UPMU_SUB22:FREQ` – system frequency (Hz)
3. `[326] UPMU_SUB22:DFDT` – frequency rate-of-change
4. `[327] UPMU_SUB22:FLAG` – PMU status flag
5. `[328] UPMU_SUB22-L1:MAG` – phase A voltage magnitude
6. `[329] UPMU_SUB22-L1:ANG` – phase A voltage angle
7. `[330] UPMU_SUB22-L2:MAG` – phase B voltage magnitude
8. `[331] UPMU_SUB22-L2:ANG` – phase B voltage angle
9. `[332] UPMU_SUB22-L3:MAG` – phase C voltage magnitude
10. `[333] UPMU_SUB22-L3:ANG` – phase C voltage angle
11. `[334] UPMU_SUB22-C1:MAG` – phase A current magnitude
12. `[335] UPMU_SUB22-C1:ANG` – phase A current angle
13. `[336] UPMU_SUB22-C2:MAG` – phase B current magnitude
14. `[337] UPMU_SUB22-C2:ANG` – phase B current angle
15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
The training tab automatically downloads the latest CSV exports from the
`VincentCroft/ThesisModelData` repository and concatenates them before building
sliding windows.
## Models Developed
1. **Support Vector Machine (SVM)** – provides the classical machine learning baseline with balanced accuracy across both datasets (85% PMU / 83% PV).
2. **CNN-LSTM** – couples convolutional feature extraction with temporal memory, achieving 92% PMU / 89% PV accuracy.
3. **Temporal Convolutional Network (TCN)** – leverages dilated convolutions for long-range context and delivers the best trade-off between accuracy and training time (94% PMU / 91% PV).
## Results Summary
- **Transmission Line Fault Classification**: SVM 85%, CNN-LSTM 92%, TCN 94%
- **PV System Fault Classification**: SVM 83%, CNN-LSTM 89%, TCN 91%
Use the **Inference** tab to score new PMU/PV windows and the **Training** tab to
fine-tune or retrain any of the supported models directly within Hugging Face
Spaces. The logs panel will surface TensorBoard archives whenever deep-learning
models are trained.
"""
def load_measurement_csv(path: str) -> pd.DataFrame:
"""Read a PMU/PV measurement file with flexible separators and column mapping."""
try:
df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
except Exception:
df = None
for separator in ("\t", ",", ";"):
try:
df = pd.read_csv(
path, sep=separator, engine="python", encoding="utf-8-sig"
)
break
except Exception:
df = None
if df is None:
raise
# Clean column names
df.columns = [str(col).strip() for col in df.columns]
print(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
print(f"Columns: {list(df.columns)}")
print(f"Data shape: {df.shape}")
# Check if we have enough data for training
if len(df) < 100:
print(
f"Warning: Only {len(df)} rows of data. Recommend at least 1000 rows for effective training."
)
# Check for label column
has_label = any(
col.lower() in ["fault", "label", "class", "target"] for col in df.columns
)
if not has_label:
print(
"Warning: No label column found. Adding dummy 'Fault' column with value 'Normal' for all samples."
)
df["Fault"] = "Normal" # Add dummy label for training
# Create column mapping - map similar column names to expected format
column_mapping = {}
expected_cols = list(REQUIRED_PMU_COLUMNS)
# If we have at least the right number of numeric columns after Timestamp, use positional mapping
if "Timestamp" in df.columns:
numeric_cols = [col for col in df.columns if col != "Timestamp"]
if len(numeric_cols) >= len(expected_cols):
# Map by position (after Timestamp)
for i, expected_col in enumerate(expected_cols):
if i < len(numeric_cols):
column_mapping[numeric_cols[i]] = expected_col
# Rename columns to match expected format
df = df.rename(columns=column_mapping)
# Check if we have the required columns after mapping
missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
if missing:
# If still missing, try a more flexible approach
available_numeric = df.select_dtypes(include=[np.number]).columns.tolist()
if len(available_numeric) >= len(expected_cols):
# Use the first N numeric columns
for i, expected_col in enumerate(expected_cols):
if i < len(available_numeric):
if available_numeric[i] not in df.columns:
continue
df = df.rename(columns={available_numeric[i]: expected_col})
# Recheck missing columns
missing = [col for col in REQUIRED_PMU_COLUMNS if col not in df.columns]
if missing:
missing_str = ", ".join(missing)
available_str = ", ".join(df.columns.tolist())
raise ValueError(
f"Missing required PMU feature columns: {missing_str}. "
f"Available columns: {available_str}. "
"Please ensure your CSV has the correct format with Timestamp followed by PMU measurements."
)
return df
def apply_metadata(metadata: Dict[str, Any]) -> None:
global FEATURE_COLUMNS, LABEL_CLASSES, LABEL_COLUMN, SEQUENCE_LENGTH, DEFAULT_WINDOW_STRIDE, MODEL_TYPE, MODEL_FORMAT
FEATURE_COLUMNS = [
str(col) for col in metadata.get("feature_columns", DEFAULT_FEATURE_COLUMNS)
]
LABEL_CLASSES = [str(label) for label in metadata.get("label_classes", [])]
LABEL_COLUMN = str(metadata.get("label_column", "Fault"))
SEQUENCE_LENGTH = int(metadata.get("sequence_length", DEFAULT_SEQUENCE_LENGTH))
DEFAULT_WINDOW_STRIDE = int(metadata.get("stride", DEFAULT_STRIDE))
MODEL_TYPE = str(metadata.get("model_type", "cnn_lstm")).lower()
MODEL_FORMAT = str(
metadata.get("model_format", "joblib" if MODEL_TYPE == "svm" else "keras")
).lower()
apply_metadata(METADATA)
def sync_label_classes_from_model(model: Optional[object]) -> None:
global LABEL_CLASSES
if model is None:
return
if hasattr(model, "classes_"):
LABEL_CLASSES = [str(label) for label in getattr(model, "classes_")]
elif not LABEL_CLASSES and hasattr(model, "output_shape"):
LABEL_CLASSES = [str(i) for i in range(int(model.output_shape[-1]))]
# Load model and scaler with error handling
print("Loading model and scaler...")
try:
MODEL = try_load_model(MODEL_PATH, MODEL_TYPE, MODEL_FORMAT)
print(f"Model loaded: {MODEL is not None}")
except Exception as e:
print(f"Model loading failed: {e}")
MODEL = None
try:
SCALER = try_load_scaler(SCALER_PATH)
print(f"Scaler loaded: {SCALER is not None}")
except Exception as e:
print(f"Scaler loading failed: {e}")
SCALER = None
try:
sync_label_classes_from_model(MODEL)
print("Label classes synchronized")
except Exception as e:
print(f"Label sync failed: {e}")
print("Application initialization completed.")
print(
f"Ready to start Gradio interface. Model available: {MODEL is not None}, Scaler available: {SCALER is not None}"
)
def refresh_artifacts(model_path: Path, scaler_path: Path, metadata_path: Path) -> None:
global MODEL_PATH, SCALER_PATH, METADATA_PATH, MODEL, SCALER, METADATA
MODEL_PATH = model_path
SCALER_PATH = scaler_path
METADATA_PATH = metadata_path
METADATA = load_metadata(metadata_path)
apply_metadata(METADATA)
MODEL = try_load_model(model_path, MODEL_TYPE, MODEL_FORMAT)
SCALER = try_load_scaler(scaler_path)
sync_label_classes_from_model(MODEL)
# --------------------------------------------------------------------------------------
# Pre-processing helpers
# --------------------------------------------------------------------------------------
def ensure_ready():
if MODEL is None or SCALER is None:
raise RuntimeError(
"The model and feature scaler are not available. Upload the trained model "
"(for example `pmu_cnn_lstm_model.keras`, `pmu_tcn_model.keras`, or `pmu_svm_model.joblib`), "
"the feature scaler (`pmu_feature_scaler.pkl`), and the metadata JSON (`pmu_metadata.json`) to the Space root "
"or configure the Hugging Face Hub environment variables so the artifacts can be downloaded "
"automatically."
)
def parse_text_features(text: str) -> np.ndarray:
cleaned = re.sub(r"[;\n\t]+", ",", text.strip())
arr = np.fromstring(cleaned, sep=",")
if arr.size == 0:
raise ValueError(
"No feature values were parsed. Please enter comma-separated numbers."
)
return arr.astype(np.float32)
def apply_scaler(sequences: np.ndarray) -> np.ndarray:
if SCALER is None:
return sequences
shape = sequences.shape
flattened = sequences.reshape(-1, shape[-1])
scaled = SCALER.transform(flattened)
return scaled.reshape(shape)
def make_sliding_windows(
data: np.ndarray, sequence_length: int, stride: int
) -> np.ndarray:
if data.shape[0] < sequence_length:
raise ValueError(
f"The dataset contains {data.shape[0]} rows which is less than the requested sequence "
f"length {sequence_length}. Provide more samples or reduce the sequence length."
)
windows = [
data[start : start + sequence_length]
for start in range(0, data.shape[0] - sequence_length + 1, stride)
]
return np.stack(windows)
def dataframe_to_sequences(
df: pd.DataFrame,
*,
sequence_length: int,
stride: int,
feature_columns: Sequence[str],
drop_label: bool = True,
) -> np.ndarray:
work_df = df.copy()
if drop_label and LABEL_COLUMN in work_df.columns:
work_df = work_df.drop(columns=[LABEL_COLUMN])
if "Timestamp" in work_df.columns:
work_df = work_df.sort_values("Timestamp")
available_cols = [c for c in feature_columns if c in work_df.columns]
n_features = len(feature_columns)
if available_cols and len(available_cols) == n_features:
array = work_df[available_cols].astype(np.float32).to_numpy()
return make_sliding_windows(array, sequence_length, stride)
numeric_df = work_df.select_dtypes(include=[np.number])
array = numeric_df.astype(np.float32).to_numpy()
if array.shape[1] == n_features * sequence_length:
return array.reshape(array.shape[0], sequence_length, n_features)
if sequence_length == 1 and array.shape[1] == n_features:
return array.reshape(array.shape[0], 1, n_features)
raise ValueError(
"CSV columns do not match the expected feature layout. Include the full PMU feature set "
"or provide pre-shaped sliding window data."
)
def label_name(index: int) -> str:
if 0 <= index < len(LABEL_CLASSES):
return str(LABEL_CLASSES[index])
return f"class_{index}"
def format_predictions(probabilities: np.ndarray) -> pd.DataFrame:
rows: List[Dict[str, object]] = []
order = np.argsort(probabilities, axis=1)[:, ::-1]
for idx, (prob_row, ranking) in enumerate(zip(probabilities, order)):
top_idx = int(ranking[0])
top_label = label_name(top_idx)
top_conf = float(prob_row[top_idx])
top3 = [f"{label_name(i)} ({prob_row[i]*100:.2f}%)" for i in ranking[:3]]
rows.append(
{
"window": idx,
"predicted_label": top_label,
"confidence": round(top_conf, 4),
"top3": " | ".join(top3),
}
)
return pd.DataFrame(rows)
def probabilities_to_json(probabilities: np.ndarray) -> List[Dict[str, object]]:
payload: List[Dict[str, object]] = []
for idx, prob_row in enumerate(probabilities):
payload.append(
{
"window": int(idx),
"probabilities": {
label_name(i): float(prob_row[i]) for i in range(prob_row.shape[0])
},
}
)
return payload
def predict_sequences(
sequences: np.ndarray,
) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
ensure_ready()
sequences = apply_scaler(sequences.astype(np.float32))
if MODEL_TYPE == "svm":
flattened = sequences.reshape(sequences.shape[0], -1)
if hasattr(MODEL, "predict_proba"):
probs = MODEL.predict_proba(flattened)
else:
raise RuntimeError(
"Loaded SVM model does not expose predict_proba. Retrain with probability=True."
)
else:
probs = MODEL.predict(sequences, verbose=0)
table = format_predictions(probs)
json_probs = probabilities_to_json(probs)
architecture = MODEL_TYPE.replace("_", "-").upper()
status = f"Generated {len(sequences)} windows. {architecture} model output dimension: {probs.shape[1]}."
return status, table, json_probs
def predict_from_text(
text: str, sequence_length: int
) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
arr = parse_text_features(text)
n_features = len(FEATURE_COLUMNS)
if arr.size % n_features != 0:
raise ValueError(
f"The number of values ({arr.size}) is not a multiple of the feature dimension "
f"({n_features}). Provide values in groups of {n_features}."
)
timesteps = arr.size // n_features
if timesteps != sequence_length:
raise ValueError(
f"Detected {timesteps} timesteps which does not match the configured sequence length "
f"({sequence_length})."
)
sequences = arr.reshape(1, sequence_length, n_features)
status, table, probs = predict_sequences(sequences)
status = f"Single window prediction complete. {status}"
return status, table, probs
def predict_from_csv(
file_obj, sequence_length: int, stride: int
) -> Tuple[str, pd.DataFrame, List[Dict[str, object]]]:
df = load_measurement_csv(file_obj.name)
sequences = dataframe_to_sequences(
df,
sequence_length=sequence_length,
stride=stride,
feature_columns=FEATURE_COLUMNS,
)
status, table, probs = predict_sequences(sequences)
status = f"CSV processed successfully. Generated {len(sequences)} windows. {status}"
return status, table, probs
# --------------------------------------------------------------------------------------
# Training helpers
# --------------------------------------------------------------------------------------
def classification_report_to_dataframe(report: Dict[str, Any]) -> pd.DataFrame:
rows: List[Dict[str, Any]] = []
for label, metrics in report.items():
if isinstance(metrics, dict):
row = {"label": label}
for key, value in metrics.items():
if key == "support":
row[key] = int(value)
else:
row[key] = round(float(value), 4)
rows.append(row)
else:
rows.append({"label": label, "accuracy": round(float(metrics), 4)})
return pd.DataFrame(rows)
def confusion_matrix_to_dataframe(
confusion: Sequence[Sequence[float]], labels: Sequence[str]
) -> pd.DataFrame:
if not confusion:
return pd.DataFrame()
df = pd.DataFrame(confusion, index=list(labels), columns=list(labels))
df.index.name = "True Label"
df.columns.name = "Predicted Label"
return df
# --------------------------------------------------------------------------------------
# Gradio interface
# --------------------------------------------------------------------------------------
def build_interface() -> gr.Blocks:
theme = gr.themes.Soft(
primary_hue="sky", secondary_hue="blue", neutral_hue="gray"
).set(
body_background_fill="#1f1f1f",
body_text_color="#f5f5f5",
block_background_fill="#262626",
block_border_color="#333333",
button_primary_background_fill="#5ac8fa",
button_primary_background_fill_hover="#48b5eb",
button_primary_border_color="#38bdf8",
button_primary_text_color="#0f172a",
button_secondary_background_fill="#3f3f46",
button_secondary_text_color="#f5f5f5",
)
def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
if value is None:
return ""
path = Path(value).expanduser()
try:
return str(path.resolve())
except Exception:
return str(path)
with gr.Blocks(
title="Fault Classification - PMU Data", theme=theme, css=APP_CSS
) as demo:
gr.Markdown("# Fault Classification for PMU & PV Data")
gr.Markdown(
"🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
)
if MODEL is None or SCALER is None:
gr.Markdown(
"⚠️ **Artifacts Missing** — Upload `pmu_cnn_lstm_model.keras`, "
"`pmu_feature_scaler.pkl`, and `pmu_metadata.json` to enable inference, "
"or configure the Hugging Face Hub environment variables so they can be downloaded."
)
else:
class_count = len(LABEL_CLASSES) if LABEL_CLASSES else "unknown"
gr.Markdown(
f"Loaded a **{MODEL_TYPE.upper()}** model ({MODEL_FORMAT.upper()}) with "
f"{len(FEATURE_COLUMNS)} features, sequence length **{SEQUENCE_LENGTH}**, and "
f"{class_count} target classes. Use the tabs below to run inference or fine-tune "
"the model with your own CSV files."
)
with gr.Accordion("Feature Reference", open=False):
gr.Markdown(
f"Each time window expects **{len(FEATURE_COLUMNS)} features** ordered as follows:\n"
+ "\n".join(f"- {name}" for name in FEATURE_COLUMNS)
)
gr.Markdown(
f"Default training parameters: **sequence length = {SEQUENCE_LENGTH}**, "
f"**stride = {DEFAULT_WINDOW_STRIDE}**. Adjust them in the tabs as needed."
)
with gr.Tabs():
with gr.Tab("Overview"):
gr.Markdown(PROJECT_OVERVIEW_MD)
with gr.Tab("Inference"):
gr.Markdown("## Run Inference")
with gr.Row():
file_in = gr.File(label="Upload PMU CSV", file_types=[".csv"])
text_in = gr.Textbox(
lines=4,
label="Or paste a single window (comma separated)",
placeholder="49.97772,1.215825E-38,...",
)
with gr.Row():
sequence_length_input = gr.Slider(
minimum=1,
maximum=max(1, SEQUENCE_LENGTH * 2),
step=1,
value=SEQUENCE_LENGTH,
label="Sequence length (timesteps)",
)
stride_input = gr.Slider(
minimum=1,
maximum=max(1, SEQUENCE_LENGTH),
step=1,
value=max(1, DEFAULT_WINDOW_STRIDE),
label="CSV window stride",
)
predict_btn = gr.Button("🚀 Run Inference", variant="primary")
status_out = gr.Textbox(label="Status", interactive=False)
table_out = gr.Dataframe(
headers=["window", "predicted_label", "confidence", "top3"],
label="Predictions",
interactive=False,
)
probs_out = gr.JSON(label="Per-window probabilities")
def _run_prediction(file_obj, text, sequence_length, stride):
sequence_length = int(sequence_length)
stride = int(stride)
try:
if file_obj is not None:
return predict_from_csv(file_obj, sequence_length, stride)
if text and text.strip():
return predict_from_text(text, sequence_length)
return (
"Please upload a CSV file or provide feature values.",
pd.DataFrame(),
[],
)
except Exception as exc:
return f"Prediction failed: {exc}", pd.DataFrame(), []
predict_btn.click(
_run_prediction,
inputs=[file_in, text_in, sequence_length_input, stride_input],
outputs=[status_out, table_out, probs_out],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
with gr.Tab("Training"):
gr.Markdown("## Train or Fine-tune the Model")
gr.Markdown(
"Training data is automatically downloaded from the database. "
"Refresh the cache if new files are added upstream."
)
training_files_state = gr.State([])
with gr.Row():
with gr.Column(scale=3):
training_files_summary = gr.Textbox(
label="Database training CSVs",
value="Training dataset not loaded yet.",
lines=4,
interactive=False,
elem_id="training-files-summary",
)
with gr.Column(scale=2, min_width=240):
dataset_info = gr.Markdown(
"No local database CSVs downloaded yet.",
)
dataset_refresh = gr.Button(
"🔄 Reload dataset from database",
variant="secondary",
)
clear_cache_button = gr.Button(
"🧹 Clear downloaded cache",
variant="secondary",
)
with gr.Accordion("📂 DataBaseBrowser", open=False):
gr.Markdown(
"Browse the upstream database by date and download only the CSVs you need."
)
with gr.Row(elem_id="date-browser-row"):
with gr.Column(scale=1, elem_classes=["date-browser-column"]):
year_selector = gr.Dropdown(label="Year", choices=[])
year_download_button = gr.Button(
"⬇️ Download year CSVs", variant="secondary"
)
with gr.Column(scale=1, elem_classes=["date-browser-column"]):
month_selector = gr.Dropdown(label="Month", choices=[])
month_download_button = gr.Button(
"⬇️ Download month CSVs", variant="secondary"
)
with gr.Column(scale=1, elem_classes=["date-browser-column"]):
day_selector = gr.Dropdown(label="Day", choices=[])
day_download_button = gr.Button(
"⬇️ Download day CSVs", variant="secondary"
)
with gr.Column(elem_id="available-files-section"):
available_files = gr.CheckboxGroup(
label="Available CSV files",
choices=[],
value=[],
elem_id="available-files-grid",
)
download_button = gr.Button(
"⬇️ Download selected CSVs",
variant="secondary",
elem_id="download-selected-button",
)
repo_status = gr.Markdown(
"Click 'Reload dataset from database' to fetch the directory tree."
)
with gr.Row():
label_input = gr.Dropdown(
value=LABEL_COLUMN,
choices=[LABEL_COLUMN],
allow_custom_value=True,
label="Label column name",
)
model_selector = gr.Radio(
choices=["CNN-LSTM", "TCN", "SVM"],
value=(
"TCN"
if MODEL_TYPE == "tcn"
else ("SVM" if MODEL_TYPE == "svm" else "CNN-LSTM")
),
label="Model architecture",
)
sequence_length_train = gr.Slider(
minimum=4,
maximum=max(32, SEQUENCE_LENGTH * 2),
step=1,
value=SEQUENCE_LENGTH,
label="Sequence length",
)
stride_train = gr.Slider(
minimum=1,
maximum=max(32, SEQUENCE_LENGTH * 2),
step=1,
value=max(1, DEFAULT_WINDOW_STRIDE),
label="Stride",
)
model_default = MODEL_FILENAME_BY_TYPE.get(
MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
)
with gr.Row():
validation_train = gr.Slider(
minimum=0.05,
maximum=0.4,
step=0.05,
value=0.2,
label="Validation split",
)
batch_train = gr.Slider(
minimum=32,
maximum=512,
step=32,
value=128,
label="Batch size",
)
epochs_train = gr.Slider(
minimum=5,
maximum=100,
step=5,
value=50,
label="Epochs",
)
directory_choices, directory_default = gather_directory_choices(
str(MODEL_OUTPUT_DIR)
)
artifact_choices, default_artifact = gather_artifact_choices(
directory_default
)
with gr.Row():
output_directory = gr.Dropdown(
value=directory_default,
label="Output directory",
choices=directory_choices,
allow_custom_value=True,
)
model_name = gr.Textbox(
value=model_default,
label="Model output filename",
)
scaler_name = gr.Textbox(
value=Path(LOCAL_SCALER_FILE).name,
label="Scaler output filename",
)
metadata_name = gr.Textbox(
value=Path(LOCAL_METADATA_FILE).name,
label="Metadata output filename",
)
with gr.Row():
artifact_browser = gr.Dropdown(
label="Saved artifacts in directory",
choices=artifact_choices,
value=default_artifact,
)
artifact_download_button = gr.DownloadButton(
"⬇️ Download selected artifact",
value=default_artifact,
visible=bool(default_artifact),
variant="secondary",
)
def on_output_directory_change(selected_dir, current_selection):
choices, normalised = gather_directory_choices(selected_dir)
artifact_options, selected = gather_artifact_choices(
normalised, current_selection
)
return (
gr.update(choices=choices, value=normalised),
gr.update(choices=artifact_options, value=selected),
download_button_state(selected),
)
def on_artifact_change(selected_path):
return download_button_state(selected_path)
output_directory.change(
on_output_directory_change,
inputs=[output_directory, artifact_browser],
outputs=[
output_directory,
artifact_browser,
artifact_download_button,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
artifact_browser.change(
on_artifact_change,
inputs=[artifact_browser],
outputs=[artifact_download_button],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
with gr.Row(elem_id="artifact-download-row"):
model_download_button = gr.DownloadButton(
"⬇️ Download model file",
value=None,
visible=False,
elem_classes=["artifact-download-button"],
)
scaler_download_button = gr.DownloadButton(
"⬇️ Download scaler file",
value=None,
visible=False,
elem_classes=["artifact-download-button"],
)
metadata_download_button = gr.DownloadButton(
"⬇️ Download metadata file",
value=None,
visible=False,
elem_classes=["artifact-download-button"],
)
tensorboard_download_button = gr.DownloadButton(
"⬇️ Download TensorBoard logs",
value=None,
visible=False,
elem_classes=["artifact-download-button"],
)
model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
tensorboard_download_button.file_name = "tensorboard_logs.zip"
tensorboard_toggle = gr.Checkbox(
value=True,
label="Enable TensorBoard logging (creates downloadable archive)",
)
def _suggest_model_filename(choice: str, current_value: str):
choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
suggested = MODEL_FILENAME_BY_TYPE.get(
choice_key, Path(LOCAL_MODEL_FILE).name
)
known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
current_name = Path(current_value).name if current_value else ""
if current_name and current_name not in known_defaults:
return gr.update()
return gr.update(value=suggested)
model_selector.change(
_suggest_model_filename,
inputs=[model_selector, model_name],
outputs=model_name,
)
with gr.Row():
train_button = gr.Button("🛠️ Start Training", variant="primary")
progress_button = gr.Button(
"📊 Check Progress", variant="secondary"
)
# Training status display
training_status = gr.Textbox(label="Training Status", interactive=False)
report_output = gr.Dataframe(
label="Classification report", interactive=False
)
history_output = gr.JSON(label="Training history")
confusion_output = gr.Dataframe(
label="Confusion matrix", interactive=False
)
# Message area at the bottom for progress updates
with gr.Accordion("📋 Progress Messages", open=True):
progress_messages = gr.Textbox(
label="Training Messages",
lines=8,
max_lines=20,
interactive=False,
autoscroll=True,
placeholder="Click 'Check Progress' to see training updates...",
)
with gr.Row():
gr.Button("🗑️ Clear Messages", variant="secondary").click(
lambda: "", outputs=[progress_messages]
)
def _run_training(
file_paths,
label_column,
model_choice,
sequence_length,
stride,
validation_split,
batch_size,
epochs,
output_dir,
model_filename,
scaler_filename,
metadata_filename,
enable_tensorboard,
):
base_dir = normalise_output_directory(output_dir)
try:
base_dir.mkdir(parents=True, exist_ok=True)
model_path = resolve_output_path(
base_dir,
model_filename,
Path(LOCAL_MODEL_FILE).name,
)
scaler_path = resolve_output_path(
base_dir,
scaler_filename,
Path(LOCAL_SCALER_FILE).name,
)
metadata_path = resolve_output_path(
base_dir,
metadata_filename,
Path(LOCAL_METADATA_FILE).name,
)
model_path.parent.mkdir(parents=True, exist_ok=True)
scaler_path.parent.mkdir(parents=True, exist_ok=True)
metadata_path.parent.mkdir(parents=True, exist_ok=True)
# Create status file path for progress tracking
status_file = model_path.parent / "training_status.txt"
# Initialize status
with open(status_file, "w") as f:
f.write("Starting training setup...")
if not file_paths:
raise ValueError(
"No training CSVs were found in the database cache. "
"Use 'Reload dataset from database' and try again."
)
with open(status_file, "w") as f:
f.write("Loading and validating CSV files...")
available_paths = [
path for path in file_paths if Path(path).exists()
]
missing_paths = [
Path(path).name
for path in file_paths
if not Path(path).exists()
]
if not available_paths:
raise ValueError(
"Database training dataset is unavailable. Reload the dataset and retry."
)
dfs = [load_measurement_csv(path) for path in available_paths]
combined = pd.concat(dfs, ignore_index=True)
# Validate data size and provide recommendations
total_samples = len(combined)
if total_samples < 100:
print(
f"Warning: Only {total_samples} samples. Recommend at least 1000 for good results."
)
print(
"Automatically switching to SVM for small dataset compatibility."
)
if model_choice in ["cnn_lstm", "tcn"]:
model_choice = "svm"
print(
f"Model type changed to SVM for better small dataset performance."
)
if total_samples < 10:
raise ValueError(
f"Insufficient data: {total_samples} samples. Need at least 10 samples for training."
)
label_column = (label_column or LABEL_COLUMN).strip()
if not label_column:
raise ValueError("Label column name cannot be empty.")
model_choice = (
(model_choice or "CNN-LSTM").lower().replace("-", "_")
)
if model_choice not in {"cnn_lstm", "tcn", "svm"}:
raise ValueError(
"Select CNN-LSTM, TCN, or SVM for the model architecture."
)
with open(status_file, "w") as f:
f.write(
f"Starting {model_choice.upper()} training with {len(combined)} samples..."
)
# Start training
result = train_from_dataframe(
combined,
label_column=label_column,
feature_columns=None,
sequence_length=int(sequence_length),
stride=int(stride),
validation_split=float(validation_split),
batch_size=int(batch_size),
epochs=int(epochs),
model_type=model_choice,
model_path=model_path,
scaler_path=scaler_path,
metadata_path=metadata_path,
enable_tensorboard=bool(enable_tensorboard),
)
refresh_artifacts(
Path(result["model_path"]),
Path(result["scaler_path"]),
Path(result["metadata_path"]),
)
report_df = classification_report_to_dataframe(
result["classification_report"]
)
confusion_df = confusion_matrix_to_dataframe(
result["confusion_matrix"], result["class_names"]
)
tensorboard_dir = result.get("tensorboard_log_dir")
tensorboard_zip = result.get("tensorboard_zip_path")
architecture = result["model_type"].replace("_", "-").upper()
status = (
f"Training complete using a {architecture} architecture. "
f"{result['num_sequences']} windows derived from "
f"{result['num_samples']} rows across {len(available_paths)} file(s)."
f" Artifacts saved to:"
f"\n• Model: {result['model_path']}\n"
f"• Scaler: {result['scaler_path']}\n"
f"• Metadata: {result['metadata_path']}"
)
status += f"\nLabel column used: {result.get('label_column', label_column)}"
if tensorboard_dir:
status += (
f"\nTensorBoard logs directory: {tensorboard_dir}"
f'\nRun `tensorboard --logdir "{tensorboard_dir}"` to inspect the training curves.'
"\nDownload the archive below to explore the run offline."
)
if missing_paths:
skipped = ", ".join(missing_paths)
status = f"⚠️ Skipped missing files: {skipped}\n" + status
artifact_choices, selected_artifact = gather_artifact_choices(
str(base_dir), result["model_path"]
)
return (
status,
report_df,
result["history"],
confusion_df,
download_button_state(result["model_path"]),
download_button_state(result["scaler_path"]),
download_button_state(result["metadata_path"]),
download_button_state(tensorboard_zip),
gr.update(value=result.get("label_column", label_column)),
gr.update(
choices=artifact_choices, value=selected_artifact
),
download_button_state(selected_artifact),
)
except Exception as exc:
artifact_choices, selected_artifact = gather_artifact_choices(
str(base_dir)
)
return (
f"Training failed: {exc}",
pd.DataFrame(),
{},
pd.DataFrame(),
download_button_state(None),
download_button_state(None),
download_button_state(None),
download_button_state(None),
gr.update(),
gr.update(
choices=artifact_choices, value=selected_artifact
),
download_button_state(selected_artifact),
)
def _check_progress(output_dir, model_filename, current_messages):
"""Check training progress by reading status file and accumulate messages."""
model_path = resolve_output_path(
output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
)
status_file = model_path.parent / "training_status.txt"
status_message = read_training_status(str(status_file))
# Add timestamp to the message
from datetime import datetime
timestamp = datetime.now().strftime("%H:%M:%S")
new_message = f"[{timestamp}] {status_message}"
# Accumulate messages, keeping last 50 lines to prevent overflow
if current_messages:
lines = current_messages.split("\n")
lines.append(new_message)
# Keep only last 50 lines
if len(lines) > 50:
lines = lines[-50:]
accumulated_messages = "\n".join(lines)
else:
accumulated_messages = new_message
return accumulated_messages
train_button.click(
_run_training,
inputs=[
training_files_state,
label_input,
model_selector,
sequence_length_train,
stride_train,
validation_train,
batch_train,
epochs_train,
output_directory,
model_name,
scaler_name,
metadata_name,
tensorboard_toggle,
],
outputs=[
training_status,
report_output,
history_output,
confusion_output,
model_download_button,
scaler_download_button,
metadata_download_button,
tensorboard_download_button,
label_input,
artifact_browser,
artifact_download_button,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
progress_button.click(
_check_progress,
inputs=[output_directory, model_name, progress_messages],
outputs=[progress_messages],
)
year_selector.change(
on_year_change,
inputs=[year_selector],
outputs=[
month_selector,
day_selector,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
month_selector.change(
on_month_change,
inputs=[year_selector, month_selector],
outputs=[day_selector, available_files, repo_status],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
day_selector.change(
on_day_change,
inputs=[year_selector, month_selector, day_selector],
outputs=[available_files, repo_status],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
download_button.click(
download_selected_files,
inputs=[
year_selector,
month_selector,
day_selector,
available_files,
label_input,
],
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
year_download_button.click(
download_year_bundle,
inputs=[year_selector, label_input],
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
month_download_button.click(
download_month_bundle,
inputs=[year_selector, month_selector, label_input],
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
day_download_button.click(
download_day_bundle,
inputs=[year_selector, month_selector, day_selector, label_input],
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
def _reload_dataset(current_label):
local = load_repository_training_files(
current_label, force_refresh=True
)
remote = refresh_remote_browser(force_refresh=True)
return (*local, *remote)
dataset_refresh.click(
_reload_dataset,
inputs=[label_input],
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
year_selector,
month_selector,
day_selector,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
clear_cache_button.click(
clear_downloaded_cache,
inputs=[label_input],
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
year_selector,
month_selector,
day_selector,
available_files,
repo_status,
],
concurrency_limit=EVENT_CONCURRENCY_LIMIT,
)
def _initialise_dataset():
local = load_repository_training_files(
LABEL_COLUMN, force_refresh=False
)
remote = refresh_remote_browser(force_refresh=False)
return (*local, *remote)
demo.load(
_initialise_dataset,
inputs=None,
outputs=[
training_files_state,
training_files_summary,
label_input,
dataset_info,
year_selector,
month_selector,
day_selector,
available_files,
repo_status,
],
queue=False,
)
return demo
# --------------------------------------------------------------------------------------
# Launch helpers
# --------------------------------------------------------------------------------------
def resolve_server_port() -> int:
for env_var in ("PORT", "GRADIO_SERVER_PORT"):
value = os.environ.get(env_var)
if value:
try:
return int(value)
except ValueError:
print(f"Ignoring invalid port value from {env_var}: {value}")
return 7860
def main():
print("Building Gradio interface...")
try:
demo = build_interface()
print("Interface built successfully")
except Exception as e:
print(f"Failed to build interface: {e}")
import traceback
traceback.print_exc()
return
print("Setting up queue...")
try:
demo.queue(max_size=QUEUE_MAX_SIZE)
print("Queue configured")
except Exception as e:
print(f"Failed to configure queue: {e}")
try:
port = resolve_server_port()
print(f"Launching Gradio app on port {port}")
demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)
except OSError as exc:
print("Failed to launch on requested port:", exc)
try:
demo.launch(server_name="0.0.0.0", show_error=True)
except Exception as e:
print(f"Failed to launch completely: {e}")
except Exception as e:
print(f"Unexpected launch error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("=" * 50)
print("PMU Fault Classification App Starting")
print(f"Python version: {os.sys.version}")
print(f"Working directory: {os.getcwd()}")
print(f"HUB_REPO: {HUB_REPO}")
print(f"Model available: {MODEL is not None}")
print(f"Scaler available: {SCALER is not None}")
print("=" * 50)
main()