| import sys |
| import os |
| file_path = os.getcwd() |
| sys.path.append(file_path) |
|
|
| import root_gnn_base.utils as utils |
| import argparse |
| from root_gnn_base.batched_dataset import PreBatchedDataset |
| from root_gnn_base.batched_dataset import LazyPreBatchedDataset |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| add_arg = parser.add_argument |
| add_arg('--config', type=str, required=True) |
| add_arg('--dataset', type=str, required=True) |
| add_arg('--chunk', type=int, default=0) |
| add_arg('--shuffle_mode', action='store_true', help='Shuffle the dataset before training.') |
| add_arg('--drop_last', action='store_false', help='Set drop_last to False if the flag is provided. Defaults to True.') |
| args = parser.parse_args() |
|
|
| config = utils.load_config(args.config) |
| dset_config = config['Datasets'][args.dataset] |
| batch_size = config['Training']['batch_size'] |
| if not args.shuffle_mode: |
| dset = utils.buildFromConfig(dset_config, {'process_chunks': [args.chunk,]}) |
| else: |
| dset = utils.buildFromConfig(dset_config) |
| if 'batch_size' in dset_config: |
| batch_size = dset_config['batch_size'] |
|
|
| shuffle_chunks = dset_config.get('shuffle_chunks', 10) |
| padding_mode = dset_config.get('padding_mode', 'STEPS') |
| fold_conf = dset_config["folding"] |
| print(f"shuffle_chunks = {shuffle_chunks}, args.chunk = {args.chunk}, padding_mode = {padding_mode}") |
| if dset_config["class"] == "LazyMultiLabelDataset": |
| LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size'] ) |
| LazyPreBatchedDataset(start_dataset = dset, batch_size = batch_size, mask_fn = utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last, hidden_size=config['Model']['args']['hid_size']) |
|
|
| else: |
| PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "train"), suffix = utils.fold_selection_name(fold_conf, "train"), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size']) |
| PreBatchedDataset(dset, batch_size, utils.fold_selection(fold_conf, "test"), suffix = utils.fold_selection_name(fold_conf, 'test'), chunks = shuffle_chunks, chunkno = args.chunk, padding_mode = padding_mode, drop_last=args.drop_last,hidden_size=config['Model']['args']['hid_size'] ) |
|
|
| if __name__ == "__main__": |
| main() |
|
|