| import sys |
| import os |
| file_path = os.getcwd() |
| sys.path.append(file_path) |
| import os |
| import argparse |
| import yaml |
| import gc |
|
|
| import torch |
| import dgl |
| from dgl.data import DGLDataset |
| from dgl.dataloading import GraphDataLoader |
| from torch.utils.data import SubsetRandomSampler, SequentialSampler |
|
|
| class CustomPreBatchedDataset(DGLDataset): |
| def __init__(self, start_dataset, batch_size, chunkno=0, chunks=1, mask_fn=None, drop_last=False, shuffle=False, **kwargs): |
| self.start_dataset = start_dataset |
| self.batch_size = batch_size |
| self.mask_fn = mask_fn or (lambda x: torch.ones(len(x), dtype=torch.bool)) |
| self.drop_last = drop_last |
| self.shuffle = shuffle |
| self.chunkno = chunkno |
| self.chunks = chunks |
| super().__init__(name=start_dataset.name + '_custom_prebatched', save_dir=start_dataset.save_dir) |
|
|
| def process(self): |
| mask = self.mask_fn(self.start_dataset) |
| indices = torch.arange(len(self.start_dataset))[mask] |
| print(f"Number of elements after masking: {len(indices)}") |
|
|
| |
| total = len(indices) |
| if self.chunks == 1: |
| chunk_indices = indices |
| print(f"Chunks=1, using all {total} indices.") |
| else: |
| chunk_size = (total + self.chunks - 1) // self.chunks |
| start = self.chunkno * chunk_size |
| end = min((self.chunkno + 1) * chunk_size, total) |
| chunk_indices = indices[start:end] |
| print(f"Working on chunk {self.chunkno}/{self.chunks}: indices {start}:{end} (total {len(chunk_indices)})") |
|
|
| if self.shuffle: |
| sampler = SubsetRandomSampler(chunk_indices) |
| else: |
| sampler = SequentialSampler(chunk_indices) |
|
|
| self.dataloader = GraphDataLoader( |
| self.start_dataset, |
| sampler=sampler, |
| batch_size=self.batch_size, |
| drop_last=self.drop_last |
| ) |
|
|
| def __getitem__(self, idx): |
| if isinstance(idx, int): |
| idx = [idx] |
| sampler = SequentialSampler(idx) |
| dloader = GraphDataLoader(self.start_dataset, sampler=sampler, batch_size=self.batch_size, drop_last=False) |
| return next(iter(dloader)) |
|
|
| def __len__(self): |
| mask = self.mask_fn(self.start_dataset) |
| indices = torch.arange(len(self.start_dataset))[mask] |
| total = len(indices) |
| if self.chunks == 1: |
| return total |
| chunk_size = (total + self.chunks - 1) // self.chunks |
| start = self.chunkno * chunk_size |
| end = min((self.chunkno + 1) * chunk_size, total) |
| return end - start |
|
|
| def include_config(conf): |
| if 'include' in conf: |
| for i in conf['include']: |
| with open(i) as f: |
| conf.update(yaml.load(f, Loader=yaml.FullLoader)) |
| del conf['include'] |
|
|
| def load_config(config_file): |
| with open(config_file) as f: |
| conf = yaml.load(f, Loader=yaml.FullLoader) |
| include_config(conf) |
| return conf |
|
|
| def main(): |
|
|
| parser = argparse.ArgumentParser() |
| add_arg = parser.add_argument |
| add_arg('--config', type=str, nargs='+', required=True, help="List of config files") |
| add_arg('--target', type=str, required=True) |
| add_arg('--destination', type=str, default='') |
| add_arg('--chunkno', type=int, default=0) |
| add_arg('--chunks', type=int, default=1) |
| add_arg('--write', action='store_true') |
| add_arg('--ckpt', type=int, default=-1) |
| add_arg('--var', type=str, default='Test_AUC') |
| add_arg('--mode', type=str, default='max') |
| add_arg('--clobber', action='store_true') |
| add_arg('--tree', type=str, default='') |
| add_arg('--branch_name', type=str, nargs='+', required=True, help="List of branch names corresponding to configs") |
| args = parser.parse_args() |
|
|
| if(len(args.config) != len(args.branch_name)): |
| print(f"configs and branch names do not match") |
| return |
|
|
| config = load_config(args.config[0]) |
|
|
| |
| if args.destination == '': |
| base_dest = os.path.join(config['Training_Directory'], 'inference/', os.path.split(args.target)[1]) |
| else: |
| base_dest = args.destination |
|
|
| base_dest = base_dest.replace('.root', '').replace('.npz', '') |
| if args.chunks > 1: |
| chunked_dest = f"{base_dest}_chunk{args.chunkno}" |
| else: |
| chunked_dest = base_dest |
| chunked_dest += '.root' if args.write else '.npz' |
| args.destination = chunked_dest |
|
|
| |
| if os.path.exists(args.destination): |
| print(f'File {args.destination} already exists.') |
| if args.clobber: |
| print('Clobbering.') |
| else: |
| print('Exiting.') |
| return |
| else: |
| print(f'Writing to {args.destination}') |
|
|
| import time |
| start = time.time() |
| import ROOT |
| import torch |
| from array import array |
| import numpy as np |
| from root_gnn_base import batched_dataset as dataset |
| from root_gnn_base import utils |
| end = time.time() |
| print('Imports finished in {:.2f} seconds'.format(end - start)) |
|
|
| start = time.time() |
| dset_config = config['Datasets'][list(config['Datasets'].keys())[0]] |
| if dset_config['class'] == 'LazyDataset': |
| dset_config['class'] = 'EdgeDataset' |
| elif dset_config['class'] == 'LazyMultiLabelDataset': |
| dset_config['class'] = 'MultiLabelDataset' |
| elif dset_config['class'] == 'PhotonIDDataset': |
| dset_config['class'] = 'UnlazyPhotonIDDataset' |
| elif dset_config['class'] == 'kNNDataset': |
| dset_config['class'] = 'UnlazyKNNDataset' |
| dset_config['args']['raw_dir'] = os.path.split(args.target)[0] |
| dset_config['args']['file_names'] = os.path.split(args.target)[1] |
| dset_config['args']['save'] = False |
| dset_config['args']['chunks'] = args.chunks |
| dset_config['args']['process_chunks'] = [args.chunkno,] |
| dset_config['args']['selections'] = [] |
|
|
| dset_config['args']['save_dir'] = os.path.dirname(args.destination) |
|
|
| if args.tree != '': |
| dset_config['args']['tree_name'] = args.tree |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| dstart = time.time() |
| dset = utils.buildFromConfig(dset_config) |
| dend = time.time() |
| print('Dataset finished in {:.2f} seconds'.format(dend - dstart)) |
|
|
| print(dset) |
|
|
| batch_size = config['Training']['batch_size'] |
| lstart = time.time() |
| loader = CustomPreBatchedDataset( |
| dset, |
| batch_size, |
| chunkno=args.chunkno, |
| chunks=args.chunks |
| ) |
| loader.process() |
| lend = time.time() |
| print('Loader finished in {:.2f} seconds'.format(lend - lstart)) |
| sample_graph, _, _, global_sample = loader[0] |
|
|
| print('dset length =', len(dset)) |
| print('loader length =', len(loader)) |
|
|
| all_scores = {} |
| all_labels = {} |
| all_tracking = {} |
| with torch.no_grad(): |
| for config_file, branch in zip(args.config, args.branch_name): |
| config = load_config(config_file) |
| model = utils.buildFromConfig(config['Model'], {'sample_graph' : sample_graph, 'sample_global': global_sample}).to(device) |
| if args.ckpt < 0: |
| ep, checkpoint = utils.get_best_epoch(config, var=args.var, mode='max', device=device) |
| else: |
| ep, checkpoint = utils.get_specific_epoch(config, args.ckpt, device=device) |
| |
| mds_copy = {} |
| for key in checkpoint['model_state_dict'].keys(): |
| newkey = key.replace('module.', '') |
| newkey = newkey.replace('_orig_mod.', '') |
| mds_copy[newkey] = checkpoint['model_state_dict'][key] |
| model.load_state_dict(mds_copy) |
| model.eval() |
|
|
| end = time.time() |
| print('Model and dataset finished in {:.2f} seconds'.format(end - start)) |
| print('Starting inference') |
| start = time.time() |
|
|
| finish_fn = torch.nn.Sigmoid() |
| if 'Loss' in config: |
| finish_fn = utils.buildFromConfig(config['Loss']['finish']) |
|
|
| scores = [] |
| labels = [] |
| tracking_info = [] |
| ibatch = 0 |
|
|
| for batch, label, track, globals in loader.dataloader: |
| batch = batch.to(device) |
| pred = model(batch, globals.to(device)) |
| ibatch += 1 |
| if (finish_fn.__class__.__name__ == "ContrastiveClusterFinish"): |
| scores.append(pred.detach().cpu().numpy()) |
| else: |
| scores.append(finish_fn(pred).detach().cpu().numpy()) |
| labels.append(label.detach().cpu().numpy()) |
| tracking_info.append(track.detach().cpu().numpy()) |
|
|
| score_size = scores[0].shape[1] if len(scores[0].shape) > 1 else 1 |
| scores = np.concatenate(scores) |
| labels = np.concatenate(labels) |
| tracking_info = np.concatenate(tracking_info) |
| end = time.time() |
|
|
| print('Inference finished in {:.2f} seconds'.format(end - start)) |
| all_scores[branch] = scores |
| all_labels[branch] = labels |
| all_tracking[branch] = tracking_info |
|
|
| if args.write: |
| from ROOT import std |
| |
| infile = ROOT.TFile.Open(args.target) |
| tree = infile.Get(dset_config['args']['tree_name']) |
|
|
| |
| os.makedirs(os.path.split(args.destination)[0], exist_ok=True) |
|
|
| |
| outfile = ROOT.TFile.Open(args.destination, 'RECREATE') |
|
|
| |
| outtree = tree.CloneTree(0) |
|
|
| |
| branch_vectors = {} |
| for branch, scores in all_scores.items(): |
| if isinstance(scores[0], (list, tuple, np.ndarray)) and len(scores[0]) > 1: |
| |
| branch_vectors[branch] = std.vector('float')() |
| outtree.Branch(branch, branch_vectors[branch]) |
| else: |
| |
| branch_vectors[branch] = array('f', [0]) |
| outtree.Branch(branch, branch_vectors[branch], f'{branch}/F') |
|
|
| |
| for i in range(tree.GetEntries()): |
| tree.GetEntry(i) |
|
|
| for branch, scores in all_scores.items(): |
| branch_data = branch_vectors[branch] |
| if isinstance(branch_data, array): |
| branch_data[0] = float(scores[i]) |
| else: |
| branch_data.clear() |
| for value in scores[i]: |
| branch_data.push_back(float(value)) |
|
|
| outtree.Fill() |
|
|
| |
| print(f'Writing to file {args.destination}') |
| print(f'Input entries: {tree.GetEntries()}, Output entries: {outtree.GetEntries()}') |
| print(f'Wrote scores to {args.branch_name}') |
| outtree.Write() |
| outfile.Close() |
| infile.Close() |
| else: |
| os.makedirs(os.path.split(args.destination)[0], exist_ok=True) |
| np.savez(args.destination, scores=all_scores, labels=all_labels, tracking_info=all_tracking) |
|
|
| if __name__ == '__main__': |
| main() |