""" Deep learning framework for sound speed inversion """ import json import git import argparse import pathlib import glob import os import h5py import loader import run_logger import net import torch import torch.utils.data as td import pytorch_lightning as pl # ---------------------------- # Setup command line arguments # ---------------------------- parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ') parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty') parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--experiment', default='DeepLearning US', help='experiment name') parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads') parser.add_argument('--load_ver', type=str, help='Network weights to load') parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)') parser.add_argument('--conf_export', type=str, help='Filename where to store settings') parser = pl.Trainer.add_argparse_args(parser) parser = loader.Loader.add_argparse_args(parser) parser = net.Net.add_model_specific_args(parser) parser = run_logger.ImgCB.add_argparse_args(parser) args = parser.parse_args() if args.conf is not None: for conf_fname in args.conf: with open(conf_fname, 'r') as f: parser.set_defaults(**json.load(f)) # Reload arguments to override config file values with command line values args = parser.parse_args() if args.conf_export is not None: with open(args.conf_export, 'w') as f: json.dump(vars(args), f, indent=4, sort_keys=True) if args.test_files is None and args.train_files is None: raise ValueError('At least one of train files or test files is required') # ---------------------------- # Load data # ---------------------------- ld = loader.Loader(**vars(args)) test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files) for name, tensor in ( ('test_input', test_input), ('test_label', test_label), ('train_input', train_input), ('train_label', train_label)): print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}') loaders = [] if args.train_files is not None: if train_input is None or train_label is None or (test_input is not None and test_label is None): raise ValueError('Training requires labeled data') train_ds = td.TensorDataset(train_input, train_label) loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True)) if args.test_files is not None: ds = [test_input] if test_label is not None: ds.append(test_label) test_ds = td.TensorDataset(*ds) loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True)) # ---------------------------- # Run # ---------------------------- if args.train_files is not None: if args.tags is None: args.tags = {} elif type(args.tags) == str: args.tags = json.loads(args.tags) try: repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha args.tags.update({'commit': sha}) except: print('Not a git repo, not logging commit ID') mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags) mfl.log_hyperparams(args) path = pathlib.Path(__file__).parent.absolute() files = glob.glob(str(path) + os.sep + '*.py') for f in files: mfl.experiment.log_artifact(mfl.run_id, f, 'source') chkpnt_cb = pl.callbacks.ModelCheckpoint( monitor='validate_mean', verbose=True, save_top_k=1, save_weights_only=True, mode='min', every_n_train_steps=1, filename='{epoch}-{validate_mean}-{train_mean}', ) img_cb = run_logger.ImgCB(**vars(args)) lr_logger = pl.callbacks.LearningRateMonitor() args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]}) else: if os.path.exists(args.test_fname): os.remove(args.test_fname) args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]}) if test_label is not None: args.n_outputs = test_label.shape[1] elif train_label is not None: args.n_outputs = train_label.shape[1] if test_input is not None: args.n_inputs = test_input.shape[1] elif train_input is not None: args.n_inputs = train_input.shape[1] n = net.Net(**vars(args)) if args.load_ver is not None: t = torch.load(args.load_ver, map_location='cpu')['state_dict'] n.load_state_dict(t) trainer = pl.Trainer.from_argparse_args(args) if args.train_files is not None: trainer.fit(n, *loaders) print(chkpnt_cb.best_model_path) elif args.label_vars: trainer.test(n, *loaders) else: predictions = trainer.predict(n, *loaders) with h5py.File(args.test_fname, "w") as F: F["predictions"] = torch.cat(predictions).numpy()