import pandas as pd from .outbreak_detection import ( LSTMforOutbreakDetection, ARIMAforOutbreakDetection, IQRforOutbreakDetection ) from .plotting.visualization import plot_anomalies from .utils import prepare_time_series_dataframe THRESHOLD_METHODS = { "IQR on (ground truth - forecast)": 0, "IQR on |ground truth - forecast|": 1, "IQR on |ground truth - forecast|/forecast": 2, "Percentile threshold on absolute loss": 3, "Percentile threshold on raw loss": 4 } def detect_anomalies(file_path: str, method: str, k: int, percentile: float, threshold_method: int): """ Detects anomalies in time series data using various detection methods. Args: file_path (str): Path to the CSV file containing time series data method (str): Detection method to use ('LSTM', 'ARIMA', or 'IQR') k (int): Number of neighbors or window size (method-dependent parameter) percentile (float): Percentile threshold for anomaly detection threshold_method (int): Method to determine threshold for anomaly detection Returns: plotly.graph_objects.Figure: Plotly figure containing the time series with highlighted anomalies """ df = pd.read_csv(file_path) df = prepare_time_series_dataframe(df) # Map threshold methods to their descriptions for better readability detectors = { 'LSTM': LSTMforOutbreakDetection( checkpoint_path='models/lstm_forec_40_11_06.pth', k=k, percentile=percentile, threshold_method=THRESHOLD_METHODS[threshold_method] ), 'ARIMA': ARIMAforOutbreakDetection(k=k), 'IQR': IQRforOutbreakDetection(k=k) } detector = detectors[method] test, new_label = detector.detect_anomalies(df) return plot_anomalies(test, anomaly_col=new_label)