VincentCroft commited on
Commit
e9cfe70
·
1 Parent(s): 61d758d

Fix Gradio download button initialization

Browse files
Files changed (3) hide show
  1. app.py +965 -84
  2. model/.gitkeep +0 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -18,12 +18,13 @@ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
18
 
19
  import re
20
  from pathlib import Path
21
- from typing import Any, Dict, List, Optional, Sequence, Tuple
22
 
23
  import gradio as gr
24
  import joblib
25
  import numpy as np
26
  import pandas as pd
 
27
  from huggingface_hub import hf_hub_download
28
  from tensorflow.keras.models import load_model
29
 
@@ -44,6 +45,9 @@ LOCAL_MODEL_FILE = os.environ.get("PMU_MODEL_FILE", "pmu_cnn_lstm_model.keras")
44
  LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
45
  LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
46
 
 
 
 
47
  HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
48
  HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
49
  HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
@@ -75,6 +79,8 @@ def download_from_hub(filename: str) -> Optional[Path]:
75
  def resolve_artifact(local_name: str, env_var: str, hub_filename: str) -> Optional[Path]:
76
  print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
77
  candidates = [Path(local_name)] if local_name else []
 
 
78
  env_value = os.environ.get(env_var)
79
  if env_value:
80
  candidates.append(Path(env_value))
@@ -173,8 +179,12 @@ DEFAULT_WINDOW_STRIDE: int = DEFAULT_STRIDE
173
  MODEL_TYPE: str = "cnn_lstm"
174
  MODEL_FORMAT: str = "keras"
175
 
 
 
 
 
176
  MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
177
- "cnn_lstm": LOCAL_MODEL_FILE,
178
  "tcn": "pmu_tcn_model.keras",
179
  "svm": "pmu_svm_model.joblib",
180
  }
