BTC-Autoformer / app.py
thesven's picture
update to interactive plot
040e502
raw
history blame
10.5 kB
# Standard library imports
from typing import Optional, Iterable
# Third-party library imports
from transformers import PretrainedConfig, AutoformerForPrediction
from functools import partial
import gradio as gr
import spaces
import torch
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# External imports
# GluonTS imports
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
AddAgeFeature,
AddObservedValuesIndicator,
AddTimeFeatures,
AsNumpyArray,
Chain,
ExpectedNumInstanceSampler,
InstanceSplitter,
RemoveFields,
TestSplitSampler,
Transformation,
ValidationSplitSampler,
VstackFeatures,
RenameFields,
)
from gluonts.time_feature import time_features_from_frequency_str
from gluonts.transform.sampler import InstanceSampler
# Hugging Face Datasets imports
from datasets import Dataset, Features, Value, Sequence, load_dataset
# GluonTS Loader imports
from gluonts.dataset.loader import as_stacked_batches
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
def convert_to_pandas_period(date, freq):
return pd.Period(date, freq)
def transform_start_field(batch, freq):
batch["start"] = [convert_to_pandas_period(date, freq) for date in batch["start"]]
return batch
def create_transformation(freq: str, config: PretrainedConfig, prediction_length: int) -> Transformation:
remove_field_names = []
if config.num_static_real_features == 0:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
if config.num_dynamic_real_features == 0:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
if config.num_static_categorical_features == 0:
remove_field_names.append(FieldName.FEAT_STATIC_CAT)
# a bit like torchvision.transforms.Compose
return Chain(
# step 1: remove static/dynamic fields if not specified
[RemoveFields(field_names=remove_field_names)]
# step 2: convert the data to NumPy (potentially not needed)
+ (
[
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=int,
)
]
if config.num_static_categorical_features > 0
else []
)
+ (
[
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
expected_ndim=1,
)
]
if config.num_static_real_features > 0
else []
)
+ [
AsNumpyArray(
field=FieldName.TARGET,
# we expect an extra dim for the multivariate case:
expected_ndim=1 if config.input_size == 1 else 2,
),
# step 3: handle the NaN's by filling in the target with zero
# and return the mask (which is in the observed values)
# true for observed values, false for nan's
# the decoder uses this mask (no loss is incurred for unobserved values)
# see loss_weights inside the xxxForPrediction model
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
# step 4: add temporal features based on freq of the dataset
# and the desired prediction length
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features_from_frequency_str(freq),
pred_length=prediction_length,
),
# step 5: add another temporal feature (just a single number)
# tells the model where in its life the value of the time series is,
# sort of a running counter
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=prediction_length,
log_scale=True,
),
# step 6: vertically stack all the temporal features into the key FEAT_TIME
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if config.num_dynamic_real_features > 0
else []
),
),
# step 7: rename to match HuggingFace names
RenameFields(
mapping={
FieldName.FEAT_STATIC_CAT: "static_categorical_features",
FieldName.FEAT_STATIC_REAL: "static_real_features",
FieldName.FEAT_TIME: "time_features",
FieldName.TARGET: "values",
FieldName.OBSERVED_VALUES: "observed_mask",
}
),
]
)
def create_instance_splitter(
config: PretrainedConfig,
mode: str,
prediction_length: int,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
) -> Transformation:
assert mode in ["train", "validation", "test"]
instance_sampler = {
"train": train_sampler
or ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
),
"validation": validation_sampler
or ValidationSplitSampler(min_future=prediction_length),
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field="values",
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=config.context_length + max(config.lags_sequence),
future_length=prediction_length,
time_series_fields=["time_features", "observed_mask"],
)
def create_test_dataloader(
config: PretrainedConfig,
freq: str,
data: Dataset,
batch_size: int,
prediction_length: int,
**kwargs,
):
PREDICTION_INPUT_NAMES = [
"past_time_features",
"past_values",
"past_observed_mask",
"future_time_features",
]
if config.num_static_categorical_features > 0:
PREDICTION_INPUT_NAMES.append("static_categorical_features")
if config.num_static_real_features > 0:
PREDICTION_INPUT_NAMES.append("static_real_features")
transformation = create_transformation(freq, config, prediction_length)
transformed_data = transformation.apply(data, is_train=False)
# we create a Test Instance splitter which will sample the very last
# context window seen during training only for the encoder.
instance_sampler = create_instance_splitter(
config, "test", prediction_length=prediction_length
)
# we apply the transformations in test mode
testing_instances = instance_sampler.apply(transformed_data, is_train=False)
return as_stacked_batches(
testing_instances,
batch_size=batch_size,
output_type=torch.tensor,
field_names=PREDICTION_INPUT_NAMES,
)
def plot(ts_index, test_dataset, forecasts, prediction_length):
# Length of the target data
target_length = len(test_dataset[ts_index]['target'])
# Creating a period range for the entire dataset plus forecast period
index = pd.period_range(
start=test_dataset[ts_index]['start'],
periods=target_length + prediction_length,
freq='1D'
).to_timestamp()
# Plotting actual data
actual_data = go.Scatter(
x=index[:target_length],
y=test_dataset[ts_index]['target'],
name="Actual",
mode='lines',
)
# Plotting the forecast data
forecast_data = go.Scatter(
x=index[target_length:],
y=forecasts[ts_index][0][:prediction_length],
name="Prediction",
mode='lines',
)
# Create the figure
fig = make_subplots(rows=1, cols=1)
fig.add_trace(actual_data, row=1, col=1)
fig.add_trace(forecast_data, row=1, col=1)
# Set layout and title
fig.update_layout(
xaxis_title="Date",
yaxis_title="Value",
title="Actual vs. Predicted Values",
xaxis_rangeslider_visible=True,
)
return fig
def do_prediction(days_to_predict: int):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the desired prediction length
prediction_length = 7 # Number of time steps to predict into the future
freq = "1D" # Daily frequency
dataset = load_dataset("thesven/BTC-Daily-Avg-Market-Value")
dataset['test'].set_transform(partial(transform_start_field, freq=freq))
model = AutoformerForPrediction.from_pretrained("thesven/BTC-Autoformer-v1")
config = model.config
print(f"Config: {config}")
test_dataloader = create_test_dataloader(
config=config,
freq=freq,
data=dataset['test'],
batch_size=64,
prediction_length=prediction_length,
)
model.to(device)
model.eval()
forecasts = []
for batch in test_dataloader:
outputs = model.generate(
static_categorical_features=batch["static_categorical_features"].to(device)
if config.num_static_categorical_features > 0
else None,
static_real_features=batch["static_real_features"].to(device)
if config.num_static_real_features > 0
else None,
past_time_features=batch["past_time_features"].to(device),
past_values=batch["past_values"].to(device),
future_time_features=batch["future_time_features"].to(device),
past_observed_mask=batch["past_observed_mask"].to(device),
)
forecasts.append(outputs.sequences.cpu().numpy())
forecasts = np.vstack(forecasts)
print(forecasts.shape)
return plot(0, dataset['test'], forecasts, prediction_length)
interface = gr.Interface(
fn=do_prediction,
inputs=gr.Slider(minimum=1, maximum=30, step=1, label="Days to Predict"),
outputs="plot",
title="Prediction Plot",
description="Adjust the slider to set the number of days to predict.",
allow_flagging=False, # Disable flagging for simplicity
)
interface.launch()