|
""" |
|
A script to run backtest for PVNet for specific sites |
|
|
|
Use: |
|
|
|
- This script uses hydra to construct the config, just like in `run.py`. So you need to make sure |
|
that the data config is set up appropriate for the model being run in this script |
|
- The PVNet model checkpoint; the time range over which to make predictions are made; |
|
the site ids to produce forecasts for and the output directory where the results |
|
near the top of the script as hard coded user variables. These should be changed. |
|
|
|
``` |
|
python scripts/backtest_sites.py |
|
``` |
|
|
|
""" |
|
|
|
try: |
|
import torch.multiprocessing as mp |
|
|
|
mp.set_start_method("spawn", force=True) |
|
mp.set_sharing_strategy("file_system") |
|
except RuntimeError: |
|
pass |
|
|
|
import json |
|
import logging |
|
import os |
|
import sys |
|
|
|
import hydra |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import xarray as xr |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME |
|
from ocf_data_sampler.sample.base import batch_to_tensor, copy_batch_to_device |
|
from ocf_datapipes.batch import ( |
|
BatchKey, |
|
NumpyBatch, |
|
stack_np_examples_into_batch, |
|
) |
|
from ocf_datapipes.config.load import load_yaml_configuration |
|
from ocf_datapipes.load.pv.pv import OpenPVFromNetCDFIterDataPipe |
|
from ocf_datapipes.training.common import create_t0_and_loc_datapipes |
|
from ocf_datapipes.training.pvnet_site import ( |
|
DictDatasetIterDataPipe, |
|
_get_datapipes_dict, |
|
construct_sliced_data_pipeline, |
|
split_dataset_dict_dp, |
|
) |
|
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD |
|
from omegaconf import DictConfig |
|
from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe |
|
from torch.utils.data.datapipes.iter import IterableWrapper |
|
from tqdm import tqdm |
|
|
|
from pvnet.load_model import get_model_from_checkpoints |
|
from pvnet.utils import SiteLocationLookup |
|
|
|
|
|
|
|
|
|
|
|
output_dir = "PLACEHOLDER" |
|
|
|
|
|
|
|
model_chckpoint_dir = "PLACEHOLDER" |
|
|
|
hf_revision = None |
|
hf_token = None |
|
hf_model_id = None |
|
|
|
|
|
start_datetime = "2022-05-08 00:00" |
|
end_datetime = "2022-05-08 00:30" |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
FREQ_MINS = 30 |
|
|
|
|
|
MIN_DAY_ELEVATION = 0 |
|
|
|
|
|
ALL_SITE_IDS = [] |
|
|
|
ALL_SITE_IDS.sort() |
|
|
|
|
|
|
|
|
|
|
|
@functional_datapipe("pad_forward_pv") |
|
class PadForwardPVIterDataPipe(IterDataPipe): |
|
""" |
|
Pads forecast pv. |
|
|
|
Sun position is calculated based off of pv time index |
|
and for t0's close to end of pv data can have wrong shape as pv starts |
|
to run out of data to slice for the forecast part. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
pv_dp: IterDataPipe, |
|
forecast_duration: np.timedelta64, |
|
history_duration: np.timedelta64, |
|
time_resolution_minutes: np.timedelta64, |
|
): |
|
"""Init""" |
|
|
|
super().__init__() |
|
self.pv_dp = pv_dp |
|
self.forecast_duration = forecast_duration |
|
self.history_duration = history_duration |
|
self.time_resolution_minutes = time_resolution_minutes |
|
|
|
self.min_seq_length = history_duration // time_resolution_minutes |
|
|
|
def __iter__(self): |
|
"""Iter""" |
|
|
|
for xr_data in self.pv_dp: |
|
t_end = ( |
|
xr_data.time_utc.data[0] |
|
+ self.history_duration |
|
+ self.forecast_duration |
|
+ self.time_resolution_minutes |
|
) |
|
time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes) |
|
|
|
if len(xr_data.time_utc.data) < self.min_seq_length: |
|
raise ValueError("Not enough PV data to predict") |
|
|
|
yield xr_data.reindex(time_utc=time_idx, fill_value=-1) |
|
|
|
|
|
def load_model_from_hf(model_id: str, revision: str, token: str): |
|
""" |
|
Loads model from HuggingFace |
|
""" |
|
|
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=PYTORCH_WEIGHTS_NAME, |
|
revision=revision, |
|
token=token, |
|
) |
|
|
|
|
|
config_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=CONFIG_NAME, |
|
revision=revision, |
|
token=token, |
|
) |
|
|
|
with open(config_file, "r", encoding="utf-8") as f: |
|
config = json.load(f) |
|
|
|
model = hydra.utils.instantiate(config) |
|
|
|
state_dict = torch.load(model_file, map_location=torch.device("cuda")) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
def preds_to_dataarray(preds, model, valid_times, site_ids): |
|
"""Put numpy array of predictions into a dataarray""" |
|
|
|
if model.use_quantile_regression: |
|
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] |
|
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" |
|
else: |
|
output_labels = ["forecast_mw"] |
|
preds = preds[..., np.newaxis] |
|
|
|
da = xr.DataArray( |
|
data=preds, |
|
dims=["pv_system_id", "target_datetime_utc", "output_label"], |
|
coords=dict( |
|
pv_system_id=site_ids, |
|
target_datetime_utc=valid_times, |
|
output_label=output_labels, |
|
), |
|
) |
|
return da |
|
|
|
|
|
|
|
def get_sites_ds(config_path: str) -> xr.Dataset: |
|
"""Load site data from the path in the data config. |
|
|
|
Args: |
|
config_path: Path to the data configuration file |
|
|
|
Returns: |
|
xarray.Dataset of PVLive truths and capacities |
|
""" |
|
|
|
config = load_yaml_configuration(config_path) |
|
site_datapipe = OpenPVFromNetCDFIterDataPipe(pv=config.input_data.pv) |
|
ds_sites = next(iter(site_datapipe)) |
|
|
|
return ds_sites |
|
|
|
|
|
def get_available_t0_times(start_datetime, end_datetime, config_path): |
|
"""Filter a list of t0 init-times to those for which all required input data is available. |
|
|
|
Args: |
|
start_datetime: First potential t0 time |
|
end_datetime: Last potential t0 time |
|
config_path: Path to data config file |
|
|
|
Returns: |
|
pandas.DatetimeIndex of the init-times available for required inputs |
|
""" |
|
|
|
start_datetime = pd.Timestamp(start_datetime) |
|
end_datetime = pd.Timestamp(end_datetime) |
|
|
|
|
|
datapipes_dict = _get_datapipes_dict(config_path, production=False) |
|
|
|
|
|
config = datapipes_dict.pop("config") |
|
|
|
|
|
|
|
|
|
|
|
|
|
potential_init_times = pd.date_range(start_datetime, end_datetime, freq=f"{FREQ_MINS}min") |
|
|
|
|
|
|
|
history_duration = pd.Timedelta(config.input_data.pv.history_minutes, "min") |
|
forecast_duration = pd.Timedelta(config.input_data.pv.forecast_minutes, "min") |
|
buffered_potential_init_times = pd.date_range( |
|
start_datetime - history_duration, end_datetime + forecast_duration, freq=f"{FREQ_MINS}min" |
|
) |
|
ds_fake_site = ( |
|
buffered_potential_init_times.to_frame().to_xarray().rename({"index": "time_utc"}) |
|
) |
|
ds_fake_site = ds_fake_site.rename({0: "site_pv_power_mw"}) |
|
ds_fake_site = ds_fake_site.expand_dims("pv_system_id", axis=1) |
|
ds_fake_site = ds_fake_site.assign_coords( |
|
pv_system_id=[0], |
|
latitude=("pv_system_id", [0]), |
|
longitude=("pv_system_id", [0]), |
|
) |
|
ds_fake_site = ds_fake_site.site_pv_power_mw.astype(float) * 1e-18 |
|
|
|
datapipes_dict["pv"] = IterableWrapper([ds_fake_site]) |
|
|
|
|
|
location_pipe, t0_datapipe = create_t0_and_loc_datapipes( |
|
datapipes_dict, |
|
configuration=config, |
|
key_for_t0="pv", |
|
shuffle=False, |
|
) |
|
|
|
|
|
|
|
available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)] |
|
available_init_times = pd.to_datetime(available_init_times) |
|
|
|
logger.info( |
|
f"{len(available_init_times)} out of {len(potential_init_times)} " |
|
"requested init-times have required input data" |
|
) |
|
|
|
return available_init_times |
|
|
|
|
|
def get_loctimes_datapipes(config_path): |
|
"""Create location and init-time datapipes |
|
|
|
Args: |
|
config_path: Path to data config file |
|
|
|
Returns: |
|
tuple: A tuple of datapipes |
|
- Datapipe yielding locations |
|
- Datapipe yielding init-times |
|
""" |
|
|
|
|
|
ds_sites = get_sites_ds(config_path) |
|
site_id_to_loc = SiteLocationLookup(ds_sites.longitude, ds_sites.latitude) |
|
|
|
|
|
available_target_times = get_available_t0_times( |
|
start_datetime, |
|
end_datetime, |
|
config_path, |
|
) |
|
num_t0s = len(available_target_times) |
|
|
|
|
|
|
|
|
|
available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv") |
|
|
|
|
|
location_pipe = IterableWrapper([[site_id_to_loc(site_id) for site_id in ALL_SITE_IDS]]).repeat( |
|
num_t0s |
|
) |
|
|
|
|
|
|
|
location_pipe = location_pipe.sharding_filter() |
|
location_pipe = location_pipe.unbatch( |
|
unbatch_level=1 |
|
) |
|
|
|
|
|
|
|
t0_datapipe = IterableWrapper( |
|
[[t0 for site_id in ALL_SITE_IDS] for t0 in available_target_times] |
|
) |
|
t0_datapipe = t0_datapipe.sharding_filter() |
|
t0_datapipe = t0_datapipe.unbatch( |
|
unbatch_level=1 |
|
) |
|
|
|
t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_SITE_IDS)) |
|
location_pipe = location_pipe.set_length(num_t0s * len(ALL_SITE_IDS)) |
|
|
|
return location_pipe, t0_datapipe |
|
|
|
|
|
class ModelPipe: |
|
"""A class to conveniently make and process predictions from batches""" |
|
|
|
def __init__(self, model, ds_site: xr.Dataset): |
|
"""A class to conveniently make and process predictions from batches |
|
|
|
Args: |
|
model: PVNet site level model |
|
ds_site:xarray dataset of pv site true values and capacities |
|
""" |
|
self.model = model |
|
self.ds_site = ds_site |
|
|
|
def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: |
|
"""Run the batch through the model and compile the predictions into an xarray DataArray |
|
|
|
Args: |
|
batch: A batch of samples with inputs for each site for the same init-time |
|
|
|
Returns: |
|
xarray.Dataset of all site and national forecasts for the batch |
|
""" |
|
|
|
id0 = batch[BatchKey.pv_t0_idx] |
|
|
|
t0 = batch[BatchKey.pv_time_utc].cpu().numpy().astype("datetime64[s]")[0, id0] |
|
n_valid_times = len(batch[BatchKey.pv_time_utc][0, id0 + 1 :]) |
|
model = self.model |
|
|
|
|
|
valid_times = pd.to_datetime( |
|
[t0 + np.timedelta64((i + 1) * FREQ_MINS, "m") for i in range(n_valid_times)] |
|
) |
|
|
|
|
|
site_capacities = self.ds_site.nominal_capacity_wp.values |
|
|
|
elevation = batch[BatchKey.pv_solar_elevation] * ELEVATION_STD + ELEVATION_MEAN |
|
|
|
elevation = elevation[:, id0 + 1 :] |
|
|
|
|
|
da_sundown_mask = xr.DataArray( |
|
data=elevation < MIN_DAY_ELEVATION, |
|
dims=["pv_system_id", "target_datetime_utc"], |
|
coords=dict( |
|
pv_system_id=ALL_SITE_IDS, |
|
target_datetime_utc=valid_times, |
|
), |
|
) |
|
|
|
with torch.no_grad(): |
|
|
|
device_batch = copy_batch_to_device(batch_to_tensor(batch), device) |
|
y_normed_site = model(device_batch).detach().cpu().numpy() |
|
da_normed_site = preds_to_dataarray(y_normed_site, model, valid_times, ALL_SITE_IDS) |
|
|
|
|
|
da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None] |
|
|
|
|
|
da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0) |
|
|
|
da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords( |
|
init_time_utc=np.array([t0], dtype="datetime64[ns]") |
|
) |
|
|
|
return da_abs_site |
|
|
|
|
|
def get_datapipe(config_path: str) -> NumpyBatch: |
|
"""Construct datapipe yielding batches of concurrent samples for all sites |
|
|
|
Args: |
|
config_path: Path to the data configuration file |
|
|
|
Returns: |
|
NumpyBatch: Concurrent batch of samples for each site |
|
""" |
|
|
|
|
|
location_pipe, t0_datapipe = get_loctimes_datapipes(config_path) |
|
|
|
|
|
|
|
num_batches = len(t0_datapipe) // len(ALL_SITE_IDS) |
|
|
|
data_pipeline = construct_sliced_data_pipeline( |
|
config_path, |
|
location_pipe, |
|
t0_datapipe, |
|
) |
|
|
|
config = load_yaml_configuration(config_path) |
|
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( |
|
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"), |
|
history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"), |
|
time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"), |
|
) |
|
|
|
data_pipeline = DictDatasetIterDataPipe( |
|
{k: v for k, v in data_pipeline.items() if k != "config"}, |
|
).map(split_dataset_dict_dp) |
|
|
|
data_pipeline = data_pipeline.pvnet_site_convert_to_numpy_batch() |
|
|
|
|
|
|
|
data_pipeline = ( |
|
data_pipeline.batch(len(ALL_SITE_IDS)) |
|
.map(stack_np_examples_into_batch) |
|
.map(batch_to_tensor) |
|
) |
|
data_pipeline = data_pipeline.set_length(num_batches) |
|
|
|
return data_pipeline |
|
|
|
|
|
@hydra.main(config_path="../configs", config_name="config.yaml", version_base="1.2") |
|
def main(config: DictConfig): |
|
"""Runs the backtest""" |
|
|
|
dataloader_kwargs = dict( |
|
shuffle=False, |
|
batch_size=None, |
|
sampler=None, |
|
batch_sampler=None, |
|
|
|
num_workers=config.datamodule.num_workers, |
|
collate_fn=None, |
|
pin_memory=False, |
|
drop_last=False, |
|
timeout=0, |
|
worker_init_fn=None, |
|
prefetch_factor=config.datamodule.prefetch_factor, |
|
persistent_workers=False, |
|
) |
|
|
|
|
|
os.makedirs(output_dir) |
|
|
|
|
|
|
|
batch_pipe = get_datapipe(config.datamodule.configuration) |
|
num_batches = len(batch_pipe) |
|
|
|
ds_site = get_sites_ds(config.datamodule.configuration) |
|
|
|
dataloader = DataLoader(batch_pipe, **dataloader_kwargs) |
|
|
|
if model_chckpoint_dir: |
|
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) |
|
elif hf_model_id: |
|
model = load_model_from_hf(hf_model_id, hf_revision, hf_token) |
|
else: |
|
raise ValueError("Provide a model checkpoint or a HuggingFace model") |
|
|
|
model = model.eval().to(device) |
|
|
|
|
|
model_pipe = ModelPipe(model, ds_site) |
|
|
|
pbar = tqdm(total=num_batches) |
|
for i, batch in zip(range(num_batches), dataloader): |
|
try: |
|
|
|
ds_abs_all = model_pipe.predict_batch(batch) |
|
|
|
t0 = ds_abs_all.init_time_utc.values[0] |
|
|
|
|
|
filename = f"{output_dir}/{t0}.nc" |
|
ds_abs_all.to_netcdf(filename) |
|
|
|
pbar.update() |
|
except Exception as e: |
|
print(f"Exception {e} at batch {i}") |
|
pass |
|
|
|
|
|
pbar.close() |
|
del dataloader |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|