@@ -183,6 +193,213 @@ REQUIRED_PMU_COLUMNS: Tuple[str, ...] = tuple(DEFAULT_FEATURE_COLUMNS)
183
  TRAINING_UPLOAD_DIR = Path(os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads"))
184
  TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def _normalise_header(name: str) -> str:
188
  return str(name).strip().lower()
@@ -218,7 +435,7 @@ def guess_label_from_columns(columns: Sequence[str], preferred: Optional[str] =
218
  def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
219
  lines = [Path(path).name for path in paths]
220
  lines.extend(notes)
221
- return "\n".join(lines) if lines else "No training files selected."
222
 
223
 
224
  def read_training_status(status_file_path: str) -> str:
@@ -259,6 +476,36 @@ def _persist_uploaded_file(file_obj) -> Optional[Path]:
259
  return destination
260
 
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
263
  if isinstance(existing_paths, (str, Path)):
264
  paths: List[str] = [str(existing_paths)]
@@ -275,32 +522,411 @@ def append_training_files(new_files, existing_paths: Sequence[str], current_labe
275
  if path_str not in paths:
276
  paths.append(path_str)
277
 
278
- valid_paths: List[str] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  notes: List[str] = []
280
- columns_map: Dict[str, str] = {}
281
- for path in paths:
282
  try:
283
- df = load_measurement_csv(path)
284
- except Exception as exc: # pragma: no cover - user file diagnostics
285
- notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  try:
287
- Path(path).unlink(missing_ok=True)
288
- except Exception:
289
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  continue
291
- valid_paths.append(path)
292
- for col in df.columns:
293
- columns_map[_normalise_header(col)] = str(col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- paths = valid_paths
296
- summary = summarise_training_files(paths, notes)
297
- column_choices = sorted(columns_map.values())
298
- preferred = current_label or LABEL_COLUMN
299
- guessed = guess_label_from_columns(column_choices, preferred)
300
- dropdown_choices = column_choices if column_choices else [preferred or LABEL_COLUMN]
301
- dropdown_value = guessed or preferred or LABEL_COLUMN
302
 
303
- return paths, summary, gr.update(choices=dropdown_choices, value=dropdown_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
 
306
  def clear_training_files():
@@ -360,8 +986,9 @@ following ordered columns:
360
  15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
361
  16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
362
 
363
- Upload as many hourly CSV exports as needed—the training tab concatenates them
364
- before building sliding windows.
 
365
 
366
  ## Models Developed
367
 
@@ -729,7 +1356,17 @@ def build_interface() -> gr.Blocks:
729
  button_secondary_background_fill="#3f3f46",
730
  button_secondary_text_color="#f5f5f5",
731
  )
732
- with gr.Blocks(title="Fault Classification - PMU Data", theme=theme) as demo:
 
 
 
 
 
 
 
 
 
 
733
  gr.Markdown("# Fault Classification for PMU & PV Data")
734
  gr.Markdown(
735
  "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
@@ -819,34 +1456,67 @@ def build_interface() -> gr.Blocks:
819
  with gr.Tab("Training"):
820
  gr.Markdown("## Train or Fine-tune the Model")
821
  gr.Markdown(
822
- "Upload one or more PMU CSV files to create a combined training dataset. "
823
- "The files will be concatenated in upload order before generating sliding windows."
824
  )
825
 
826
  training_files_state = gr.State([])
827
  with gr.Row():
828
- training_file_drop = gr.Files(
829
- label="Drag and drop PMU training CSVs",
830
- file_types=[".csv"],
831
- file_count="multiple",
832
- type="filepath",
833
- )
834
- with gr.Column(scale=1, min_width=180):
835
- training_upload = gr.UploadButton(
836
- "📂 Add training CSVs",
837
- file_types=[".csv"],
838
- file_count="multiple",
839
- type="filepath",
840
- variant="primary",
 
 
 
 
 
 
841
  )
842
- clear_training = gr.Button("Clear list", variant="secondary")
843
 
844
- training_files_summary = gr.Textbox(
845
- label="Selected training CSVs",
846
- value="No training files selected.",
847
- lines=4,
848
- interactive=False,
849
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
 
851
  with gr.Row():
852
  label_input = gr.Dropdown(
@@ -879,10 +1549,8 @@ def build_interface() -> gr.Blocks:
879
  label="Stride",
880
  )
881
 
882
- model_default = (
883
- str(MODEL_PATH)
884
- if MODEL_PATH
885
- else MODEL_FILENAME_BY_TYPE.get(MODEL_TYPE, LOCAL_MODEL_FILE)
886
  )
887
 
888
  with gr.Row():
@@ -909,16 +1577,54 @@ def build_interface() -> gr.Blocks:
909
  )
910
 
911
  with gr.Row():
912
- model_name = gr.Textbox(value=model_default, label="Model output filename")
 
 
 
 
 
 
 
913
  scaler_name = gr.Textbox(
914
- value=str(SCALER_PATH or LOCAL_SCALER_FILE),
915
  label="Scaler output filename",
916
  )
917
  metadata_name = gr.Textbox(
918
- value=str(METADATA_PATH or LOCAL_METADATA_FILE),
919
  label="Metadata output filename",
920
  )
921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
922
  tensorboard_toggle = gr.Checkbox(
923
  value=True,
924
  label="Enable TensorBoard logging (creates downloadable archive)",
@@ -926,8 +1632,10 @@ def build_interface() -> gr.Blocks:
926
 
927
  def _suggest_model_filename(choice: str, current_value: str):
928
  choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
929
- suggested = MODEL_FILENAME_BY_TYPE.get(choice_key, LOCAL_MODEL_FILE)
930
- known_defaults = {Path(name).name for name in MODEL_FILENAME_BY_TYPE.values()}
 
 
931
  current_name = Path(current_value).name if current_value else ""
932
  if current_name and current_name not in known_defaults:
933
  return gr.update()
@@ -948,10 +1656,6 @@ def build_interface() -> gr.Blocks:
948
  report_output = gr.Dataframe(label="Classification report", interactive=False)
949
  history_output = gr.JSON(label="Training history")
950
  confusion_output = gr.Dataframe(label="Confusion matrix", interactive=False)
951
- tensorboard_file = gr.File(
952
- label="TensorBoard logs (.zip)",
953
- interactive=False,
954
- )
955
 
956
  # Message area at the bottom for progress updates
957
  with gr.Accordion("📋 Progress Messages", open=True):
@@ -978,21 +1682,56 @@ def build_interface() -> gr.Blocks:
978
  validation_split,
979
  batch_size,
980
  epochs,
 
981
  model_filename,
982
  scaler_filename,
983
  metadata_filename,
984
  enable_tensorboard,
985
  ):
 
 
 
 
 
 
 
 
986
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
  # Create status file path for progress tracking
988
- status_file = Path(model_filename).parent / "training_status.txt"
989
 
990
  # Initialize status
991
  with open(status_file, 'w') as f:
992
  f.write("Starting training setup...")
993
 
994
  if not file_paths:
995
- raise ValueError("Add at least one training CSV via the uploader before starting.")
 
 
 
996
 
997
  with open(status_file, 'w') as f:
998
  f.write("Loading and validating CSV files...")
@@ -1000,7 +1739,9 @@ def build_interface() -> gr.Blocks:
1000
  available_paths = [path for path in file_paths if Path(path).exists()]
1001
  missing_paths = [Path(path).name for path in file_paths if not Path(path).exists()]
1002
  if not available_paths:
1003
- raise ValueError("None of the referenced CSV files are available. Please upload them again.")
 
 
1004
 
1005
  dfs = [load_measurement_csv(path) for path in available_paths]
1006
  combined = pd.concat(dfs, ignore_index=True)
@@ -1038,9 +1779,9 @@ def build_interface() -> gr.Blocks:
1038
  batch_size=int(batch_size),
1039
  epochs=int(epochs),
1040
  model_type=model_choice,
1041
- model_path=Path(model_filename),
1042
- scaler_path=Path(scaler_filename),
1043
- metadata_path=Path(metadata_filename),
1044
  enable_tensorboard=bool(enable_tensorboard),
1045
  )
1046
 
@@ -1084,7 +1825,10 @@ def build_interface() -> gr.Blocks:
1084
  report_df,
1085
  result["history"],
1086
  confusion_df,
1087
- tensorboard_zip,
 
 
 
1088
  gr.update(value=result.get("label_column", label_column)),
1089
  )
1090
  except Exception as exc:
@@ -1093,13 +1837,19 @@ def build_interface() -> gr.Blocks:
1093
  pd.DataFrame(),
1094
  {},
1095
  pd.DataFrame(),
1096
- None,
 
 
 
1097
  gr.update(),
1098
  )
1099
 
1100
- def _check_progress(model_filename, current_messages):
1101
  """Check training progress by reading status file and accumulate messages."""
1102
- status_file = Path(model_filename).parent / "training_status.txt"
 
 
 
1103
  status_message = read_training_status(str(status_file))
1104
 
1105
  # Add timestamp to the message
@@ -1131,6 +1881,7 @@ def build_interface() -> gr.Blocks:
1131
  validation_train,
1132
  batch_train,
1133
  epochs_train,
 
1134
  model_name,
1135
  scaler_name,
1136
  metadata_name,
@@ -1141,7 +1892,10 @@ def build_interface() -> gr.Blocks:
1141
  report_output,
1142
  history_output,
1143
  confusion_output,
1144
- tensorboard_file,
 
 
 
1145
  label_input,
1146
  ],
1147
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
@@ -1149,25 +1903,152 @@ def build_interface() -> gr.Blocks:
1149
 
1150
  progress_button.click(
1151
  _check_progress,
1152
- inputs=[model_name, progress_messages],
1153
  outputs=[progress_messages],
1154
  )
1155
 
1156
- training_upload.upload(
1157
- append_training_files,
1158
- inputs=[training_upload, training_files_state, label_input],
1159
- outputs=[training_files_state, training_files_summary, label_input],
1160
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1161
  )
1162
- training_file_drop.upload(
1163
- append_training_files,
1164
- inputs=[training_file_drop, training_files_state, label_input],
1165
- outputs=[training_files_state, training_files_summary, label_input],
 
1166
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1167
  )
1168
- clear_training.click(
1169
- clear_training_files,
1170
- outputs=[training_files_state, training_files_summary, label_input, training_file_drop],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171
  )
1172
 
1173
  return demo
 
18
 
19
  import re
20
  from pathlib import Path
21
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
22
 
23
  import gradio as gr
24
  import joblib
25
  import numpy as np
26
  import pandas as pd
27
+ import requests
28
  from huggingface_hub import hf_hub_download
29
  from tensorflow.keras.models import load_model
30
 
 
45
  LOCAL_SCALER_FILE = os.environ.get("PMU_SCALER_FILE", "pmu_feature_scaler.pkl")
46
  LOCAL_METADATA_FILE = os.environ.get("PMU_METADATA_FILE", "pmu_metadata.json")
47
 
48
+ MODEL_OUTPUT_DIR = Path(os.environ.get("PMU_MODEL_DIR", "model")).resolve()
49
+ MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
50
+
51
  HUB_REPO = os.environ.get("PMU_HUB_REPO", "")
52
  HUB_MODEL_FILENAME = os.environ.get("PMU_HUB_MODEL_FILENAME", LOCAL_MODEL_FILE)
53
  HUB_SCALER_FILENAME = os.environ.get("PMU_HUB_SCALER_FILENAME", LOCAL_SCALER_FILE)
 
79
  def resolve_artifact(local_name: str, env_var: str, hub_filename: str) -> Optional[Path]:
80
  print(f"Resolving artifact: {local_name}, env: {env_var}, hub: {hub_filename}")
81
  candidates = [Path(local_name)] if local_name else []
82
+ if local_name:
83
+ candidates.append(MODEL_OUTPUT_DIR / Path(local_name).name)
84
  env_value = os.environ.get(env_var)
85
  if env_value:
86
  candidates.append(Path(env_value))
 
179
  MODEL_TYPE: str = "cnn_lstm"
180
  MODEL_FORMAT: str = "keras"
181
 
182
+ def _model_output_path(filename: str) -> str:
183
+ return str(MODEL_OUTPUT_DIR / Path(filename).name)
184
+
185
+
186
  MODEL_FILENAME_BY_TYPE: Dict[str, str] = {
187
+ "cnn_lstm": Path(LOCAL_MODEL_FILE).name,
188
  "tcn": "pmu_tcn_model.keras",
189
  "svm": "pmu_svm_model.joblib",
190
  }
 
193
  TRAINING_UPLOAD_DIR = Path(os.environ.get("PMU_TRAINING_UPLOAD_DIR", "training_uploads"))
194
  TRAINING_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
195
 
196
+ TRAINING_DATA_REPO = os.environ.get("PMU_TRAINING_DATA_REPO", "VincentCroft/ThesisModelData")
197
+ TRAINING_DATA_BRANCH = os.environ.get("PMU_TRAINING_DATA_BRANCH", "main")
198
+ TRAINING_DATA_DIR = Path(os.environ.get("PMU_TRAINING_DATA_DIR", "training_dataset"))
199
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
200
+
201
+ GITHUB_CONTENT_CACHE: Dict[str, List[Dict[str, Any]]] = {}
202
+
203
+
204
+ APP_CSS = """
205
+ #available-files-grid .wrap {
206
+ display: grid;
207
+ grid-template-columns: repeat(4, minmax(0, 1fr));
208
+ gap: 0.5rem;
209
+ max-height: 24rem;
210
+ overflow-y: auto;
211
+ padding-right: 0.25rem;
212
+ }
213
+
214
+ #available-files-grid {
215
+ position: relative;
216
+ }
217
+
218
+ #available-files-grid .wrap > div {
219
+ min-width: 0;
220
+ }
221
+
222
+ #available-files-grid .wrap label {
223
+ margin: 0;
224
+ display: flex;
225
+ align-items: center;
226
+ padding: 0.45rem 0.65rem;
227
+ border-radius: 0.65rem;
228
+ background-color: rgba(255, 255, 255, 0.05);
229
+ border: 1px solid rgba(255, 255, 255, 0.08);
230
+ transition: background-color 0.2s ease, border-color 0.2s ease;
231
+ min-height: 2.5rem;
232
+ }
233
+
234
+ #available-files-grid .wrap label:hover {
235
+ background-color: rgba(90, 200, 250, 0.16);
236
+ border-color: rgba(90, 200, 250, 0.4);
237
+ }
238
+
239
+ #available-files-grid .wrap label span {
240
+ overflow: hidden;
241
+ text-overflow: ellipsis;
242
+ white-space: nowrap;
243
+ }
244
+
245
+ #available-files-grid .gradio-loading {
246
+ position: absolute;
247
+ inset: 0;
248
+ display: flex;
249
+ align-items: center;
250
+ justify-content: center;
251
+ background: rgba(10, 14, 23, 0.72);
252
+ border-radius: 0.75rem;
253
+ z-index: 10;
254
+ }
255
+
256
+ #date-browser-row {
257
+ gap: 0.75rem;
258
+ }
259
+
260
+ #date-browser-row .date-browser-column {
261
+ flex: 1 1 0%;
262
+ min-width: 0;
263
+ }
264
+
265
+ #date-browser-row .date-browser-column > .gradio-dropdown,
266
+ #date-browser-row .date-browser-column > .gradio-button {
267
+ width: 100%;
268
+ }
269
+
270
+ #date-browser-row .date-browser-column > .gradio-dropdown > div {
271
+ width: 100%;
272
+ }
273
+
274
+ #date-browser-row .date-browser-column .gradio-button {
275
+ justify-content: center;
276
+ }
277
+
278
+ #training-files-summary textarea {
279
+ max-height: 12rem;
280
+ overflow-y: auto;
281
+ }
282
+
283
+ #download-selected-button {
284
+ width: 100%;
285
+ }
286
+
287
+ #download-selected-button .gradio-button {
288
+ width: 100%;
289
+ justify-content: center;
290
+ }
291
+
292
+ #artifact-download-row {
293
+ gap: 0.75rem;
294
+ }
295
+
296
+ #artifact-download-row .artifact-download-button {
297
+ flex: 1 1 0%;
298
+ min-width: 0;
299
+ }
300
+
301
+ #artifact-download-row .artifact-download-button .gradio-button {
302
+ width: 100%;
303
+ justify-content: center;
304
+ }
305
+ """
306
+
307
+
308
+ def _github_cache_key(path: str) -> str:
309
+ return path or "__root__"
310
+
311
+
312
+ def _github_api_url(path: str) -> str:
313
+ clean_path = path.strip("/")
314
+ base = f"https://api.github.com/repos/{TRAINING_DATA_REPO}/contents"
315
+ if clean_path:
316
+ return f"{base}/{clean_path}?ref={TRAINING_DATA_BRANCH}"
317
+ return f"{base}?ref={TRAINING_DATA_BRANCH}"
318
+
319
+
320
+ def list_remote_directory(path: str = "", *, force_refresh: bool = False) -> List[Dict[str, Any]]:
321
+ key = _github_cache_key(path)
322
+ if not force_refresh and key in GITHUB_CONTENT_CACHE:
323
+ return GITHUB_CONTENT_CACHE[key]
324
+
325
+ url = _github_api_url(path)
326
+ response = requests.get(url, timeout=30)
327
+ if response.status_code != 200:
328
+ raise RuntimeError(
329
+ f"GitHub API request failed for `{path or '.'}` (status {response.status_code})."
330
+ )
331
+
332
+ payload = response.json()
333
+ if not isinstance(payload, list):
334
+ raise RuntimeError("Unexpected GitHub API payload. Expected a directory listing.")
335
+
336
+ GITHUB_CONTENT_CACHE[key] = payload
337
+ return payload
338
+
339
+
340
+ def list_remote_years(force_refresh: bool = False) -> List[str]:
341
+ entries = list_remote_directory("", force_refresh=force_refresh)
342
+ years = [item["name"] for item in entries if item.get("type") == "dir"]
343
+ return sorted(years)
344
+
345
+
346
+ def list_remote_months(year: str, *, force_refresh: bool = False) -> List[str]:
347
+ if not year:
348
+ return []
349
+ entries = list_remote_directory(year, force_refresh=force_refresh)
350
+ months = [item["name"] for item in entries if item.get("type") == "dir"]
351
+ return sorted(months)
352
+
353
+
354
+ def list_remote_days(year: str, month: str, *, force_refresh: bool = False) -> List[str]:
355
+ if not year or not month:
356
+ return []
357
+ entries = list_remote_directory(f"{year}/{month}", force_refresh=force_refresh)
358
+ days = [item["name"] for item in entries if item.get("type") == "dir"]
359
+ return sorted(days)
360
+
361
+
362
+ def list_remote_files(year: str, month: str, day: str, *, force_refresh: bool = False) -> List[str]:
363
+ if not year or not month or not day:
364
+ return []
365
+ entries = list_remote_directory(
366
+ f"{year}/{month}/{day}", force_refresh=force_refresh
367
+ )
368
+ files = [item["name"] for item in entries if item.get("type") == "file"]
369
+ return sorted(files)
370
+
371
+
372
+ def download_repository_file(year: str, month: str, day: str, filename: str) -> Path:
373
+ if not filename:
374
+ raise ValueError("Filename cannot be empty when downloading repository data.")
375
+
376
+ relative_parts = [part for part in (year, month, day, filename) if part]
377
+ if len(relative_parts) < 4:
378
+ raise ValueError("Provide year, month, day, and filename to download a CSV.")
379
+
380
+ relative_path = "/".join(relative_parts)
381
+ raw_url = (
382
+ f"https://raw.githubusercontent.com/{TRAINING_DATA_REPO}/"
383
+ f"{TRAINING_DATA_BRANCH}/{relative_path}"
384
+ )
385
+
386
+ response = requests.get(raw_url, stream=True, timeout=120)
387
+ if response.status_code != 200:
388
+ raise RuntimeError(
389
+ f"Failed to download `{relative_path}` (status {response.status_code})."
390
+ )
391
+
392
+ target_dir = TRAINING_DATA_DIR.joinpath(year, month, day)
393
+ target_dir.mkdir(parents=True, exist_ok=True)
394
+ target_path = target_dir / filename
395
+
396
+ with open(target_path, "wb") as handle:
397
+ for chunk in response.iter_content(chunk_size=1 << 20):
398
+ if chunk:
399
+ handle.write(chunk)
400
+
401
+ return target_path
402
+
403
 
404
  def _normalise_header(name: str) -> str:
405
  return str(name).strip().lower()
 
435
  def summarise_training_files(paths: Sequence[str], notes: Sequence[str]) -> str:
436
  lines = [Path(path).name for path in paths]
437
  lines.extend(notes)
438
+ return "\n".join(lines) if lines else "No training files available."
439
 
440
 
441
  def read_training_status(status_file_path: str) -> str:
 
476
  return destination
477
 
478
 
479
+ def prepare_training_paths(
480
+ paths: Sequence[str], current_label: str, cleanup_missing: bool = False
481
+ ):
482
+ valid_paths: List[str] = []
483
+ notes: List[str] = []
484
+ columns_map: Dict[str, str] = {}
485
+ for path in paths:
486
+ try:
487
+ df = load_measurement_csv(path)
488
+ except Exception as exc: # pragma: no cover - user file diagnostics
489
+ notes.append(f"⚠️ Skipped {Path(path).name}: {exc}")
490
+ if cleanup_missing:
491
+ try:
492
+ Path(path).unlink(missing_ok=True)
493
+ except Exception:
494
+ pass
495
+ continue
496
+ valid_paths.append(str(path))
497
+ for col in df.columns:
498
+ columns_map[_normalise_header(col)] = str(col)
499
+
500
+ summary = summarise_training_files(valid_paths, notes)
501
+ preferred = current_label or LABEL_COLUMN
502
+ dropdown_choices = sorted(columns_map.values()) if columns_map else [preferred or LABEL_COLUMN]
503
+ guessed = guess_label_from_columns(dropdown_choices, preferred)
504
+ dropdown_value = guessed or preferred or LABEL_COLUMN
505
+
506
+ return valid_paths, summary, gr.update(choices=dropdown_choices, value=dropdown_value)
507
+
508
+
509
  def append_training_files(new_files, existing_paths: Sequence[str], current_label: str):
510
  if isinstance(existing_paths, (str, Path)):
511
  paths: List[str] = [str(existing_paths)]
 
522
  if path_str not in paths:
523
  paths.append(path_str)
524
 
525
+ return prepare_training_paths(paths, current_label, cleanup_missing=True)
526
+
527
+
528
+ def load_repository_training_files(current_label: str, force_refresh: bool = False):
529
+ if force_refresh:
530
+ # Clearing the cache is enough because downloads are now on-demand.
531
+ for cached in list(TRAINING_DATA_DIR.glob("*")):
532
+ # On refresh we keep previously downloaded files; no deletion required.
533
+ # The flag triggers downstream UI updates only.
534
+ break
535
+
536
+ csv_paths = sorted(
537
+ str(path)
538
+ for path in TRAINING_DATA_DIR.rglob("*.csv")
539
+ if path.is_file()
540
+ )
541
+ if not csv_paths:
542
+ message = (
543
+ "No local database CSVs are available yet. Use the database browser "
544
+ "below to download specific days before training."
545
+ )
546
+ default_label = current_label or LABEL_COLUMN or "Fault"
547
+ return (
548
+ [],
549
+ message,
550
+ gr.update(choices=[default_label], value=default_label),
551
+ message,
552
+ )
553
+
554
+ valid_paths, summary, label_update = prepare_training_paths(
555
+ csv_paths, current_label, cleanup_missing=False
556
+ )
557
+
558
+ info = (
559
+ f"Ready with {len(valid_paths)} CSV file(s) cached locally under "
560
+ f"the database cache `{TRAINING_DATA_DIR}`."
561
+ )
562
+
563
+ return valid_paths, summary, label_update, info
564
+
565
+
566
+ def refresh_remote_browser(force_refresh: bool = False):
567
+ if force_refresh:
568
+ GITHUB_CONTENT_CACHE.clear()
569
+ try:
570
+ years = list_remote_years(force_refresh=force_refresh)
571
+ if years:
572
+ message = "Select a year, month, and day to list available CSV files."
573
+ else:
574
+ message = (
575
+ "⚠️ No directories were found in the database root. Verify the upstream "
576
+ "structure."
577
+ )
578
+ return (
579
+ gr.update(choices=years, value=None),
580
+ gr.update(choices=[], value=None),
581
+ gr.update(choices=[], value=None),
582
+ gr.update(choices=[], value=[]),
583
+ message,
584
+ )
585
+ except Exception as exc:
586
+ return (
587
+ gr.update(choices=[], value=None),
588
+ gr.update(choices=[], value=None),
589
+ gr.update(choices=[], value=None),
590
+ gr.update(choices=[], value=[]),
591
+ f"⚠️ Failed to query database: {exc}",
592
+ )
593
+
594
+
595
+ def on_year_change(year: Optional[str]):
596
+ if not year:
597
+ return (
598
+ gr.update(choices=[], value=None),
599
+ gr.update(choices=[], value=None),
600
+ gr.update(choices=[], value=[]),
601
+ "Select a year to continue.",
602
+ )
603
+ try:
604
+ months = list_remote_months(year)
605
+ message = (
606
+ f"Year `{year}` selected. Choose a month to drill down."
607
+ if months
608
+ else f"⚠️ No months available under `{year}`."
609
+ )
610
+ return (
611
+ gr.update(choices=months, value=None),
612
+ gr.update(choices=[], value=None),
613
+ gr.update(choices=[], value=[]),
614
+ message,
615
+ )
616
+ except Exception as exc:
617
+ return (
618
+ gr.update(choices=[], value=None),
619
+ gr.update(choices=[], value=None),
620
+ gr.update(choices=[], value=[]),
621
+ f"⚠️ Failed to list months: {exc}",
622
+ )
623
+
624
+
625
+ def on_month_change(year: Optional[str], month: Optional[str]):
626
+ if not year or not month:
627
+ return (
628
+ gr.update(choices=[], value=None),
629
+ gr.update(choices=[], value=[]),
630
+ "Select a month to continue.",
631
+ )
632
+ try:
633
+ days = list_remote_days(year, month)
634
+ message = (
635
+ f"Month `{year}/{month}` ready. Pick a day to view files."
636
+ if days
637
+ else f"⚠️ No day folders found under `{year}/{month}`."
638
+ )
639
+ return (
640
+ gr.update(choices=days, value=None),
641
+ gr.update(choices=[], value=[]),
642
+ message,
643
+ )
644
+ except Exception as exc:
645
+ return (
646
+ gr.update(choices=[], value=None),
647
+ gr.update(choices=[], value=[]),
648
+ f"⚠️ Failed to list days: {exc}",
649
+ )
650
+
651
+
652
+ def on_day_change(year: Optional[str], month: Optional[str], day: Optional[str]):
653
+ if not year or not month or not day:
654
+ return (
655
+ gr.update(choices=[], value=[]),
656
+ "Select a day to load file names.",
657
+ )
658
+ try:
659
+ files = list_remote_files(year, month, day)
660
+ message = (
661
+ f"{len(files)} file(s) available for `{year}/{month}/{day}`."
662
+ if files
663
+ else f"⚠️ No CSV files found under `{year}/{month}/{day}`."
664
+ )
665
+ return (
666
+ gr.update(choices=files, value=[]),
667
+ message,
668
+ )
669
+ except Exception as exc:
670
+ return (
671
+ gr.update(choices=[], value=[]),
672
+ f"⚠️ Failed to list files: {exc}",
673
+ )
674
+
675
+
676
+ def download_selected_files(
677
+ year: Optional[str],
678
+ month: Optional[str],
679
+ day: Optional[str],
680
+ filenames: Sequence[str],
681
+ current_label: str,
682
+ ):
683
+ if not filenames:
684
+ message = "Select at least one CSV before downloading."
685
+ local = load_repository_training_files(current_label)
686
+ return (*local, gr.update(), message)
687
+
688
+ success: List[str] = []
689
  notes: List[str] = []
690
+ for filename in filenames:
 
691
  try:
692
+ path = download_repository_file(year or "", month or "", day or "", filename)
693
+ success.append(str(path))
694
+ except Exception as exc:
695
+ notes.append(f"⚠️ {filename}: {exc}")
696
+
697
+ local = load_repository_training_files(current_label)
698
+
699
+ message_lines = []
700
+ if success:
701
+ message_lines.append(
702
+ f"Downloaded {len(success)} file(s) to the database cache `{TRAINING_DATA_DIR}`."
703
+ )
704
+ if notes:
705
+ message_lines.extend(notes)
706
+ if not message_lines:
707
+ message_lines.append("No files were downloaded.")
708
+
709
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
710
+
711
+
712
+ def download_day_bundle(
713
+ year: Optional[str],
714
+ month: Optional[str],
715
+ day: Optional[str],
716
+ current_label: str,
717
+ ):
718
+ if not (year and month and day):
719
+ local = load_repository_training_files(current_label)
720
+ return (
721
+ *local,
722
+ gr.update(),
723
+ "Select a year, month, and day before downloading an entire day.",
724
+ )
725
+
726
+ try:
727
+ files = list_remote_files(year, month, day)
728
+ except Exception as exc:
729
+ local = load_repository_training_files(current_label)
730
+ return (
731
+ *local,
732
+ gr.update(),
733
+ f"⚠️ Failed to list CSVs for `{year}/{month}/{day}`: {exc}",
734
+ )
735
+
736
+ if not files:
737
+ local = load_repository_training_files(current_label)
738
+ return (
739
+ *local,
740
+ gr.update(),
741
+ f"No CSV files were found for `{year}/{month}/{day}`.",
742
+ )
743
+
744
+ result = list(download_selected_files(year, month, day, files, current_label))
745
+ result[-1] = (
746
+ f"Downloaded all {len(files)} CSV file(s) for `{year}/{month}/{day}`.\n"
747
+ f"{result[-1]}"
748
+ )
749
+ return tuple(result)
750
+
751
+
752
+ def download_month_bundle(
753
+ year: Optional[str], month: Optional[str], current_label: str
754
+ ):
755
+ if not (year and month):
756
+ local = load_repository_training_files(current_label)
757
+ return (
758
+ *local,
759
+ gr.update(),
760
+ "Select a year and month before downloading an entire month.",
761
+ )
762
+
763
+ try:
764
+ days = list_remote_days(year, month)
765
+ except Exception as exc:
766
+ local = load_repository_training_files(current_label)
767
+ return (
768
+ *local,
769
+ gr.update(),
770
+ f"⚠️ Failed to enumerate days for `{year}/{month}`: {exc}",
771
+ )
772
+
773
+ if not days:
774
+ local = load_repository_training_files(current_label)
775
+ return (
776
+ *local,
777
+ gr.update(),
778
+ f"No day folders were found for `{year}/{month}`.",
779
+ )
780
+
781
+ downloaded = 0
782
+ notes: List[str] = []
783
+ for day in days:
784
+ try:
785
+ files = list_remote_files(year, month, day)
786
+ except Exception as exc:
787
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
788
+ continue
789
+ if not files:
790
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
791
+ continue
792
+ for filename in files:
793
  try:
794
+ download_repository_file(year, month, day, filename)
795
+ downloaded += 1
796
+ except Exception as exc:
797
+ notes.append(
798
+ f"⚠️ {year}/{month}/{day}/{filename}: {exc}"
799
+ )
800
+
801
+ local = load_repository_training_files(current_label)
802
+ message_lines = []
803
+ if downloaded:
804
+ message_lines.append(
805
+ f"Downloaded {downloaded} CSV file(s) for `{year}/{month}` into the "
806
+ f"database cache `{TRAINING_DATA_DIR}`."
807
+ )
808
+ message_lines.extend(notes)
809
+ if not message_lines:
810
+ message_lines.append("No files were downloaded.")
811
+
812
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
813
+
814
+
815
+ def download_year_bundle(year: Optional[str], current_label: str):
816
+ if not year:
817
+ local = load_repository_training_files(current_label)
818
+ return (
819
+ *local,
820
+ gr.update(),
821
+ "Select a year before downloading an entire year of CSVs.",
822
+ )
823
+
824
+ try:
825
+ months = list_remote_months(year)
826
+ except Exception as exc:
827
+ local = load_repository_training_files(current_label)
828
+ return (
829
+ *local,
830
+ gr.update(),
831
+ f"⚠️ Failed to enumerate months for `{year}`: {exc}",
832
+ )
833
+
834
+ if not months:
835
+ local = load_repository_training_files(current_label)
836
+ return (
837
+ *local,
838
+ gr.update(),
839
+ f"No month folders were found for `{year}`.",
840
+ )
841
+
842
+ downloaded = 0
843
+ notes: List[str] = []
844
+ for month in months:
845
+ try:
846
+ days = list_remote_days(year, month)
847
+ except Exception as exc:
848
+ notes.append(f"⚠️ Failed to list `{year}/{month}`: {exc}")
849
  continue
850
+ if not days:
851
+ notes.append(f"⚠️ No day folders in `{year}/{month}`.")
852
+ continue
853
+ for day in days:
854
+ try:
855
+ files = list_remote_files(year, month, day)
856
+ except Exception as exc:
857
+ notes.append(f"⚠️ Failed to list `{year}/{month}/{day}`: {exc}")
858
+ continue
859
+ if not files:
860
+ notes.append(f"⚠️ No CSV files in `{year}/{month}/{day}`.")
861
+ continue
862
+ for filename in files:
863
+ try:
864
+ download_repository_file(year, month, day, filename)
865
+ downloaded += 1
866
+ except Exception as exc:
867
+ notes.append(
868
+ f"⚠️ {year}/{month}/{day}/{filename}: {exc}"
869
+ )
870
+
871
+ local = load_repository_training_files(current_label)
872
+ message_lines = []
873
+ if downloaded:
874
+ message_lines.append(
875
+ f"Downloaded {downloaded} CSV file(s) for `{year}` into the "
876
+ f"database cache `{TRAINING_DATA_DIR}`."
877
+ )
878
+ message_lines.extend(notes)
879
+ if not message_lines:
880
+ message_lines.append("No files were downloaded.")
881
+
882
+ return (*local, gr.update(value=[]), "\n".join(message_lines))
883
 
 
 
 
 
 
 
 
884
 
885
+ def clear_downloaded_cache(current_label: str):
886
+ status_message = ""
887
+ try:
888
+ if TRAINING_DATA_DIR.exists():
889
+ shutil.rmtree(TRAINING_DATA_DIR)
890
+ TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
891
+ status_message = (
892
+ f"Cleared all downloaded CSVs from database cache `{TRAINING_DATA_DIR}`."
893
+ )
894
+ except Exception as exc:
895
+ status_message = f"⚠️ Failed to clear database cache: {exc}"
896
+
897
+ local = load_repository_training_files(current_label, force_refresh=True)
898
+ remote = list(refresh_remote_browser(force_refresh=False))
899
+ if status_message:
900
+ previous = remote[-1]
901
+ if isinstance(previous, str) and previous:
902
+ remote[-1] = f"{status_message}\n{previous}"
903
+ else:
904
+ remote[-1] = status_message
905
+
906
+ return (*local, *remote)
907
+
908
+
909
+ def normalise_output_directory(directory: Optional[str]) -> Path:
910
+ base = Path(directory or MODEL_OUTPUT_DIR)
911
+ base = base.expanduser()
912
+ if not base.is_absolute():
913
+ base = (Path.cwd() / base).resolve()
914
+ return base
915
+
916
+
917
+ def resolve_output_path(
918
+ directory: Optional[Union[Path, str]], filename: Optional[str], fallback: str
919
+ ) -> Path:
920
+ if isinstance(directory, Path):
921
+ base = directory
922
+ else:
923
+ base = normalise_output_directory(directory)
924
+ candidate = Path(filename or "").expanduser()
925
+ if str(candidate):
926
+ if candidate.is_absolute():
927
+ return candidate
928
+ return (base / candidate).resolve()
929
+ return (base / fallback).resolve()
930
 
931
 
932
  def clear_training_files():
 
986
  15. `[338] UPMU_SUB22-C3:MAG` – phase C current magnitude
987
  16. `[339] UPMU_SUB22-C3:ANG` – phase C current angle
988
 
989
+ The training tab automatically downloads the latest CSV exports from the
990
+ `VincentCroft/ThesisModelData` repository and concatenates them before building
991
+ sliding windows.
992
 
993
  ## Models Developed
994
 
 
1356
  button_secondary_background_fill="#3f3f46",
1357
  button_secondary_text_color="#f5f5f5",
1358
  )
