|
|
import logging |
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import scipy.fft as fft |
|
|
import torch |
|
|
from gluonts.time_feature import time_features_from_frequency_str |
|
|
from gluonts.time_feature._base import ( |
|
|
day_of_month, |
|
|
day_of_month_index, |
|
|
day_of_week, |
|
|
day_of_week_index, |
|
|
day_of_year, |
|
|
hour_of_day, |
|
|
hour_of_day_index, |
|
|
minute_of_hour, |
|
|
minute_of_hour_index, |
|
|
month_of_year, |
|
|
month_of_year_index, |
|
|
second_of_minute, |
|
|
second_of_minute_index, |
|
|
week_of_year, |
|
|
week_of_year_index, |
|
|
) |
|
|
from gluonts.time_feature.holiday import ( |
|
|
BLACK_FRIDAY, |
|
|
CHRISTMAS_DAY, |
|
|
CHRISTMAS_EVE, |
|
|
CYBER_MONDAY, |
|
|
EASTER_MONDAY, |
|
|
EASTER_SUNDAY, |
|
|
GOOD_FRIDAY, |
|
|
INDEPENDENCE_DAY, |
|
|
LABOR_DAY, |
|
|
MEMORIAL_DAY, |
|
|
NEW_YEARS_DAY, |
|
|
NEW_YEARS_EVE, |
|
|
THANKSGIVING, |
|
|
SpecialDateFeatureSet, |
|
|
exponential_kernel, |
|
|
squared_exponential_kernel, |
|
|
) |
|
|
from gluonts.time_feature.seasonality import get_seasonality |
|
|
from scipy.signal import find_peaks |
|
|
|
|
|
from src.data.constants import BASE_END_DATE, BASE_START_DATE |
|
|
from src.data.frequency import ( |
|
|
Frequency, |
|
|
validate_frequency_safety, |
|
|
) |
|
|
from src.utils.utils import device |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
ENHANCED_TIME_FEATURES = { |
|
|
|
|
|
"high_freq": { |
|
|
"normalized": [ |
|
|
second_of_minute, |
|
|
minute_of_hour, |
|
|
hour_of_day, |
|
|
day_of_week, |
|
|
day_of_month, |
|
|
], |
|
|
"index": [ |
|
|
second_of_minute_index, |
|
|
minute_of_hour_index, |
|
|
hour_of_day_index, |
|
|
day_of_week_index, |
|
|
], |
|
|
}, |
|
|
|
|
|
"medium_freq": { |
|
|
"normalized": [ |
|
|
hour_of_day, |
|
|
day_of_week, |
|
|
day_of_month, |
|
|
day_of_year, |
|
|
month_of_year, |
|
|
], |
|
|
"index": [ |
|
|
hour_of_day_index, |
|
|
day_of_week_index, |
|
|
day_of_month_index, |
|
|
week_of_year_index, |
|
|
], |
|
|
}, |
|
|
|
|
|
"low_freq": { |
|
|
"normalized": [day_of_week, day_of_month, month_of_year, week_of_year], |
|
|
"index": [day_of_week_index, month_of_year_index, week_of_year_index], |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
HOLIDAY_FEATURE_SETS = { |
|
|
"us_business": [ |
|
|
NEW_YEARS_DAY, |
|
|
MEMORIAL_DAY, |
|
|
INDEPENDENCE_DAY, |
|
|
LABOR_DAY, |
|
|
THANKSGIVING, |
|
|
CHRISTMAS_EVE, |
|
|
CHRISTMAS_DAY, |
|
|
NEW_YEARS_EVE, |
|
|
], |
|
|
"us_retail": [ |
|
|
NEW_YEARS_DAY, |
|
|
EASTER_SUNDAY, |
|
|
MEMORIAL_DAY, |
|
|
INDEPENDENCE_DAY, |
|
|
LABOR_DAY, |
|
|
THANKSGIVING, |
|
|
BLACK_FRIDAY, |
|
|
CYBER_MONDAY, |
|
|
CHRISTMAS_EVE, |
|
|
CHRISTMAS_DAY, |
|
|
NEW_YEARS_EVE, |
|
|
], |
|
|
"christian": [ |
|
|
NEW_YEARS_DAY, |
|
|
GOOD_FRIDAY, |
|
|
EASTER_SUNDAY, |
|
|
EASTER_MONDAY, |
|
|
CHRISTMAS_EVE, |
|
|
CHRISTMAS_DAY, |
|
|
NEW_YEARS_EVE, |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
class TimeFeatureGenerator: |
|
|
""" |
|
|
Enhanced time feature generator that leverages full GluonTS capabilities. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
use_enhanced_features: bool = True, |
|
|
use_holiday_features: bool = True, |
|
|
holiday_set: str = "us_business", |
|
|
holiday_kernel: str = "exponential", |
|
|
holiday_kernel_alpha: float = 1.0, |
|
|
use_index_features: bool = True, |
|
|
k_max: int = 15, |
|
|
include_seasonality_info: bool = True, |
|
|
use_auto_seasonality: bool = False, |
|
|
max_seasonal_periods: int = 3, |
|
|
): |
|
|
""" |
|
|
Initialize enhanced time feature generator. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
use_enhanced_features : bool |
|
|
Whether to use frequency-specific enhanced features |
|
|
use_holiday_features : bool |
|
|
Whether to include holiday features |
|
|
holiday_set : str |
|
|
Which holiday set to use ('us_business', 'us_retail', 'christian') |
|
|
holiday_kernel : str |
|
|
Holiday kernel type ('indicator', 'exponential', 'squared_exponential') |
|
|
holiday_kernel_alpha : float |
|
|
Kernel parameter for exponential kernels |
|
|
use_index_features : bool |
|
|
Whether to include index-based features alongside normalized ones |
|
|
k_max : int |
|
|
Maximum number of time features to pad to |
|
|
include_seasonality_info : bool |
|
|
Whether to include seasonality information as features |
|
|
use_auto_seasonality : bool |
|
|
Whether to use automatic FFT-based seasonality detection |
|
|
max_seasonal_periods : int |
|
|
Maximum number of seasonal periods to detect automatically |
|
|
""" |
|
|
self.use_enhanced_features = use_enhanced_features |
|
|
self.use_holiday_features = use_holiday_features |
|
|
self.holiday_set = holiday_set |
|
|
self.use_index_features = use_index_features |
|
|
self.k_max = k_max |
|
|
self.include_seasonality_info = include_seasonality_info |
|
|
self.use_auto_seasonality = use_auto_seasonality |
|
|
self.max_seasonal_periods = max_seasonal_periods |
|
|
|
|
|
|
|
|
self.holiday_feature_set = None |
|
|
if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS: |
|
|
kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha) |
|
|
self.holiday_feature_set = SpecialDateFeatureSet(HOLIDAY_FEATURE_SETS[holiday_set], kernel_func) |
|
|
|
|
|
def _get_holiday_kernel(self, kernel_type: str, alpha: float): |
|
|
"""Get holiday kernel function.""" |
|
|
if kernel_type == "exponential": |
|
|
return exponential_kernel(alpha) |
|
|
elif kernel_type == "squared_exponential": |
|
|
return squared_exponential_kernel(alpha) |
|
|
else: |
|
|
|
|
|
return lambda x: float(x == 0) |
|
|
|
|
|
def _get_feature_category(self, freq_str: str) -> str: |
|
|
"""Determine feature category based on frequency.""" |
|
|
if freq_str in ["s", "1min", "5min", "10min", "15min"]: |
|
|
return "high_freq" |
|
|
elif freq_str in ["h", "D"]: |
|
|
return "medium_freq" |
|
|
else: |
|
|
return "low_freq" |
|
|
|
|
|
def _compute_enhanced_features(self, period_index: pd.PeriodIndex, freq_str: str) -> np.ndarray: |
|
|
"""Compute enhanced time features based on frequency.""" |
|
|
if not self.use_enhanced_features: |
|
|
return np.array([]).reshape(len(period_index), 0) |
|
|
|
|
|
category = self._get_feature_category(freq_str) |
|
|
feature_config = ENHANCED_TIME_FEATURES[category] |
|
|
|
|
|
features = [] |
|
|
|
|
|
|
|
|
for feat_func in feature_config["normalized"]: |
|
|
try: |
|
|
feat_values = feat_func(period_index) |
|
|
features.append(feat_values) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
|
|
|
if self.use_index_features: |
|
|
for feat_func in feature_config["index"]: |
|
|
try: |
|
|
feat_values = feat_func(period_index) |
|
|
|
|
|
if feat_values.max() > 0: |
|
|
feat_values = feat_values / feat_values.max() |
|
|
features.append(feat_values) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if features: |
|
|
return np.stack(features, axis=-1) |
|
|
else: |
|
|
return np.array([]).reshape(len(period_index), 0) |
|
|
|
|
|
def _compute_holiday_features(self, date_range: pd.DatetimeIndex) -> np.ndarray: |
|
|
"""Compute holiday features.""" |
|
|
if not self.use_holiday_features or self.holiday_feature_set is None: |
|
|
return np.array([]).reshape(len(date_range), 0) |
|
|
|
|
|
try: |
|
|
holiday_features = self.holiday_feature_set(date_range) |
|
|
return holiday_features.T |
|
|
except Exception: |
|
|
return np.array([]).reshape(len(date_range), 0) |
|
|
|
|
|
def _detect_auto_seasonality(self, time_series_values: np.ndarray) -> list: |
|
|
""" |
|
|
Detect seasonal periods automatically using FFT analysis. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
time_series_values : np.ndarray |
|
|
Time series values for seasonality detection |
|
|
|
|
|
Returns |
|
|
------- |
|
|
list |
|
|
List of detected seasonal periods |
|
|
""" |
|
|
if not self.use_auto_seasonality or len(time_series_values) < 10: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
values = time_series_values[~np.isnan(time_series_values)] |
|
|
if len(values) < 10: |
|
|
return [] |
|
|
|
|
|
|
|
|
x = np.arange(len(values)) |
|
|
coeffs = np.polyfit(x, values, 1) |
|
|
trend = np.polyval(coeffs, x) |
|
|
detrended = values - trend |
|
|
|
|
|
|
|
|
window = np.hanning(len(detrended)) |
|
|
windowed = detrended * window |
|
|
|
|
|
|
|
|
padded_length = len(windowed) * 2 |
|
|
padded_values = np.zeros(padded_length) |
|
|
padded_values[: len(windowed)] = windowed |
|
|
|
|
|
|
|
|
fft_values = fft.rfft(padded_values) |
|
|
fft_magnitudes = np.abs(fft_values) |
|
|
freqs = np.fft.rfftfreq(padded_length) |
|
|
|
|
|
|
|
|
fft_magnitudes[0] = 0.0 |
|
|
|
|
|
|
|
|
threshold = 0.05 * np.max(fft_magnitudes) |
|
|
peak_indices, _ = find_peaks(fft_magnitudes, height=threshold) |
|
|
|
|
|
if len(peak_indices) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
sorted_indices = peak_indices[np.argsort(fft_magnitudes[peak_indices])[::-1]] |
|
|
top_indices = sorted_indices[: self.max_seasonal_periods] |
|
|
|
|
|
|
|
|
periods = [] |
|
|
for idx in top_indices: |
|
|
if freqs[idx] > 0: |
|
|
period = 1.0 / freqs[idx] |
|
|
|
|
|
period = round(period / 2) |
|
|
if 2 <= period <= len(values) // 2: |
|
|
periods.append(period) |
|
|
|
|
|
return list(set(periods)) |
|
|
|
|
|
except Exception: |
|
|
return [] |
|
|
|
|
|
def _compute_seasonality_features( |
|
|
self, |
|
|
period_index: pd.PeriodIndex, |
|
|
freq_str: str, |
|
|
time_series_values: np.ndarray = None, |
|
|
) -> np.ndarray: |
|
|
"""Compute seasonality-aware features.""" |
|
|
if not self.include_seasonality_info: |
|
|
return np.array([]).reshape(len(period_index), 0) |
|
|
|
|
|
all_seasonal_features = [] |
|
|
|
|
|
|
|
|
try: |
|
|
seasonality = get_seasonality(freq_str) |
|
|
if seasonality > 1: |
|
|
positions = np.arange(len(period_index)) |
|
|
sin_feat = np.sin(2 * np.pi * positions / seasonality) |
|
|
cos_feat = np.cos(2 * np.pi * positions / seasonality) |
|
|
all_seasonal_features.extend([sin_feat, cos_feat]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if self.use_auto_seasonality and time_series_values is not None: |
|
|
auto_periods = self._detect_auto_seasonality(time_series_values) |
|
|
for period in auto_periods: |
|
|
try: |
|
|
positions = np.arange(len(period_index)) |
|
|
sin_feat = np.sin(2 * np.pi * positions / period) |
|
|
cos_feat = np.cos(2 * np.pi * positions / period) |
|
|
all_seasonal_features.extend([sin_feat, cos_feat]) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if all_seasonal_features: |
|
|
return np.stack(all_seasonal_features, axis=-1) |
|
|
else: |
|
|
return np.array([]).reshape(len(period_index), 0) |
|
|
|
|
|
def compute_features( |
|
|
self, |
|
|
period_index: pd.PeriodIndex, |
|
|
date_range: pd.DatetimeIndex, |
|
|
freq_str: str, |
|
|
time_series_values: np.ndarray = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Compute all time features for given period index. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
period_index : pd.PeriodIndex |
|
|
Period index for computing features |
|
|
date_range : pd.DatetimeIndex |
|
|
Corresponding datetime index for holiday features |
|
|
freq_str : str |
|
|
Frequency string |
|
|
time_series_values : np.ndarray, optional |
|
|
Time series values for automatic seasonality detection |
|
|
|
|
|
Returns |
|
|
------- |
|
|
np.ndarray |
|
|
Time features array of shape [time_steps, num_features] |
|
|
""" |
|
|
all_features = [] |
|
|
|
|
|
|
|
|
try: |
|
|
standard_features = time_features_from_frequency_str(freq_str) |
|
|
if standard_features: |
|
|
std_feat = np.stack([feat(period_index) for feat in standard_features], axis=-1) |
|
|
all_features.append(std_feat) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
enhanced_feat = self._compute_enhanced_features(period_index, freq_str) |
|
|
if enhanced_feat.shape[1] > 0: |
|
|
all_features.append(enhanced_feat) |
|
|
|
|
|
|
|
|
holiday_feat = self._compute_holiday_features(date_range) |
|
|
if holiday_feat.shape[1] > 0: |
|
|
all_features.append(holiday_feat) |
|
|
|
|
|
|
|
|
seasonality_feat = self._compute_seasonality_features(period_index, freq_str, time_series_values) |
|
|
if seasonality_feat.shape[1] > 0: |
|
|
all_features.append(seasonality_feat) |
|
|
|
|
|
if all_features: |
|
|
combined_features = np.concatenate(all_features, axis=-1) |
|
|
else: |
|
|
combined_features = np.zeros((len(period_index), 1)) |
|
|
|
|
|
return combined_features |
|
|
|
|
|
|
|
|
def compute_batch_time_features( |
|
|
start: list[np.datetime64], |
|
|
history_length: int, |
|
|
future_length: int, |
|
|
batch_size: int, |
|
|
frequency: list[Frequency], |
|
|
K_max: int = 6, |
|
|
time_feature_config: dict[str, Any] | None = None, |
|
|
): |
|
|
""" |
|
|
Compute time features from start timestamps and frequency. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
start : array-like, shape (batch_size,) |
|
|
Start timestamps for each batch item. |
|
|
history_length : int |
|
|
Length of history sequence. |
|
|
future_length : int |
|
|
Length of target sequence. |
|
|
batch_size : int |
|
|
Batch size. |
|
|
frequency : array-like, shape (batch_size,) |
|
|
Frequency of the time series. |
|
|
K_max : int, optional |
|
|
Maximum number of time features to pad to (default: 6). |
|
|
time_feature_config : dict, optional |
|
|
Configuration for enhanced time features. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
tuple |
|
|
(history_time_features, target_time_features) where each is a torch.Tensor |
|
|
of shape (batch_size, length, K_max). |
|
|
""" |
|
|
|
|
|
feature_config = time_feature_config or {} |
|
|
feature_generator = TimeFeatureGenerator(**feature_config) |
|
|
|
|
|
|
|
|
history_features_list = [] |
|
|
future_features_list = [] |
|
|
total_length = history_length + future_length |
|
|
for i in range(batch_size): |
|
|
frequency_i = frequency[i] |
|
|
freq_str = frequency_i.to_pandas_freq(for_date_range=True) |
|
|
period_freq_str = frequency_i.to_pandas_freq(for_date_range=False) |
|
|
|
|
|
|
|
|
start_ts = pd.Timestamp(start[i]) |
|
|
if not validate_frequency_safety(start_ts, total_length, frequency_i): |
|
|
logger.debug( |
|
|
f"Start date {start_ts} not safe for total_length={total_length}, frequency={frequency_i}. " |
|
|
f"Using BASE_START_DATE instead." |
|
|
) |
|
|
start_ts = BASE_START_DATE |
|
|
|
|
|
|
|
|
history_range = pd.date_range(start=start_ts, periods=history_length, freq=freq_str) |
|
|
|
|
|
|
|
|
if history_range[-1] > BASE_END_DATE: |
|
|
safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (history_length + future_length) |
|
|
if safe_start < BASE_START_DATE: |
|
|
safe_start = BASE_START_DATE |
|
|
history_range = pd.date_range(start=safe_start, periods=history_length, freq=freq_str) |
|
|
|
|
|
future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str) |
|
|
future_range = pd.date_range(start=future_start, periods=future_length, freq=freq_str) |
|
|
|
|
|
|
|
|
history_period_idx = history_range.to_period(period_freq_str) |
|
|
future_period_idx = future_range.to_period(period_freq_str) |
|
|
|
|
|
|
|
|
history_features = feature_generator.compute_features(history_period_idx, history_range, freq_str) |
|
|
future_features = feature_generator.compute_features(future_period_idx, future_range, freq_str) |
|
|
|
|
|
|
|
|
history_features = _pad_or_truncate_features(history_features, K_max) |
|
|
future_features = _pad_or_truncate_features(future_features, K_max) |
|
|
|
|
|
history_features_list.append(history_features) |
|
|
future_features_list.append(future_features) |
|
|
|
|
|
|
|
|
history_time_features = np.stack(history_features_list, axis=0) |
|
|
future_time_features = np.stack(future_features_list, axis=0) |
|
|
|
|
|
return ( |
|
|
torch.from_numpy(history_time_features).float().to(device), |
|
|
torch.from_numpy(future_time_features).float().to(device), |
|
|
) |
|
|
|
|
|
|
|
|
def _pad_or_truncate_features(features: np.ndarray, K_max: int) -> np.ndarray: |
|
|
"""Pad with zeros or truncate features to K_max dimensions.""" |
|
|
seq_len, num_features = features.shape |
|
|
|
|
|
if num_features < K_max: |
|
|
|
|
|
padding = np.zeros((seq_len, K_max - num_features)) |
|
|
features = np.concatenate([features, padding], axis=-1) |
|
|
elif num_features > K_max: |
|
|
|
|
|
features = features[:, :K_max] |
|
|
|
|
|
return features |
|
|
|