medical imaging
ultrasound
dl_us_sos_inversion / __main__.py
laughingrice's picture
Upload 11 files
6ce7d82
"""
Deep learning framework for sound speed inversion
"""
import json
import git
import argparse
import pathlib
import glob
import os
import h5py
import loader
import run_logger
import net
import torch
import torch.utils.data as td
import pytorch_lightning as pl
# ----------------------------
# Setup command line arguments
# ----------------------------
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ')
parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty')
parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--experiment', default='DeepLearning US', help='experiment name')
parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads')
parser.add_argument('--load_ver', type=str, help='Network weights to load')
parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)')
parser.add_argument('--conf_export', type=str, help='Filename where to store settings')
parser = pl.Trainer.add_argparse_args(parser)
parser = loader.Loader.add_argparse_args(parser)
parser = net.Net.add_model_specific_args(parser)
parser = run_logger.ImgCB.add_argparse_args(parser)
args = parser.parse_args()
if args.conf is not None:
for conf_fname in args.conf:
with open(conf_fname, 'r') as f:
parser.set_defaults(**json.load(f))
# Reload arguments to override config file values with command line values
args = parser.parse_args()
if args.conf_export is not None:
with open(args.conf_export, 'w') as f:
json.dump(vars(args), f, indent=4, sort_keys=True)
if args.test_files is None and args.train_files is None:
raise ValueError('At least one of train files or test files is required')
# ----------------------------
# Load data
# ----------------------------
ld = loader.Loader(**vars(args))
test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files)
for name, tensor in (
('test_input', test_input),
('test_label', test_label),
('train_input', train_input),
('train_label', train_label)):
print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}')
loaders = []
if args.train_files is not None:
if train_input is None or train_label is None or (test_input is not None and test_label is None):
raise ValueError('Training requires labeled data')
train_ds = td.TensorDataset(train_input, train_label)
loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True))
if args.test_files is not None:
ds = [test_input]
if test_label is not None:
ds.append(test_label)
test_ds = td.TensorDataset(*ds)
loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True))
# ----------------------------
# Run
# ----------------------------
if args.train_files is not None:
if args.tags is None:
args.tags = {}
elif type(args.tags) == str:
args.tags = json.loads(args.tags)
try:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
args.tags.update({'commit': sha})
except:
print('Not a git repo, not logging commit ID')
mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags)
mfl.log_hyperparams(args)
path = pathlib.Path(__file__).parent.absolute()
files = glob.glob(str(path) + os.sep + '*.py')
for f in files:
mfl.experiment.log_artifact(mfl.run_id, f, 'source')
chkpnt_cb = pl.callbacks.ModelCheckpoint(
monitor='validate_mean',
verbose=True,
save_top_k=1,
save_weights_only=True,
mode='min',
every_n_train_steps=1,
filename='{epoch}-{validate_mean}-{train_mean}',
)
img_cb = run_logger.ImgCB(**vars(args))
lr_logger = pl.callbacks.LearningRateMonitor()
args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]})
else:
if os.path.exists(args.test_fname):
os.remove(args.test_fname)
args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]})
if test_label is not None:
args.n_outputs = test_label.shape[1]
elif train_label is not None:
args.n_outputs = train_label.shape[1]
if test_input is not None:
args.n_inputs = test_input.shape[1]
elif train_input is not None:
args.n_inputs = train_input.shape[1]
n = net.Net(**vars(args))
if args.load_ver is not None:
t = torch.load(args.load_ver, map_location='cpu')['state_dict']
n.load_state_dict(t)
trainer = pl.Trainer.from_argparse_args(args)
if args.train_files is not None:
trainer.fit(n, *loaders)
print(chkpnt_cb.best_model_path)
elif args.label_vars:
trainer.test(n, *loaders)
else:
predictions = trainer.predict(n, *loaders)
with h5py.File(args.test_fname, "w") as F:
F["predictions"] = torch.cat(predictions).numpy()