1359
+
1360
+ def _normalise_directory_string(value: Optional[Union[str, Path]]) -> str:
1361
+ if value is None:
1362
+ return ""
1363
+ path = Path(value).expanduser()
1364
+ try:
1365
+ return str(path.resolve())
1366
+ except Exception:
1367
+ return str(path)
1368
+
1369
+ with gr.Blocks(title="Fault Classification - PMU Data", theme=theme, css=APP_CSS) as demo:
1370
  gr.Markdown("# Fault Classification for PMU & PV Data")
1371
  gr.Markdown(
1372
  "🖥️ TensorFlow is locked to CPU execution so the Space can run without CUDA drivers."
 
1456
  with gr.Tab("Training"):
1457
  gr.Markdown("## Train or Fine-tune the Model")
1458
  gr.Markdown(
1459
+ "Training data is automatically downloaded from the database. "
1460
+ "Refresh the cache if new files are added upstream."
1461
  )
1462
 
1463
  training_files_state = gr.State([])
1464
  with gr.Row():
1465
+ with gr.Column(scale=3):
1466
+ training_files_summary = gr.Textbox(
1467
+ label="Database training CSVs",
1468
+ value="Training dataset not loaded yet.",
1469
+ lines=4,
1470
+ interactive=False,
1471
+ elem_id="training-files-summary",
1472
+ )
1473
+ with gr.Column(scale=2, min_width=240):
1474
+ dataset_info = gr.Markdown(
1475
+ "No local database CSVs downloaded yet.",
1476
+ )
1477
+ dataset_refresh = gr.Button(
1478
+ "🔄 Reload dataset from database",
1479
+ variant="secondary",
1480
+ )
1481
+ clear_cache_button = gr.Button(
1482
+ "🧹 Clear downloaded cache",
1483
+ variant="secondary",
1484
  )
 
1485
 
1486
+ with gr.Accordion("📂 DataBaseBrowser", open=False):
1487
+ gr.Markdown(
1488
+ "Browse the upstream database by date and download only the CSVs you need."
1489
+ )
1490
+ with gr.Row(elem_id="date-browser-row"):
1491
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1492
+ year_selector = gr.Dropdown(label="Year", choices=[])
1493
+ year_download_button = gr.Button(
1494
+ "⬇️ Download year CSVs", variant="secondary"
1495
+ )
1496
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1497
+ month_selector = gr.Dropdown(label="Month", choices=[])
1498
+ month_download_button = gr.Button(
1499
+ "⬇️ Download month CSVs", variant="secondary"
1500
+ )
1501
+ with gr.Column(scale=1, elem_classes=["date-browser-column"]):
1502
+ day_selector = gr.Dropdown(label="Day", choices=[])
1503
+ day_download_button = gr.Button(
1504
+ "⬇️ Download day CSVs", variant="secondary"
1505
+ )
1506
+ available_files = gr.CheckboxGroup(
1507
+ label="Available CSV files",
1508
+ choices=[],
1509
+ value=[],
1510
+ elem_id="available-files-grid",
1511
+ )
1512
+ download_button = gr.Button(
1513
+ "⬇️ Download selected CSVs",
1514
+ variant="secondary",
1515
+ elem_id="download-selected-button",
1516
+ )
1517
+ repo_status = gr.Markdown(
1518
+ "Click 'Reload dataset from database' to fetch the directory tree."
1519
+ )
1520
 
1521
  with gr.Row():
1522
  label_input = gr.Dropdown(
 
1549
  label="Stride",
1550
  )
1551
 
1552
+ model_default = MODEL_FILENAME_BY_TYPE.get(
1553
+ MODEL_TYPE, Path(LOCAL_MODEL_FILE).name
 
 
1554
  )
1555
 
1556
  with gr.Row():
 
1577
  )
