tempoPFN / src /synthetic_generation /kernel_synth /kernel_generator_wrapper.py
Vladyslav Moroshan
Apply ruff formatting
96e1a32
from typing import Any
import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import KernelGeneratorParams
from src.synthetic_generation.kernel_synth.kernel_synth import KernelSynthGenerator
class KernelGeneratorWrapper(GeneratorWrapper):
"""
Wrapper for KernelSynthGenerator to generate batches of multivariate time series data
by stacking multiple univariate series. Accepts a KernelGeneratorParams dataclass for configuration.
"""
def __init__(self, params: KernelGeneratorParams):
super().__init__(params)
self.params: KernelGeneratorParams = params
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
"""
Sample parameter values for batch generation with KernelSynthGenerator.
Returns
-------
Dict[str, Any]
Dictionary containing sampled parameter values.
"""
params = super()._sample_parameters(batch_size)
params.update(
{
"length": self.params.length,
"max_kernels": self.params.max_kernels,
}
)
return params
def generate_batch(
self,
batch_size: int,
seed: int | None = None,
params: dict[str, Any] | None = None,
) -> TimeSeriesContainer:
"""
Generate a batch of synthetic multivariate time series using KernelSynthGenerator.
Parameters
----------
batch_size : int
Number of time series to generate.
seed : int, optional
Random seed for this batch (default: None).
params : Dict[str, Any], optional
Pre-sampled parameters to use. If None, parameters will be sampled.
Returns
-------
BatchTimeSeriesContainer
A container with the generated time series data.
"""
if seed is not None:
self._set_random_seeds(seed)
if params is None:
params = self._sample_parameters(batch_size)
generator = KernelSynthGenerator(
length=params["length"],
max_kernels=params["max_kernels"],
random_seed=seed,
)
batch_values = []
for i in range(batch_size):
batch_seed = None if seed is None else seed + i
values = generator.generate_time_series(random_seed=batch_seed)
batch_values.append(values)
return TimeSeriesContainer(
values=np.array(batch_values),
start=params["start"],
frequency=params["frequency"],
)