|
""" |
|
Constructs batches where each batch includes all GSPs and only a single timestamp. |
|
|
|
Currently a slightly hacky implementation due to the way the configs are done. This script will use |
|
the same config file currently set to train the model. In the datamodule config it is possible |
|
to set the batch_output_dir and number of train/val batches, they can also be overriden in the |
|
command as shown in the example below. |
|
|
|
use: |
|
``` |
|
python save_concurrent_samples.py \ |
|
+datamodule.sample_output_dir="/mnt/disks/concurrent_batches/concurrent_samples_sat_pred_test" \ |
|
+datamodule.num_train_samples=20 \ |
|
+datamodule.num_val_samples=20 |
|
``` |
|
|
|
""" |
|
|
|
if __name__ == "__main__": |
|
import torch.multiprocessing as mp |
|
|
|
|
|
|
|
mp.set_start_method("forkserver") |
|
|
|
|
|
|
|
|
|
mp.set_sharing_strategy("file_system") |
|
|
|
|
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
import warnings |
|
|
|
import hydra |
|
import numpy as np |
|
import torch |
|
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset |
|
from omegaconf import DictConfig, OmegaConf |
|
from sqlalchemy import exc as sa_exc |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
|
|
from pvnet.utils import print_config |
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=sa_exc.SAWarning) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logging.basicConfig(stream=sys.stdout, level=logging.ERROR) |
|
|
|
|
|
|
|
|
|
class SaveFuncFactory: |
|
"""Factory for creating a function to save a sample to disk.""" |
|
|
|
def __init__(self, save_dir: str): |
|
"""Factory for creating a function to save a sample to disk.""" |
|
self.save_dir = save_dir |
|
|
|
def __call__(self, sample, sample_num: int): |
|
"""Save a sample to disk""" |
|
torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") |
|
|
|
|
|
def save_samples_with_dataloader( |
|
dataset: Dataset, |
|
save_dir: str, |
|
num_samples: int, |
|
dataloader_kwargs: dict, |
|
) -> None: |
|
"""Save samples from a dataset using a dataloader.""" |
|
save_func = SaveFuncFactory(save_dir) |
|
|
|
gsp_ids = np.array([loc.id for loc in dataset.locations]) |
|
|
|
dataloader = DataLoader(dataset, **dataloader_kwargs) |
|
|
|
pbar = tqdm(total=num_samples) |
|
for i, sample in zip(range(num_samples), dataloader): |
|
check_sample(sample, gsp_ids) |
|
save_func(sample, i) |
|
pbar.update() |
|
pbar.close() |
|
|
|
|
|
def check_sample(sample, gsp_ids): |
|
"""Check if sample is valid concurrent batch for all GSPs""" |
|
|
|
assert (sample["gsp_id"].flatten().numpy() == gsp_ids).all() |
|
|
|
assert len(np.unique(sample["gsp_time_utc"][:, 0].numpy())) == 1 |
|
|
|
|
|
@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") |
|
def main(config: DictConfig) -> None: |
|
"""Constructs and saves validation and training samples.""" |
|
config_dm = config.datamodule |
|
|
|
print_config(config, resolve=False) |
|
|
|
|
|
os.makedirs(config_dm.sample_output_dir, exist_ok=False) |
|
|
|
|
|
with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f: |
|
f.write(OmegaConf.to_yaml(config_dm)) |
|
|
|
shutil.copyfile( |
|
config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml" |
|
) |
|
|
|
|
|
dataloader_kwargs = dict( |
|
shuffle=True, |
|
batch_size=None, |
|
sampler=None, |
|
batch_sampler=None, |
|
num_workers=config_dm.num_workers, |
|
collate_fn=None, |
|
pin_memory=False, |
|
drop_last=False, |
|
timeout=0, |
|
worker_init_fn=None, |
|
prefetch_factor=config_dm.prefetch_factor, |
|
persistent_workers=False, |
|
) |
|
|
|
if config_dm.num_val_samples > 0: |
|
print("----- Saving val samples -----") |
|
|
|
val_output_dir = f"{config_dm.sample_output_dir}/val" |
|
|
|
|
|
os.mkdir(val_output_dir) |
|
|
|
|
|
val_dataset = PVNetUKConcurrentDataset( |
|
config_dm.configuration, |
|
start_time=config_dm.val_period[0], |
|
end_time=config_dm.val_period[1], |
|
) |
|
|
|
|
|
save_samples_with_dataloader( |
|
dataset=val_dataset, |
|
save_dir=val_output_dir, |
|
num_samples=config_dm.num_val_samples, |
|
dataloader_kwargs=dataloader_kwargs, |
|
) |
|
|
|
del val_dataset |
|
|
|
if config_dm.num_train_samples > 0: |
|
print("----- Saving train samples -----") |
|
|
|
train_output_dir = f"{config_dm.sample_output_dir}/train" |
|
|
|
|
|
os.mkdir(train_output_dir) |
|
|
|
|
|
train_dataset = PVNetUKConcurrentDataset( |
|
config_dm.configuration, |
|
start_time=config_dm.train_period[0], |
|
end_time=config_dm.train_period[1], |
|
) |
|
|
|
|
|
save_samples_with_dataloader( |
|
dataset=train_dataset, |
|
save_dir=train_output_dir, |
|
num_samples=config_dm.num_train_samples, |
|
dataloader_kwargs=dataloader_kwargs, |
|
) |
|
|
|
del train_dataset |
|
|
|
print("----- Saving complete -----") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|