1578
 
1579
  with gr.Row():
1580
+ output_directory = gr.Textbox(
1581
+ value=str(MODEL_OUTPUT_DIR),
1582
+ label="Output directory",
1583
+ )
1584
+ model_name = gr.Textbox(
1585
+ value=model_default,
1586
+ label="Model output filename",
1587
+ )
1588
  scaler_name = gr.Textbox(
1589
+ value=Path(LOCAL_SCALER_FILE).name,
1590
  label="Scaler output filename",
1591
  )
1592
  metadata_name = gr.Textbox(
1593
+ value=Path(LOCAL_METADATA_FILE).name,
1594
  label="Metadata output filename",
1595
  )
1596
 
1597
+ with gr.Row(elem_id="artifact-download-row"):
1598
+ model_download_button = gr.DownloadButton(
1599
+ "⬇️ Download model file",
1600
+ value=None,
1601
+ visible=False,
1602
+ elem_classes=["artifact-download-button"],
1603
+ )
1604
+ scaler_download_button = gr.DownloadButton(
1605
+ "⬇️ Download scaler file",
1606
+ value=None,
1607
+ visible=False,
1608
+ elem_classes=["artifact-download-button"],
1609
+ )
1610
+ metadata_download_button = gr.DownloadButton(
1611
+ "⬇️ Download metadata file",
1612
+ value=None,
1613
+ visible=False,
1614
+ elem_classes=["artifact-download-button"],
1615
+ )
1616
+ tensorboard_download_button = gr.DownloadButton(
1617
+ "⬇️ Download TensorBoard logs",
1618
+ value=None,
1619
+ visible=False,
1620
+ elem_classes=["artifact-download-button"],
1621
+ )
1622
+
1623
+ model_download_button.file_name = Path(LOCAL_MODEL_FILE).name
1624
+ scaler_download_button.file_name = Path(LOCAL_SCALER_FILE).name
1625
+ metadata_download_button.file_name = Path(LOCAL_METADATA_FILE).name
1626
+ tensorboard_download_button.file_name = "tensorboard_logs.zip"
1627
+
1628
  tensorboard_toggle = gr.Checkbox(
1629
  value=True,
1630
  label="Enable TensorBoard logging (creates downloadable archive)",
 
1632
 
1633
  def _suggest_model_filename(choice: str, current_value: str):
1634
  choice_key = (choice or "cnn_lstm").lower().replace("-", "_")
1635
+ suggested = MODEL_FILENAME_BY_TYPE.get(
1636
+ choice_key, Path(LOCAL_MODEL_FILE).name
1637
+ )
1638
+ known_defaults = set(MODEL_FILENAME_BY_TYPE.values())
1639
  current_name = Path(current_value).name if current_value else ""
1640
  if current_name and current_name not in known_defaults:
1641
  return gr.update()
 
1656
  report_output = gr.Dataframe(label="Classification report", interactive=False)
1657
  history_output = gr.JSON(label="Training history")
1658
  confusion_output = gr.Dataframe(label="Confusion matrix", interactive=False)
 
 
 
 
1659
 
1660
  # Message area at the bottom for progress updates
1661
  with gr.Accordion("📋 Progress Messages", open=True):
 
1682
  validation_split,
1683
  batch_size,
1684
  epochs,
1685
+ output_dir,
1686
  model_filename,
1687
  scaler_filename,
1688
  metadata_filename,
1689
  enable_tensorboard,
1690
  ):
