Time_RCD / app.py
Oliver Le
update the plot resolution
51ed775
import io
import zipfile
from pathlib import Path
from typing import List, Tuple, Literal, Optional
from evaluation.metrics import get_metrics
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from model_wrapper import run_Time_RCD
REPO_ID = "thu-sail-lab/Time-RCD"
CHECKPOINT_FILES = [
"checkpoints/full_mask_anomaly_head_pretrain_checkpoint_best.pth",
"checkpoints/dataset_10_20.pth",
"checkpoints/full_mask_10_20.pth",
"checkpoints/dataset_15_56.pth",
"checkpoints/full_mask_15_56.pth",
]
def ensure_checkpoints() -> None:
"""Ensure that the required checkpoint files are present locally."""
missing = [path for path in CHECKPOINT_FILES if not Path(path).exists()]
if not missing:
return
try:
zip_path = hf_hub_download(
repo_id=REPO_ID,
filename="checkpoints.zip",
repo_type="model",
cache_dir=".cache/hf",
)
except HfHubHTTPError:
zip_path = hf_hub_download(
repo_id=REPO_ID,
filename="checkpoints.zip",
repo_type="dataset",
cache_dir=".cache/hf",
)
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(".")
BASE_DIR = Path(__file__).resolve().parent
SAMPLE_DATASET_DIR = BASE_DIR / "sample_datasets"
LabelSource = Literal["same_file", "separate_file", "none"]
LABEL_COLUMN_CANDIDATES = ("label", "labels")
LABEL_SOURCE_CHOICES = {
"Value + label in same file": "same_file",
"Labels in separate file": "separate_file",
"No labels provided": "none",
}
SAMPLE_FILES: dict[str, dict[str, object]] = {
"Sample: Univariate SED Medical": {
"path": SAMPLE_DATASET_DIR / "235_SED_id_2_Medical_tr_2499_1st_3840.csv",
"is_multivariate": False,
},
"Sample: Univariate UCR Medical": {
"path": SAMPLE_DATASET_DIR / "353_UCR_id_51_Medical_tr_1875_1st_3198.csv",
"is_multivariate": False,
},
"Sample: Univariate Yahoo WebService": {
"path": SAMPLE_DATASET_DIR / "686_YAHOO_id_136_WebService_tr_500_1st_755.csv",
"is_multivariate": False,
},
# "Sample: Multivariate MSL Sensor": {
# "path": SAMPLE_DATASET_DIR / "003_MSL_id_2_Sensor_tr_883_1st_1238.csv",
# "is_multivariate": True,
# },
}
def _resolve_path(file_obj) -> Path:
"""Extract a pathlib.Path from the gradio file object."""
if file_obj is None:
raise ValueError("File object is None.")
if isinstance(file_obj, Path):
return file_obj
if isinstance(file_obj, str):
path = Path(file_obj)
if not path.is_absolute():
path = (BASE_DIR / path).resolve()
return path
# Gradio may pass dictionaries or objects with a 'name' attribute.
if isinstance(file_obj, dict) and "name" in file_obj:
return _resolve_path(file_obj["name"])
name = getattr(file_obj, "name", None)
if not name:
raise ValueError("Unable to resolve uploaded file path.")
return _resolve_path(name)
def _load_dataframe(path: Path) -> pd.DataFrame:
"""Load a dataframe from supported file types."""
if not path.exists():
raise ValueError(f"File not found: {path}. If this is a bundled sample, ensure it exists under {SAMPLE_DATASET_DIR}.")
suffix = path.suffix.lower()
if suffix == ".npy":
data = np.load(path, allow_pickle=False)
if data.ndim == 1:
data = data.reshape(-1, 1)
if not isinstance(data, np.ndarray):
raise ValueError("Loaded .npy data is not a numpy array.")
return pd.DataFrame(data)
if suffix not in {".csv", ".txt"}:
raise ValueError("Unsupported file type. Please upload a .csv, .txt, or .npy file.")
return pd.read_csv(path)
def _extract_label_column(df: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[pd.Series]]:
"""Split a label column from dataframe if one of the candidate names exists."""
lower_to_original = {col.lower(): col for col in df.columns}
label_col = None
for candidate in LABEL_COLUMN_CANDIDATES:
if candidate in lower_to_original:
label_col = lower_to_original[candidate]
break
if label_col is None:
return df, None
label_series = pd.to_numeric(df[label_col], errors="raise")
feature_df = df.drop(columns=[label_col])
return feature_df, label_series
def _load_label_series(file_obj) -> pd.Series:
"""Load labels from a dedicated upload."""
path = _resolve_path(file_obj)
df = _load_dataframe(path)
numeric_df = df.select_dtypes(include=np.number)
if numeric_df.empty:
raise ValueError("Uploaded label file does not contain numeric columns.")
lower_to_original = {col.lower(): col for col in numeric_df.columns}
for candidate in LABEL_COLUMN_CANDIDATES:
if candidate in lower_to_original:
column = lower_to_original[candidate]
return pd.to_numeric(numeric_df[column], errors="raise").rename("label")
if numeric_df.shape[1] > 1:
raise ValueError(
"Label file must contain exactly one numeric column or include a column named 'label'."
)
series = pd.to_numeric(numeric_df.iloc[:, 0], errors="raise").rename("label")
return series
def load_timeseries(
value_file,
feature_columns: List[str] | None,
label_source: LabelSource,
label_file=None,
) -> Tuple[pd.DataFrame, np.ndarray, Optional[pd.Series]]:
"""Load the uploaded value file, optional label file, and return features/labels."""
value_path = _resolve_path(value_file)
raw_df = _load_dataframe(value_path)
feature_df = raw_df.select_dtypes(include=np.number)
if feature_df.empty:
raise ValueError("No numeric columns detected. Ensure your value file contains numeric values.")
label_series: Optional[pd.Series] = None
feature_df, embedded_label = _extract_label_column(feature_df)
if label_source == "same_file":
if embedded_label is None:
raise ValueError(
"Label column not found in the uploaded file. Expected a column named 'label'."
)
label_series = embedded_label
elif label_source == "separate_file":
if label_file is None:
raise ValueError("Please upload a label file or switch the label source option.")
label_series = _load_label_series(label_file)
elif label_source == "none":
label_series = None
else:
raise ValueError(f"Unsupported label source option: {label_source}")
if feature_columns:
missing = [col for col in feature_columns if col not in feature_df.columns]
if missing:
raise ValueError(f"Selected columns not found in the value file: {', '.join(missing)}")
feature_df = feature_df[feature_columns]
feature_df = feature_df.reset_index(drop=True)
if label_series is not None:
label_series = label_series.reset_index(drop=True)
if len(label_series) != len(feature_df):
min_length = min(len(label_series), len(feature_df))
label_series = label_series.iloc[:min_length].reset_index(drop=True)
feature_df = feature_df.iloc[:min_length, :].reset_index(drop=True)
array = feature_df.to_numpy(dtype=np.float32)
if array.ndim == 1:
array = array.reshape(-1, 1)
return feature_df, array, label_series
def _metrics_to_dataframe(metrics: dict[str, float]) -> pd.DataFrame:
if not metrics:
return pd.DataFrame({"Metric": [], "Value": []})
return pd.DataFrame(
{
"Metric": list(metrics.keys()),
"Value": [round(float(value), 4) for value in metrics.values()],
}
)
def infer(
file_obj,
is_multivariate: bool,
window_size: int,
batch_size: int,
multi_size: str,
feature_columns: List[str],
label_source: LabelSource,
label_file,
) -> Tuple[str, pd.DataFrame, plt.Figure, pd.DataFrame]:
"""Run Time-RCD inference and produce outputs for the Gradio UI."""
ensure_checkpoints()
numeric_df, array, labels = load_timeseries(
file_obj, feature_columns or None, label_source=label_source, label_file=label_file
)
num_features = array.shape[1] if array.ndim > 1 else 1
if is_multivariate and num_features == 1:
raise ValueError(
"Dataset check: only one feature column found, so please switch the Data type to 'Univariate' or upload a multivariate file with multiple feature columns."
)
if not is_multivariate and num_features > 1:
raise ValueError(
"Dataset check: multiple feature columns detected, so please switch the Data type to 'Multivariate' or provide a univariate file with a single feature column."
)
kwargs = {
"Multi": is_multivariate,
"win_size": window_size,
"batch_size": batch_size,
"random_mask": "random_mask",
"size": multi_size,
"device": "cpu",
}
scores, logits = run_Time_RCD(array, **kwargs)
score_vector = np.asarray(scores).reshape(-1)
logit_vector = np.asarray(logits).reshape(-1)
valid_length = min(len(score_vector), len(numeric_df))
if labels is not None:
valid_length = min(valid_length, len(labels))
result_df = numeric_df.iloc[:valid_length, :].copy()
score_series = pd.Series(score_vector[:valid_length], index=result_df.index, name="anomaly_score")
logit_series = pd.Series(logit_vector[:valid_length], index=result_df.index, name="anomaly_logit")
result_df["anomaly_score"] = score_series
result_df["anomaly_logit"] = logit_series
metrics_df: pd.DataFrame
if labels is not None:
label_series = labels.iloc[:valid_length]
result_df["label"] = label_series.to_numpy()
metrics = get_metrics(score_series.to_numpy(), label_series.to_numpy())
metrics_df = _metrics_to_dataframe(metrics)
else:
metrics_df = pd.DataFrame({"Metric": ["Info"], "Value": ["Labels not provided; metrics skipped."]})
top_indices = score_series.nlargest(5).index.tolist()
highlight_message = (
"Top anomaly indices (by score): " + ", ".join(str(idx) for idx in top_indices)
if len(top_indices) > 0
else "No anomalies detected."
)
if labels is None:
highlight_message += " Metrics skipped due to missing labels."
figure = build_plot(result_df)
return highlight_message, result_df, figure, metrics_df
def build_plot(result_df: pd.DataFrame) -> plt.Figure:
"""Create a matplotlib plot of the first feature vs. anomaly score."""
fig, ax_primary = plt.subplots(
figsize=(12, 4), # wider canvas
dpi=200, # higher resolution
constrained_layout=True
)
index = result_df.index
feature_cols = [
col for col in result_df.columns if col not in {"anomaly_score", "anomaly_logit", "label"}
]
primary_col = feature_cols[0]
ax_primary.plot(
index,
result_df[primary_col],
label=f"{primary_col}",
color="#1f77b4",
linewidth=1.0,
)
if "label" in result_df.columns:
anomalies = result_df[result_df["label"] > 0]
if not anomalies.empty:
ax_primary.scatter(
anomalies.index,
anomalies[primary_col],
label="Label = 1",
color="#ff7f0e",
marker="o",
s=30,
alpha=0.85,
)
ax_primary.set_xlabel("Index")
ax_primary.set_ylabel("Value")
ax_primary.grid(alpha=0.2)
ax_secondary = ax_primary.twinx()
ax_secondary.plot(
index,
result_df["anomaly_score"],
label="Anomaly Score",
color="#d62728",
linewidth=1.0,
)
ax_secondary.set_ylabel("Anomaly Score")
handles_primary, labels_primary = ax_primary.get_legend_handles_labels()
handles_secondary, labels_secondary = ax_secondary.get_legend_handles_labels()
ax_primary.legend(
handles_primary + handles_secondary,
labels_primary + labels_secondary,
loc="upper right",
)
fig.tight_layout()
return fig
def build_interface() -> gr.Blocks:
"""Define the Gradio UI."""
with gr.Blocks(title="Time-RCD Zero-Shot Anomaly Detection") as demo:
gr.Markdown(
"# Time-RCD Zero-Shot Anomaly Detection\n"
"Start with one of the bundled datasets or upload your own time series to run zero-shot anomaly detection."
)
bundled_choices = list(SAMPLE_FILES.keys())
default_choice = bundled_choices[0] if bundled_choices else "Upload my own"
data_selector = gr.Radio(
choices=bundled_choices + ["Upload my own"],
value=default_choice,
label="Choose dataset",
)
with gr.Row():
file_input = gr.File(
label="Upload time series file (.csv, .txt, .npy)",
file_types=[".csv", ".txt", ".npy"],
visible=default_choice == "Upload my own",
)
label_source = gr.Radio(
choices=list(LABEL_SOURCE_CHOICES.keys()),
value="Value + label in same file",
label="Label source",
)
with gr.Row():
label_file_input = gr.File(
label="Upload label file (.csv, .txt, .npy)",
file_types=[".csv", ".txt", ".npy"],
visible=False,
)
column_selector = gr.Textbox(
label="Columns to use (comma-separated, optional)",
placeholder="e.g. value,feature_1,feature_2",
)
gr.Markdown(
"Bundled datasets live in the Downloads folder and include labels unless noted. "
"Select \"Upload my own\" to provide a custom file."
)
with gr.Row():
multivariate = gr.Radio(
choices=["Univariate", "Multivariate"],
value=(
"Multivariate"
if bundled_choices and SAMPLE_FILES[default_choice]["is_multivariate"]
else "Univariate"
),
label="Data type",
)
window_size_in = gr.Slider(
minimum=128,
maximum=20000,
value=15000,
step=128,
label="Window size",
)
batch_size_in = gr.Slider(
minimum=1,
maximum=128,
value=16,
step=1,
label="Batch size",
)
with gr.Row():
multi_size_in = gr.Radio(
choices=["full", "small"],
value="full",
label="Multivariate model size",
)
run_button = gr.Button("Run Inference", variant="primary")
result_message = gr.Textbox(label="Summary", interactive=False)
result_dataframe = gr.DataFrame(label="Anomaly Scores", interactive=False)
plot_output = gr.Plot(label="Series vs. Anomaly Score")
metrics_output = gr.DataFrame(label="Metrics", interactive=False)
def _submit(
data_choice,
file_obj,
label_source_choice,
label_file_obj,
multivariate_choice,
win,
batch,
size,
columns_text,
):
use_sample = data_choice != "Upload my own"
if use_sample:
sample_entry = SAMPLE_FILES[data_choice]
value_obj = sample_entry["path"]
else:
value_obj = file_obj
if value_obj is None:
raise gr.Error("Please upload a time series file or choose a sample.")
feature_columns = [col.strip() for col in columns_text.split(",") if col.strip()] if columns_text else []
is_multi = multivariate_choice == "Multivariate"
resolved_label_source = LABEL_SOURCE_CHOICES[label_source_choice]
if resolved_label_source == "separate_file" and label_file_obj is None:
raise gr.Error("Please upload a label file or change the label source option.")
summary, df, fig, metrics = infer(
file_obj=value_obj,
is_multivariate=is_multi,
window_size=int(win),
batch_size=int(batch),
multi_size=size,
feature_columns=feature_columns,
label_source=resolved_label_source,
label_file=label_file_obj,
)
return summary, df, fig, metrics
def _toggle_label_file(option):
return gr.update(visible=option == "Labels in separate file")
def _handle_dataset_choice(choice):
show_upload = choice == "Upload my own"
if choice == "Upload my own":
multi_update = gr.update()
else:
expected_multi = SAMPLE_FILES[choice]["is_multivariate"]
multi_update = gr.update(value="Multivariate" if expected_multi else "Univariate")
return gr.update(visible=show_upload), multi_update
label_source.change(fn=_toggle_label_file, inputs=label_source, outputs=label_file_input)
data_selector.change(fn=_handle_dataset_choice, inputs=data_selector, outputs=[file_input, multivariate])
run_button.click(
fn=_submit,
inputs=[
data_selector,
file_input,
label_source,
label_file_input,
multivariate,
window_size_in,
batch_size_in,
multi_size_in,
column_selector,
],
outputs=[result_message, result_dataframe, plot_output, metrics_output],
)
return demo
demo = build_interface()
if __name__ == "__main__":
demo.launch()