Spaces:
Sleeping
Sleeping
import torch | |
import os.path as osp | |
import os | |
import sys | |
from src.dataset.dataset import SimpleIterDataset | |
from src.utils.utils import to_filelist | |
from pathlib import Path | |
import pickle | |
from src.utils.paths import get_path | |
import argparse | |
import numpy as np | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", type=str) | |
parser.add_argument("--output", type=str) | |
parser.add_argument("--overwrite", action="store_true") | |
parser.add_argument("--dataset-cap", type=int, default=-1) | |
parser.add_argument("--v2", action="store_true") # V2 means that the dataset also stores parton-level and genParticles | |
parser.add_argument("--delphes", action="store_true") | |
args = parser.parse_args() | |
path = get_path(args.input, "data") | |
def remove_from_list(lst): | |
out = [] | |
for item in lst: | |
if item in ["hgcal", "data.txt", "test_file.root"]: | |
continue | |
out.append(item) | |
return out | |
def preprocess_dataset(datasets, output_path, config_file, dataset_cap): | |
#datasets = os.listdir(path) | |
#datasets = [os.path.join(path, x) for x in datasets] | |
class Args: | |
def __init__(self): | |
self.data_train = datasets | |
self.data_val = datasets | |
#self.data_train = files_train | |
self.data_config = config_file | |
self.extra_selection = None | |
self.train_val_split = 1.0 | |
self.data_fraction = 1 | |
self.file_fraction = 1 | |
self.fetch_by_files = False | |
self.fetch_step = 1 | |
self.steps_per_epoch = None | |
self.in_memory = False | |
self.local_rank = None | |
self.copy_inputs = False | |
self.no_remake_weights = False | |
self.batch_size = 10 | |
self.num_workers = 0 | |
self.demo = False | |
self.laplace = False | |
self.diffs = False | |
self.class_edges = False | |
args = Args() | |
train_range = (0, args.train_val_split) | |
train_file_dict, train_files = to_filelist(args, 'train') | |
train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True, | |
extra_selection=args.extra_selection, | |
remake_weights=True, | |
load_range_and_fraction=(train_range, args.data_fraction), | |
file_fraction=args.file_fraction, | |
fetch_by_files=args.fetch_by_files, | |
fetch_step=args.fetch_step, | |
infinity_mode=False, | |
in_memory=args.in_memory, | |
async_load=False, | |
name='train', jets=True) | |
iterator = iter(train_data) | |
from time import time | |
t0 = time() | |
data = [] | |
count = 0 | |
while True: | |
try: | |
i = next(iterator) | |
data.append(i) | |
count += 1 | |
if dataset_cap > 0 and count >= dataset_cap: | |
break | |
except StopIteration: | |
break | |
t1 = time() | |
print("Took", t1-t0, "s -", datasets[0]) | |
from src.dataset.functions_data import concat_events | |
events = concat_events(data) # TODO: This can be done in a nicer way, using less memory (?) | |
result = events.serialize() | |
dir_name = datasets[0].split("/")[-2] | |
save_to_dir = os.path.join(output_path, dir_name) | |
Path(save_to_dir).mkdir(parents=True, exist_ok=True) | |
for key in result[0]: | |
with open(osp.join(save_to_dir, key + ".pkl"), "wb") as f: | |
#pickle.dump(result[0][key], f) #save with torch for mmap | |
#torch.save(result[0][key], f) | |
np.save(f, result[0][key].numpy()) | |
with open(osp.join(save_to_dir, "metadata.pkl"), "wb") as f: | |
pickle.dump(result[1], f) | |
print("Saved to", save_to_dir) | |
print("Finished", dir_name) | |
''' | |
from src.dataset.functions_data import EventCollection, EventJets, Event | |
from src.dataset.dataset import EventDataset | |
t2 = time() | |
data1 = [] | |
for event in EventDataset(result[0], result[1]): | |
data1.append(event) | |
t3 = time() | |
print("Took", t3-t2, "s") | |
print("Done") | |
''' | |
output = get_path(args.output, "preprocessed_data") | |
for dir in os.listdir(path): | |
if args.overwrite or not os.path.exists(os.path.join(output, dir)): | |
config = get_path('config_files/config_jets.yaml', 'code') | |
if args.v2: | |
delphes_suffix = "" | |
if args.delphes: | |
delphes_suffix = "_delphes" | |
config = get_path(f'config_files/config_jets_2{delphes_suffix}.yaml', 'code') | |
for i, file in enumerate(sorted(os.listdir(os.path.join(path, dir)))): | |
print("Preprocessing file", file) | |
preprocess_dataset([os.path.join(path, dir, file)], output + "_part"+str(i), config_file=config, dataset_cap=args.dataset_cap) | |
else: | |
print("Skipping", dir + ", already exists") | |
# flush | |
sys.stdout.flush() | |