1691
+ def _download_state(path: Optional[Union[str, Path]]):
1692
+ if not path:
1693
+ return gr.update(value=None, visible=False)
1694
+ candidate = Path(path)
1695
+ if candidate.exists():
1696
+ return gr.update(value=str(candidate), visible=True)
1697
+ return gr.update(value=None, visible=False)
1698
+
1699
  try:
1700
+ base_dir = normalise_output_directory(output_dir)
1701
+ base_dir.mkdir(parents=True, exist_ok=True)
1702
+
1703
+ model_path = resolve_output_path(
1704
+ base_dir,
1705
+ model_filename,
1706
+ Path(LOCAL_MODEL_FILE).name,
1707
+ )
1708
+ scaler_path = resolve_output_path(
1709
+ base_dir,
1710
+ scaler_filename,
1711
+ Path(LOCAL_SCALER_FILE).name,
1712
+ )
1713
+ metadata_path = resolve_output_path(
1714
+ base_dir,
1715
+ metadata_filename,
1716
+ Path(LOCAL_METADATA_FILE).name,
1717
+ )
1718
+
1719
+ model_path.parent.mkdir(parents=True, exist_ok=True)
1720
+ scaler_path.parent.mkdir(parents=True, exist_ok=True)
1721
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
1722
+
1723
  # Create status file path for progress tracking
