pvnet_nl / scripts /save_concurrent_samples.py
peterdudfield's picture
Upload folder using huggingface_hub
eff4f1d verified
raw
history blame
5.99 kB
"""
Constructs batches where each batch includes all GSPs and only a single timestamp.
Currently a slightly hacky implementation due to the way the configs are done. This script will use
the same config file currently set to train the model. In the datamodule config it is possible
to set the batch_output_dir and number of train/val batches, they can also be overriden in the
command as shown in the example below.
use:
```
python save_concurrent_samples.py \
+datamodule.sample_output_dir="/mnt/disks/concurrent_batches/concurrent_samples_sat_pred_test" \
+datamodule.num_train_samples=20 \
+datamodule.num_val_samples=20
```
"""
# Ensure this block of code runs only in the main process to avoid issues with worker processes.
if __name__ == "__main__":
import torch.multiprocessing as mp
# Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be
# compatible with dask's multiprocessing.
mp.set_start_method("forkserver")
# Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is
# important because libraries like Zarr may open many files, which can exhaust the file
# descriptor limit if too many workers are used.
mp.set_sharing_strategy("file_system")
import logging
import os
import shutil
import sys
import warnings
import hydra
import numpy as np
import torch
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from pvnet.utils import print_config
# ------- filter warning and set up config -------
warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
# -------------------------------------------------
class SaveFuncFactory:
"""Factory for creating a function to save a sample to disk."""
def __init__(self, save_dir: str):
"""Factory for creating a function to save a sample to disk."""
self.save_dir = save_dir
def __call__(self, sample, sample_num: int):
"""Save a sample to disk"""
torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt")
def save_samples_with_dataloader(
dataset: Dataset,
save_dir: str,
num_samples: int,
dataloader_kwargs: dict,
) -> None:
"""Save samples from a dataset using a dataloader."""
save_func = SaveFuncFactory(save_dir)
gsp_ids = np.array([loc.id for loc in dataset.locations])
dataloader = DataLoader(dataset, **dataloader_kwargs)
pbar = tqdm(total=num_samples)
for i, sample in zip(range(num_samples), dataloader):
check_sample(sample, gsp_ids)
save_func(sample, i)
pbar.update()
pbar.close()
def check_sample(sample, gsp_ids):
"""Check if sample is valid concurrent batch for all GSPs"""
# Check all GSP IDs are included and in correct order
assert (sample["gsp_id"].flatten().numpy() == gsp_ids).all()
# Check all times are the same
assert len(np.unique(sample["gsp_time_utc"][:, 0].numpy())) == 1
@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2")
def main(config: DictConfig) -> None:
"""Constructs and saves validation and training samples."""
config_dm = config.datamodule
print_config(config, resolve=False)
# Set up directory
os.makedirs(config_dm.sample_output_dir, exist_ok=False)
# Copy across configs which define the samples into the new sample directory
with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f:
f.write(OmegaConf.to_yaml(config_dm))
shutil.copyfile(
config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml"
)
# Define the keywargs going into the train and val dataloaders
dataloader_kwargs = dict(
shuffle=True,
batch_size=None,
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
collate_fn=None,
pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False, # Not needed since we only enter the dataloader loop once
)
if config_dm.num_val_samples > 0:
print("----- Saving val samples -----")
val_output_dir = f"{config_dm.sample_output_dir}/val"
# Make directory for val samples
os.mkdir(val_output_dir)
# Get the dataset
val_dataset = PVNetUKConcurrentDataset(
config_dm.configuration,
start_time=config_dm.val_period[0],
end_time=config_dm.val_period[1],
)
# Save samples
save_samples_with_dataloader(
dataset=val_dataset,
save_dir=val_output_dir,
num_samples=config_dm.num_val_samples,
dataloader_kwargs=dataloader_kwargs,
)
del val_dataset
if config_dm.num_train_samples > 0:
print("----- Saving train samples -----")
train_output_dir = f"{config_dm.sample_output_dir}/train"
# Make directory for train samples
os.mkdir(train_output_dir)
# Get the dataset
train_dataset = PVNetUKConcurrentDataset(
config_dm.configuration,
start_time=config_dm.train_period[0],
end_time=config_dm.train_period[1],
)
# Save samples
save_samples_with_dataloader(
dataset=train_dataset,
save_dir=train_output_dir,
num_samples=config_dm.num_train_samples,
dataloader_kwargs=dataloader_kwargs,
)
del train_dataset
print("----- Saving complete -----")
if __name__ == "__main__":
main()