|
""" |
|
Deep learning framework for sound speed inversion |
|
""" |
|
|
|
import json |
|
import git |
|
import argparse |
|
import pathlib |
|
import glob |
|
import os |
|
import h5py |
|
|
|
import loader |
|
import run_logger |
|
import net |
|
|
|
import torch |
|
import torch.utils.data as td |
|
import pytorch_lightning as pl |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ') |
|
parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty') |
|
parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten') |
|
|
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size') |
|
|
|
parser.add_argument('--experiment', default='DeepLearning US', help='experiment name') |
|
parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads') |
|
|
|
parser.add_argument('--load_ver', type=str, help='Network weights to load') |
|
|
|
parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)') |
|
parser.add_argument('--conf_export', type=str, help='Filename where to store settings') |
|
|
|
|
|
parser = pl.Trainer.add_argparse_args(parser) |
|
parser = loader.Loader.add_argparse_args(parser) |
|
parser = net.Net.add_model_specific_args(parser) |
|
parser = run_logger.ImgCB.add_argparse_args(parser) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.conf is not None: |
|
for conf_fname in args.conf: |
|
with open(conf_fname, 'r') as f: |
|
parser.set_defaults(**json.load(f)) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
if args.conf_export is not None: |
|
with open(args.conf_export, 'w') as f: |
|
json.dump(vars(args), f, indent=4, sort_keys=True) |
|
|
|
if args.test_files is None and args.train_files is None: |
|
raise ValueError('At least one of train files or test files is required') |
|
|
|
|
|
|
|
|
|
|
|
ld = loader.Loader(**vars(args)) |
|
test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files) |
|
|
|
for name, tensor in ( |
|
('test_input', test_input), |
|
('test_label', test_label), |
|
('train_input', train_input), |
|
('train_label', train_label)): |
|
print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}') |
|
|
|
loaders = [] |
|
|
|
if args.train_files is not None: |
|
if train_input is None or train_label is None or (test_input is not None and test_label is None): |
|
raise ValueError('Training requires labeled data') |
|
|
|
train_ds = td.TensorDataset(train_input, train_label) |
|
loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True)) |
|
|
|
if args.test_files is not None: |
|
ds = [test_input] |
|
if test_label is not None: |
|
ds.append(test_label) |
|
|
|
test_ds = td.TensorDataset(*ds) |
|
loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True)) |
|
|
|
|
|
|
|
|
|
|
|
if args.train_files is not None: |
|
if args.tags is None: |
|
args.tags = {} |
|
elif type(args.tags) == str: |
|
args.tags = json.loads(args.tags) |
|
|
|
try: |
|
repo = git.Repo(search_parent_directories=True) |
|
sha = repo.head.object.hexsha |
|
args.tags.update({'commit': sha}) |
|
except: |
|
print('Not a git repo, not logging commit ID') |
|
|
|
mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags) |
|
mfl.log_hyperparams(args) |
|
|
|
path = pathlib.Path(__file__).parent.absolute() |
|
files = glob.glob(str(path) + os.sep + '*.py') |
|
for f in files: |
|
mfl.experiment.log_artifact(mfl.run_id, f, 'source') |
|
|
|
chkpnt_cb = pl.callbacks.ModelCheckpoint( |
|
monitor='validate_mean', |
|
verbose=True, |
|
save_top_k=1, |
|
save_weights_only=True, |
|
mode='min', |
|
every_n_train_steps=1, |
|
filename='{epoch}-{validate_mean}-{train_mean}', |
|
) |
|
|
|
img_cb = run_logger.ImgCB(**vars(args)) |
|
lr_logger = pl.callbacks.LearningRateMonitor() |
|
|
|
args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]}) |
|
else: |
|
if os.path.exists(args.test_fname): |
|
os.remove(args.test_fname) |
|
|
|
args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]}) |
|
|
|
|
|
if test_label is not None: |
|
args.n_outputs = test_label.shape[1] |
|
elif train_label is not None: |
|
args.n_outputs = train_label.shape[1] |
|
|
|
if test_input is not None: |
|
args.n_inputs = test_input.shape[1] |
|
elif train_input is not None: |
|
args.n_inputs = train_input.shape[1] |
|
|
|
|
|
n = net.Net(**vars(args)) |
|
if args.load_ver is not None: |
|
t = torch.load(args.load_ver, map_location='cpu')['state_dict'] |
|
n.load_state_dict(t) |
|
|
|
trainer = pl.Trainer.from_argparse_args(args) |
|
|
|
if args.train_files is not None: |
|
trainer.fit(n, *loaders) |
|
|
|
print(chkpnt_cb.best_model_path) |
|
elif args.label_vars: |
|
trainer.test(n, *loaders) |
|
else: |
|
predictions = trainer.predict(n, *loaders) |
|
with h5py.File(args.test_fname, "w") as F: |
|
F["predictions"] = torch.cat(predictions).numpy() |
|
|
|
|