1724
+ status_file = model_path.parent / "training_status.txt"
1725
 
1726
  # Initialize status
1727
  with open(status_file, 'w') as f:
1728
  f.write("Starting training setup...")
1729
 
1730
  if not file_paths:
1731
+ raise ValueError(
1732
+ "No training CSVs were found in the database cache. "
1733
+ "Use 'Reload dataset from database' and try again."
1734
+ )
1735
 
1736
  with open(status_file, 'w') as f:
1737
  f.write("Loading and validating CSV files...")
 
1739
  available_paths = [path for path in file_paths if Path(path).exists()]
1740
  missing_paths = [Path(path).name for path in file_paths if not Path(path).exists()]
1741
  if not available_paths:
1742
+ raise ValueError(
1743
+ "Database training dataset is unavailable. Reload the dataset and retry."
1744
+ )
1745
 
1746
  dfs = [load_measurement_csv(path) for path in available_paths]
1747
  combined = pd.concat(dfs, ignore_index=True)
 
1779
  batch_size=int(batch_size),
1780
  epochs=int(epochs),
1781
  model_type=model_choice,
1782
+ model_path=model_path,
1783
+ scaler_path=scaler_path,
1784
+ metadata_path=metadata_path,
1785
  enable_tensorboard=bool(enable_tensorboard),
