import io import zipfile from pathlib import Path from typing import List, Tuple, Literal, Optional from evaluation.metrics import get_metrics import gradio as gr import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import pandas as pd from huggingface_hub import hf_hub_download from huggingface_hub.errors import HfHubHTTPError from model_wrapper import run_Time_RCD REPO_ID = "thu-sail-lab/Time-RCD" CHECKPOINT_FILES = [ "checkpoints/full_mask_anomaly_head_pretrain_checkpoint_best.pth", "checkpoints/dataset_10_20.pth", "checkpoints/full_mask_10_20.pth", "checkpoints/dataset_15_56.pth", "checkpoints/full_mask_15_56.pth", ] def ensure_checkpoints() -> None: """Ensure that the required checkpoint files are present locally.""" missing = [path for path in CHECKPOINT_FILES if not Path(path).exists()] if not missing: return try: zip_path = hf_hub_download( repo_id=REPO_ID, filename="checkpoints.zip", repo_type="model", cache_dir=".cache/hf", ) except HfHubHTTPError: zip_path = hf_hub_download( repo_id=REPO_ID, filename="checkpoints.zip", repo_type="dataset", cache_dir=".cache/hf", ) with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(".") BASE_DIR = Path(__file__).resolve().parent SAMPLE_DATASET_DIR = BASE_DIR / "sample_datasets" LabelSource = Literal["same_file", "separate_file", "none"] LABEL_COLUMN_CANDIDATES = ("label", "labels") LABEL_SOURCE_CHOICES = { "Value + label in same file": "same_file", "Labels in separate file": "separate_file", "No labels provided": "none", } SAMPLE_FILES: dict[str, dict[str, object]] = { "Sample: Univariate SED Medical": { "path": SAMPLE_DATASET_DIR / "235_SED_id_2_Medical_tr_2499_1st_3840.csv", "is_multivariate": False, }, "Sample: Univariate UCR Medical": { "path": SAMPLE_DATASET_DIR / "353_UCR_id_51_Medical_tr_1875_1st_3198.csv", "is_multivariate": False, }, "Sample: Univariate Yahoo WebService": { "path": SAMPLE_DATASET_DIR / "686_YAHOO_id_136_WebService_tr_500_1st_755.csv", "is_multivariate": False, }, # "Sample: Multivariate MSL Sensor": { # "path": SAMPLE_DATASET_DIR / "003_MSL_id_2_Sensor_tr_883_1st_1238.csv", # "is_multivariate": True, # }, } def _resolve_path(file_obj) -> Path: """Extract a pathlib.Path from the gradio file object.""" if file_obj is None: raise ValueError("File object is None.") if isinstance(file_obj, Path): return file_obj if isinstance(file_obj, str): path = Path(file_obj) if not path.is_absolute(): path = (BASE_DIR / path).resolve() return path # Gradio may pass dictionaries or objects with a 'name' attribute. if isinstance(file_obj, dict) and "name" in file_obj: return _resolve_path(file_obj["name"]) name = getattr(file_obj, "name", None) if not name: raise ValueError("Unable to resolve uploaded file path.") return _resolve_path(name) def _load_dataframe(path: Path) -> pd.DataFrame: """Load a dataframe from supported file types.""" if not path.exists(): raise ValueError(f"File not found: {path}. If this is a bundled sample, ensure it exists under {SAMPLE_DATASET_DIR}.") suffix = path.suffix.lower() if suffix == ".npy": data = np.load(path, allow_pickle=False) if data.ndim == 1: data = data.reshape(-1, 1) if not isinstance(data, np.ndarray): raise ValueError("Loaded .npy data is not a numpy array.") return pd.DataFrame(data) if suffix not in {".csv", ".txt"}: raise ValueError("Unsupported file type. Please upload a .csv, .txt, or .npy file.") return pd.read_csv(path) def _extract_label_column(df: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[pd.Series]]: """Split a label column from dataframe if one of the candidate names exists.""" lower_to_original = {col.lower(): col for col in df.columns} label_col = None for candidate in LABEL_COLUMN_CANDIDATES: if candidate in lower_to_original: label_col = lower_to_original[candidate] break if label_col is None: return df, None label_series = pd.to_numeric(df[label_col], errors="raise") feature_df = df.drop(columns=[label_col]) return feature_df, label_series def _load_label_series(file_obj) -> pd.Series: """Load labels from a dedicated upload.""" path = _resolve_path(file_obj) df = _load_dataframe(path) numeric_df = df.select_dtypes(include=np.number) if numeric_df.empty: raise ValueError("Uploaded label file does not contain numeric columns.") lower_to_original = {col.lower(): col for col in numeric_df.columns} for candidate in LABEL_COLUMN_CANDIDATES: if candidate in lower_to_original: column = lower_to_original[candidate] return pd.to_numeric(numeric_df[column], errors="raise").rename("label") if numeric_df.shape[1] > 1: raise ValueError( "Label file must contain exactly one numeric column or include a column named 'label'." ) series = pd.to_numeric(numeric_df.iloc[:, 0], errors="raise").rename("label") return series def load_timeseries( value_file, feature_columns: List[str] | None, label_source: LabelSource, label_file=None, ) -> Tuple[pd.DataFrame, np.ndarray, Optional[pd.Series]]: """Load the uploaded value file, optional label file, and return features/labels.""" value_path = _resolve_path(value_file) raw_df = _load_dataframe(value_path) feature_df = raw_df.select_dtypes(include=np.number) if feature_df.empty: raise ValueError("No numeric columns detected. Ensure your value file contains numeric values.") label_series: Optional[pd.Series] = None feature_df, embedded_label = _extract_label_column(feature_df) if label_source == "same_file": if embedded_label is None: raise ValueError( "Label column not found in the uploaded file. Expected a column named 'label'." ) label_series = embedded_label elif label_source == "separate_file": if label_file is None: raise ValueError("Please upload a label file or switch the label source option.") label_series = _load_label_series(label_file) elif label_source == "none": label_series = None else: raise ValueError(f"Unsupported label source option: {label_source}") if feature_columns: missing = [col for col in feature_columns if col not in feature_df.columns] if missing: raise ValueError(f"Selected columns not found in the value file: {', '.join(missing)}") feature_df = feature_df[feature_columns] feature_df = feature_df.reset_index(drop=True) if label_series is not None: label_series = label_series.reset_index(drop=True) if len(label_series) != len(feature_df): min_length = min(len(label_series), len(feature_df)) label_series = label_series.iloc[:min_length].reset_index(drop=True) feature_df = feature_df.iloc[:min_length, :].reset_index(drop=True) array = feature_df.to_numpy(dtype=np.float32) if array.ndim == 1: array = array.reshape(-1, 1) return feature_df, array, label_series def _metrics_to_dataframe(metrics: dict[str, float]) -> pd.DataFrame: if not metrics: return pd.DataFrame({"Metric": [], "Value": []}) return pd.DataFrame( { "Metric": list(metrics.keys()), "Value": [round(float(value), 4) for value in metrics.values()], } ) def infer( file_obj, is_multivariate: bool, window_size: int, batch_size: int, multi_size: str, feature_columns: List[str], label_source: LabelSource, label_file, ) -> Tuple[str, pd.DataFrame, plt.Figure, pd.DataFrame]: """Run Time-RCD inference and produce outputs for the Gradio UI.""" ensure_checkpoints() numeric_df, array, labels = load_timeseries( file_obj, feature_columns or None, label_source=label_source, label_file=label_file ) num_features = array.shape[1] if array.ndim > 1 else 1 if is_multivariate and num_features == 1: raise ValueError( "Dataset check: only one feature column found, so please switch the Data type to 'Univariate' or upload a multivariate file with multiple feature columns." ) if not is_multivariate and num_features > 1: raise ValueError( "Dataset check: multiple feature columns detected, so please switch the Data type to 'Multivariate' or provide a univariate file with a single feature column." ) kwargs = { "Multi": is_multivariate, "win_size": window_size, "batch_size": batch_size, "random_mask": "random_mask", "size": multi_size, "device": "cpu", } scores, logits = run_Time_RCD(array, **kwargs) score_vector = np.asarray(scores).reshape(-1) logit_vector = np.asarray(logits).reshape(-1) valid_length = min(len(score_vector), len(numeric_df)) if labels is not None: valid_length = min(valid_length, len(labels)) result_df = numeric_df.iloc[:valid_length, :].copy() score_series = pd.Series(score_vector[:valid_length], index=result_df.index, name="anomaly_score") logit_series = pd.Series(logit_vector[:valid_length], index=result_df.index, name="anomaly_logit") result_df["anomaly_score"] = score_series result_df["anomaly_logit"] = logit_series metrics_df: pd.DataFrame if labels is not None: label_series = labels.iloc[:valid_length] result_df["label"] = label_series.to_numpy() metrics = get_metrics(score_series.to_numpy(), label_series.to_numpy()) metrics_df = _metrics_to_dataframe(metrics) else: metrics_df = pd.DataFrame({"Metric": ["Info"], "Value": ["Labels not provided; metrics skipped."]}) top_indices = score_series.nlargest(5).index.tolist() highlight_message = ( "Top anomaly indices (by score): " + ", ".join(str(idx) for idx in top_indices) if len(top_indices) > 0 else "No anomalies detected." ) if labels is None: highlight_message += " Metrics skipped due to missing labels." figure = build_plot(result_df) return highlight_message, result_df, figure, metrics_df def build_plot(result_df: pd.DataFrame) -> plt.Figure: """Create a matplotlib plot of the first feature vs. anomaly score.""" fig, ax_primary = plt.subplots( figsize=(12, 4), # wider canvas dpi=200, # higher resolution constrained_layout=True ) index = result_df.index feature_cols = [ col for col in result_df.columns if col not in {"anomaly_score", "anomaly_logit", "label"} ] primary_col = feature_cols[0] ax_primary.plot( index, result_df[primary_col], label=f"{primary_col}", color="#1f77b4", linewidth=1.0, ) if "label" in result_df.columns: anomalies = result_df[result_df["label"] > 0] if not anomalies.empty: ax_primary.scatter( anomalies.index, anomalies[primary_col], label="Label = 1", color="#ff7f0e", marker="o", s=30, alpha=0.85, ) ax_primary.set_xlabel("Index") ax_primary.set_ylabel("Value") ax_primary.grid(alpha=0.2) ax_secondary = ax_primary.twinx() ax_secondary.plot( index, result_df["anomaly_score"], label="Anomaly Score", color="#d62728", linewidth=1.0, ) ax_secondary.set_ylabel("Anomaly Score") handles_primary, labels_primary = ax_primary.get_legend_handles_labels() handles_secondary, labels_secondary = ax_secondary.get_legend_handles_labels() ax_primary.legend( handles_primary + handles_secondary, labels_primary + labels_secondary, loc="upper right", ) fig.tight_layout() return fig def build_interface() -> gr.Blocks: """Define the Gradio UI.""" with gr.Blocks(title="Time-RCD Zero-Shot Anomaly Detection") as demo: gr.Markdown( "# Time-RCD Zero-Shot Anomaly Detection\n" "Start with one of the bundled datasets or upload your own time series to run zero-shot anomaly detection." ) bundled_choices = list(SAMPLE_FILES.keys()) default_choice = bundled_choices[0] if bundled_choices else "Upload my own" data_selector = gr.Radio( choices=bundled_choices + ["Upload my own"], value=default_choice, label="Choose dataset", ) with gr.Row(): file_input = gr.File( label="Upload time series file (.csv, .txt, .npy)", file_types=[".csv", ".txt", ".npy"], visible=default_choice == "Upload my own", ) label_source = gr.Radio( choices=list(LABEL_SOURCE_CHOICES.keys()), value="Value + label in same file", label="Label source", ) with gr.Row(): label_file_input = gr.File( label="Upload label file (.csv, .txt, .npy)", file_types=[".csv", ".txt", ".npy"], visible=False, ) column_selector = gr.Textbox( label="Columns to use (comma-separated, optional)", placeholder="e.g. value,feature_1,feature_2", ) gr.Markdown( "Bundled datasets live in the Downloads folder and include labels unless noted. " "Select \"Upload my own\" to provide a custom file." ) with gr.Row(): multivariate = gr.Radio( choices=["Univariate", "Multivariate"], value=( "Multivariate" if bundled_choices and SAMPLE_FILES[default_choice]["is_multivariate"] else "Univariate" ), label="Data type", ) window_size_in = gr.Slider( minimum=128, maximum=20000, value=15000, step=128, label="Window size", ) batch_size_in = gr.Slider( minimum=1, maximum=128, value=16, step=1, label="Batch size", ) with gr.Row(): multi_size_in = gr.Radio( choices=["full", "small"], value="full", label="Multivariate model size", ) run_button = gr.Button("Run Inference", variant="primary") result_message = gr.Textbox(label="Summary", interactive=False) result_dataframe = gr.DataFrame(label="Anomaly Scores", interactive=False) plot_output = gr.Plot(label="Series vs. Anomaly Score") metrics_output = gr.DataFrame(label="Metrics", interactive=False) def _submit( data_choice, file_obj, label_source_choice, label_file_obj, multivariate_choice, win, batch, size, columns_text, ): use_sample = data_choice != "Upload my own" if use_sample: sample_entry = SAMPLE_FILES[data_choice] value_obj = sample_entry["path"] else: value_obj = file_obj if value_obj is None: raise gr.Error("Please upload a time series file or choose a sample.") feature_columns = [col.strip() for col in columns_text.split(",") if col.strip()] if columns_text else [] is_multi = multivariate_choice == "Multivariate" resolved_label_source = LABEL_SOURCE_CHOICES[label_source_choice] if resolved_label_source == "separate_file" and label_file_obj is None: raise gr.Error("Please upload a label file or change the label source option.") summary, df, fig, metrics = infer( file_obj=value_obj, is_multivariate=is_multi, window_size=int(win), batch_size=int(batch), multi_size=size, feature_columns=feature_columns, label_source=resolved_label_source, label_file=label_file_obj, ) return summary, df, fig, metrics def _toggle_label_file(option): return gr.update(visible=option == "Labels in separate file") def _handle_dataset_choice(choice): show_upload = choice == "Upload my own" if choice == "Upload my own": multi_update = gr.update() else: expected_multi = SAMPLE_FILES[choice]["is_multivariate"] multi_update = gr.update(value="Multivariate" if expected_multi else "Univariate") return gr.update(visible=show_upload), multi_update label_source.change(fn=_toggle_label_file, inputs=label_source, outputs=label_file_input) data_selector.change(fn=_handle_dataset_choice, inputs=data_selector, outputs=[file_input, multivariate]) run_button.click( fn=_submit, inputs=[ data_selector, file_input, label_source, label_file_input, multivariate, window_size_in, batch_size_in, multi_size_in, column_selector, ], outputs=[result_message, result_dataframe, plot_output, metrics_output], ) return demo demo = build_interface() if __name__ == "__main__": demo.launch()