1786
  )
1787
 
 
1825
  report_df,
1826
  result["history"],
1827
  confusion_df,
1828
+ _download_state(result["model_path"]),
1829
+ _download_state(result["scaler_path"]),
1830
+ _download_state(result["metadata_path"]),
1831
+ _download_state(tensorboard_zip),
1832
  gr.update(value=result.get("label_column", label_column)),
1833
  )
1834
  except Exception as exc:
 
1837
  pd.DataFrame(),
1838
  {},
1839
  pd.DataFrame(),
1840
+ _download_state(None),
1841
+ _download_state(None),
1842
+ _download_state(None),
1843
+ _download_state(None),
1844
  gr.update(),
1845
  )
1846
 
1847
+ def _check_progress(output_dir, model_filename, current_messages):
1848
  """Check training progress by reading status file and accumulate messages."""
1849
+ model_path = resolve_output_path(
1850
+ output_dir, model_filename, Path(LOCAL_MODEL_FILE).name
1851
+ )
1852
+ status_file = model_path.parent / "training_status.txt"
1853
  status_message = read_training_status(str(status_file))
1854
 
1855
  # Add timestamp to the message
 
1881
  validation_train,
1882
  batch_train,
1883
  epochs_train,
1884
+ output_directory,
1885
  model_name,
1886
  scaler_name,
1887
  metadata_name,
 
1892
  report_output,
1893
  history_output,
1894
  confusion_output,
1895
+ model_download_button,
1896
+ scaler_download_button,
1897
+ metadata_download_button,
1898
+ tensorboard_download_button,
1899
  label_input,
1900
  ],
1901
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
 
1903
 
1904
  progress_button.click(
1905
  _check_progress,
1906
+ inputs=[output_directory, model_name, progress_messages],
1907
  outputs=[progress_messages],
1908
  )
1909
 
1910
+ year_selector.change(
1911
+ on_year_change,
1912
+ inputs=[year_selector],
1913
+ outputs=[month_selector, day_selector, available_files, repo_status],
1914
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1915
  )
1916
+
1917
+ month_selector.change(
1918
+ on_month_change,
1919
+ inputs=[year_selector, month_selector],
1920
+ outputs=[day_selector, available_files, repo_status],
1921
  concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1922
  )
1923
+
1924
+ day_selector.change(
1925
+ on_day_change,
1926
+ inputs=[year_selector, month_selector, day_selector],
1927
+ outputs=[available_files, repo_status],
1928
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1929
+ )
1930
+
1931
+ download_button.click(
1932
+ download_selected_files,
1933
+ inputs=[
1934
+ year_selector,
1935
+ month_selector,
1936
+ day_selector,
1937
+ available_files,
1938
+ label_input,
1939
+ ],
1940
+ outputs=[
1941
+ training_files_state,
1942
+ training_files_summary,
1943
+ label_input,
1944
+ dataset_info,
1945
+ available_files,
1946
+ repo_status,
1947
+ ],
1948
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1949
+ )
1950
+
1951
+ year_download_button.click(
1952
+ download_year_bundle,
1953
+ inputs=[year_selector, label_input],
1954
+ outputs=[
1955
+ training_files_state,
1956
+ training_files_summary,
1957
+ label_input,
1958
+ dataset_info,
1959
+ available_files,
1960
+ repo_status,
1961
+ ],
1962
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1963
+ )
1964
+
1965
+ month_download_button.click(
1966
+ download_month_bundle,
1967
+ inputs=[year_selector, month_selector, label_input],
1968
+ outputs=[
1969
+ training_files_state,
1970
+ training_files_summary,
1971
+ label_input,
1972
+ dataset_info,
1973
+ available_files,
1974
+ repo_status,
1975
+ ],
1976
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1977
+ )
1978
+
1979
+ day_download_button.click(
1980
+ download_day_bundle,
1981
+ inputs=[year_selector, month_selector, day_selector, label_input],
1982
+ outputs=[
1983
+ training_files_state,
1984
+ training_files_summary,
1985
+ label_input,
1986
+ dataset_info,
1987
+ available_files,
1988
+ repo_status,
1989
+ ],
1990
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
1991
+ )
1992
+
1993
+ def _reload_dataset(current_label):
1994
+ local = load_repository_training_files(current_label, force_refresh=True)
1995
+ remote = refresh_remote_browser(force_refresh=True)
1996
+ return (*local, *remote)
1997
+
1998
+ dataset_refresh.click(
1999
+ _reload_dataset,
2000
+ inputs=[label_input],
2001
+ outputs=[
2002
+ training_files_state,
2003
+ training_files_summary,
2004
+ label_input,
2005
+ dataset_info,
2006
+ year_selector,
2007
+ month_selector,
2008
+ day_selector,
2009
+ available_files,
2010
+ repo_status,
2011
+ ],
2012
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2013
+ )
2014
+
2015
+ clear_cache_button.click(
2016
+ clear_downloaded_cache,
2017
+ inputs=[label_input],
2018
+ outputs=[
2019
+ training_files_state,
2020
+ training_files_summary,
2021
+ label_input,
2022
+ dataset_info,
2023
+ year_selector,
2024
+ month_selector,
2025
+ day_selector,
2026
+ available_files,
2027
+ repo_status,
2028
+ ],
2029
+ concurrency_limit=EVENT_CONCURRENCY_LIMIT,
2030
+ )
2031
+
2032
+ def _initialise_dataset():
2033
+ local = load_repository_training_files(LABEL_COLUMN, force_refresh=False)
2034
+ remote = refresh_remote_browser(force_refresh=False)
2035
+ return (*local, *remote)
2036
+
2037
+ demo.load(
2038
+ _initialise_dataset,
2039
+ inputs=None,
2040
+ outputs=[
2041
+ training_files_state,
2042
+ training_files_summary,
2043
+ label_input,
2044
+ dataset_info,
2045
+ year_selector,
2046
+ month_selector,
2047
+ day_selector,
2048
+ available_files,
2049
+ repo_status,
2050
+ ],
2051
+ queue=False,
2052
  )
2053
 
2054
  return demo
model/.gitkeep ADDED
File without changes
requirements.txt CHANGED
@@ -6,3 +6,4 @@ scikit-learn
6
  huggingface_hub
7
  matplotlib
8
  joblib
 
 
6
  huggingface_hub
7
  matplotlib
8
  joblib
9
+ requests