medical imaging
ultrasound
laughingrice commited on
Commit
6ce7d82
1 Parent(s): d10cbbd

Upload 11 files

Browse files
README.md CHANGED
@@ -1,3 +1,45 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep learning for speed of sound inversion in ultrasound imaging
2
+
3
+ This repository contains the code and models for the following papers:
4
+
5
+
6
+ 1. Feigin M, Freedman D, Anthony B. W. A Deep Learning Framework for Single-Sided Sound Speed Inversion in Medical Ultrasound. IEEE Trans Biomed Eng. 2020;67(4):1142-1151. doi:10.1109/TBME.2019.2931195
7
+ 2. Feigin M, Zwecker M, Freedman D, Anthony BW. Detecting muscle activation using ultrasound speed of sound inversion with deep learning. In: 2020 42nd Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC). IEEE; 2020:2092-2095. doi:10.1109/EMBC44109.2020.9175237
8
+ 3. Feigin M, Freedman D, Anthony BW. Computing Speed-of-Sound from ultrasound: user-agnostic recovery and a new benchmark. IEEE Trans Biomed Eng. 2023; doi:TBF
9
+
10
+ This repository contain the network code and models for the algorithms and results contained in the paper.
11
+
12
+ The code was tested under python 3.9. The anaconda environment is defined in environment.yml (setup environment with the `command conda env create -f environment.yml`)
13
+
14
+ ## Data
15
+
16
+ The dataset used is available on huggingface at https://huggingface.co/datasets/laughingrice/Ultrasound_planewave_sos_inversion
17
+
18
+ Variables in the files are `[sample, layer, x/channel, y/sample]` order
19
+
20
+ * `alpha_coeff` -- Alpha coefficient used for simulations, full resolution
21
+ * `c0` -- Speed-of-sound used for simulations, full resolution
22
+ * `data` -- Channel data (first 2048 samples, 64 active channels, first layer with flat plane wave, to
23
+ match existing physical hardware were used for the results in the paper)
24
+ * `dx` -- spatial dx value of `c0` and `alpha_coef`
25
+ * `f` -- temporal sampling frequency of channel data (40MHz)
26
+
27
+ ## Models
28
+
29
+ Model files appearing under the `models` directory for results presented in the paper with teh matching
30
+ execution parameters are as follows:
31
+
32
+ * `tbme_sos.pt` -- network weights for the network presented in [1]
33
+ * `python . --test_files data/supplamentary_sample.mat --test_fname tbme_sos.h5 --load_ver models/tbme_sos.pt --net_type tbme`
34
+ * `embc_sos.pt` -- network weights for the network presented in [2]
35
+ * `python . --test_files data/supplamentary_sample.mat --test_fname embc_sos.h5 --load_ver models/embc_sos.pt --net_type embc`
36
+ * `tbme2_sos.pt` -- network weights for the network presented in [3]
37
+ * `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_sos.h5 --load_ver models/tbme2_sos.pt`
38
+ * `tbme2_sos_rand_gain.pt` -- [3] trained to recover the speed-of-sound map with random gain profile and scaling
39
+ * `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_sos_gain.h5 --load_ver models/tbme2_sos_rand_gain.pt`
40
+ * `tbme2_attn.pt` -- [3] trained to recover the attenuation coefficient
41
+ * `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_attn.h5 --load_ver models/tbme2_attn.pt --label_vars alpha_coeff`
42
+ * `tbme2_sos_attn.pt` -- [3] trained to recover both the speed-of-sound map and attenuation coefficient
43
+ * `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_sos_attn.h5 --load_ver models/tbme2_sos_attn.pt --label_vars c0 alpha_coeff`
44
+ * `tbme2_phase_sos.pt` -- [3] trained to recover the speed-of-sound map using the IQ phase component
45
+ * `python . --test_files data/supplamentary_sample.mat --test_fname tbme2_phase_sos.h5 --load_ver models/tbme2_phase_sos.pt --phase_inv 1`
__main__.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Deep learning framework for sound speed inversion
3
+ """
4
+
5
+ import json
6
+ import git
7
+ import argparse
8
+ import pathlib
9
+ import glob
10
+ import os
11
+ import h5py
12
+
13
+ import loader
14
+ import run_logger
15
+ import net
16
+
17
+ import torch
18
+ import torch.utils.data as td
19
+ import pytorch_lightning as pl
20
+
21
+
22
+ # ----------------------------
23
+ # Setup command line arguments
24
+ # ----------------------------
25
+
26
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
+
28
+ parser.add_argument('--test_files', nargs='?', help='Test data (file pattern) to process / data to evaluate ')
29
+ parser.add_argument('--train_files', nargs='?', help='Train data (file pattern) to process, only evaluate test if empty')
30
+ parser.add_argument('--test_fname', default='output.h5', help='Filename into which to write testing output -- will be overwritten')
31
+
32
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
33
+
34
+ parser.add_argument('--experiment', default='DeepLearning US', help='experiment name')
35
+ parser.add_argument('--tags', nargs='?', help='Optional run tags, should evaluate to dictionary via json.loads')
36
+
37
+ parser.add_argument('--load_ver', type=str, help='Network weights to load')
38
+
39
+ parser.add_argument('--conf', type=str, action='append', help='Config file(s) to import (overridden by command line arguments)')
40
+ parser.add_argument('--conf_export', type=str, help='Filename where to store settings')
41
+
42
+
43
+ parser = pl.Trainer.add_argparse_args(parser)
44
+ parser = loader.Loader.add_argparse_args(parser)
45
+ parser = net.Net.add_model_specific_args(parser)
46
+ parser = run_logger.ImgCB.add_argparse_args(parser)
47
+
48
+ args = parser.parse_args()
49
+
50
+ if args.conf is not None:
51
+ for conf_fname in args.conf:
52
+ with open(conf_fname, 'r') as f:
53
+ parser.set_defaults(**json.load(f))
54
+
55
+ # Reload arguments to override config file values with command line values
56
+ args = parser.parse_args()
57
+
58
+ if args.conf_export is not None:
59
+ with open(args.conf_export, 'w') as f:
60
+ json.dump(vars(args), f, indent=4, sort_keys=True)
61
+
62
+ if args.test_files is None and args.train_files is None:
63
+ raise ValueError('At least one of train files or test files is required')
64
+
65
+ # ----------------------------
66
+ # Load data
67
+ # ----------------------------
68
+
69
+ ld = loader.Loader(**vars(args))
70
+ test_input, test_label, train_input, train_label = ld.load_data(test_file_pattern=args.test_files, train_file_pattern=args.train_files)
71
+
72
+ for name, tensor in (
73
+ ('test_input', test_input),
74
+ ('test_label', test_label),
75
+ ('train_input', train_input),
76
+ ('train_label', train_label)):
77
+ print(f'{name}: {tensor.shape if tensor is not None else None} -- {tensor.dtype if tensor is not None else None}')
78
+
79
+ loaders = []
80
+
81
+ if args.train_files is not None:
82
+ if train_input is None or train_label is None or (test_input is not None and test_label is None):
83
+ raise ValueError('Training requires labeled data')
84
+
85
+ train_ds = td.TensorDataset(train_input, train_label)
86
+ loaders.append(td.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, pin_memory=True))
87
+
88
+ if args.test_files is not None:
89
+ ds = [test_input]
90
+ if test_label is not None:
91
+ ds.append(test_label)
92
+
93
+ test_ds = td.TensorDataset(*ds)
94
+ loaders.append(td.DataLoader(test_ds, args.batch_size, shuffle=False, pin_memory=True))
95
+
96
+ # ----------------------------
97
+ # Run
98
+ # ----------------------------
99
+
100
+ if args.train_files is not None:
101
+ if args.tags is None:
102
+ args.tags = {}
103
+ elif type(args.tags) == str:
104
+ args.tags = json.loads(args.tags)
105
+
106
+ try:
107
+ repo = git.Repo(search_parent_directories=True)
108
+ sha = repo.head.object.hexsha
109
+ args.tags.update({'commit': sha})
110
+ except:
111
+ print('Not a git repo, not logging commit ID')
112
+
113
+ mfl = pl.loggers.MLFlowLogger(experiment_name=args.experiment, tags=args.tags)
114
+ mfl.log_hyperparams(args)
115
+
116
+ path = pathlib.Path(__file__).parent.absolute()
117
+ files = glob.glob(str(path) + os.sep + '*.py')
118
+ for f in files:
119
+ mfl.experiment.log_artifact(mfl.run_id, f, 'source')
120
+
121
+ chkpnt_cb = pl.callbacks.ModelCheckpoint(
122
+ monitor='validate_mean',
123
+ verbose=True,
124
+ save_top_k=1,
125
+ save_weights_only=True,
126
+ mode='min',
127
+ every_n_train_steps=1,
128
+ filename='{epoch}-{validate_mean}-{train_mean}',
129
+ )
130
+
131
+ img_cb = run_logger.ImgCB(**vars(args))
132
+ lr_logger = pl.callbacks.LearningRateMonitor()
133
+
134
+ args.__dict__.update({'logger': mfl, 'callbacks': [chkpnt_cb, img_cb, lr_logger]})
135
+ else:
136
+ if os.path.exists(args.test_fname):
137
+ os.remove(args.test_fname)
138
+
139
+ args.__dict__.update({'callbacks': [run_logger.TestLogger(args.test_fname)]})
140
+
141
+
142
+ if test_label is not None:
143
+ args.n_outputs = test_label.shape[1]
144
+ elif train_label is not None:
145
+ args.n_outputs = train_label.shape[1]
146
+
147
+ if test_input is not None:
148
+ args.n_inputs = test_input.shape[1]
149
+ elif train_input is not None:
150
+ args.n_inputs = train_input.shape[1]
151
+
152
+
153
+ n = net.Net(**vars(args))
154
+ if args.load_ver is not None:
155
+ t = torch.load(args.load_ver, map_location='cpu')['state_dict']
156
+ n.load_state_dict(t)
157
+
158
+ trainer = pl.Trainer.from_argparse_args(args)
159
+
160
+ if args.train_files is not None:
161
+ trainer.fit(n, *loaders)
162
+
163
+ print(chkpnt_cb.best_model_path)
164
+ elif args.label_vars:
165
+ trainer.test(n, *loaders)
166
+ else:
167
+ predictions = trainer.predict(n, *loaders)
168
+ with h5py.File(args.test_fname, "w") as F:
169
+ F["predictions"] = torch.cat(predictions).numpy()
170
+
environment.yml ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: DL-US-inversion
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ - pytorch
6
+ - gimli
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=conda_forge
9
+ - _openmp_mutex=4.5=2_kmp_llvm
10
+ - abseil-cpp=20211102.0=h27087fc_1
11
+ - absl-py=1.1.0=pyhd8ed1ab_0
12
+ - aiohttp=3.8.1=py310h5764c6d_1
13
+ - aiosignal=1.2.0=pyhd8ed1ab_0
14
+ - alembic=1.8.1=pyhd8ed1ab_0
15
+ - aom=3.4.0=h27087fc_1
16
+ - appdirs=1.4.4=pyh9f0ad1d_0
17
+ - arrow-cpp=8.0.0=py310h893e394_4_cpu
18
+ - asn1crypto=1.5.1=pyhd8ed1ab_0
19
+ - asttokens=2.0.5=pyhd8ed1ab_0
20
+ - async-timeout=4.0.2=pyhd8ed1ab_0
21
+ - attrs=21.4.0=pyhd8ed1ab_0
22
+ - aws-c-cal=0.5.11=h95a6274_0
23
+ - aws-c-common=0.6.2=h7f98852_0
24
+ - aws-c-event-stream=0.2.7=h3541f99_13
25
+ - aws-c-io=0.10.5=hfb6a706_0
26
+ - aws-checksums=0.1.11=ha31a3da_7
27
+ - aws-sdk-cpp=1.8.186=hb4091e7_3
28
+ - backcall=0.2.0=pyh9f0ad1d_0
29
+ - backports=1.1=pyhd3eb1b0_0
30
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
31
+ - blas=2.115=mkl
32
+ - blas-devel=3.9.0=15_linux64_mkl
33
+ - blinker=1.4=py_1
34
+ - boost-cpp=1.79.0=h75c5d50_0
35
+ - brotli=1.0.9=h166bdaf_7
36
+ - brotli-bin=1.0.9=h166bdaf_7
37
+ - brotlipy=0.7.0=py310h5764c6d_1004
38
+ - bzip2=1.0.8=h7f98852_4
39
+ - c-ares=1.18.1=h7f98852_0
40
+ - ca-certificates=2022.6.15=ha878542_0
41
+ - cached-property=1.5.2=hd8ed1ab_1
42
+ - cached_property=1.5.2=pyha770c72_1
43
+ - cachetools=5.0.0=pyhd8ed1ab_0
44
+ - certifi=2022.6.15=py310hff52083_0
45
+ - cffi=1.15.1=py310h255011f_0
46
+ - charset-normalizer=2.1.0=pyhd8ed1ab_0
47
+ - click=8.1.3=py310hff52083_0
48
+ - cloudpickle=2.1.0=pyhd8ed1ab_0
49
+ - colorama=0.4.5=pyhd8ed1ab_0
50
+ - configparser=5.2.0=pyhd8ed1ab_0
51
+ - cryptography=37.0.1=py310h9ce1e76_0
52
+ - cudatoolkit=11.6.0=hecad31d_10
53
+ - cudnn=8.4.1.50=hed8a83a_0
54
+ - cycler=0.11.0=pyhd8ed1ab_0
55
+ - databricks-cli=0.17.0=pyhd8ed1ab_0
56
+ - decorator=5.1.1=pyhd8ed1ab_0
57
+ - docker-py=5.0.3=py310hff52083_2
58
+ - docker-pycreds=0.4.0=py_0
59
+ - entrypoints=0.4=pyhd8ed1ab_0
60
+ - executing=0.8.3=pyhd8ed1ab_0
61
+ - expat=2.4.8=h27087fc_0
62
+ - ffmpeg=5.0.1=gpl_h512afef_107
63
+ - flask=2.1.3=pyhd8ed1ab_0
64
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
65
+ - font-ttf-inconsolata=3.000=h77eed37_0
66
+ - font-ttf-source-code-pro=2.038=h77eed37_0
67
+ - font-ttf-ubuntu=0.83=hab24e00_0
68
+ - fontconfig=2.14.0=h8e229c2_0
69
+ - fonts-conda-ecosystem=1=0
70
+ - fonts-conda-forge=1=0
71
+ - fonttools=4.34.4=py310h5764c6d_0
72
+ - freetype=2.11.0=h70c0345_0
73
+ - frozenlist=1.3.0=py310h5764c6d_1
74
+ - fsspec=2022.5.0=pyhd8ed1ab_0
75
+ - future=0.18.2=py310hff52083_5
76
+ - gettext=0.21.0=hf68c758_0
77
+ - gflags=2.2.2=he1b5a44_1004
78
+ - giflib=5.2.1=h36c2ea0_2
79
+ - gitdb=4.0.9=pyhd8ed1ab_0
80
+ - gitpython=3.1.27=pyhd8ed1ab_0
81
+ - glog=0.6.0=h6f12383_0
82
+ - gmp=6.2.1=h58526e2_0
83
+ - gnutls=3.7.6=hf3e180e_5
84
+ - google-auth=2.9.1=pyh6c4a22f_0
85
+ - google-auth-oauthlib=0.4.1=py_2
86
+ - greenlet=1.1.2=py310hd8f1fbe_2
87
+ - grpc-cpp=1.46.3=hbd84cd8_2
88
+ - grpcio=1.46.3=py310ha0b7d45_2
89
+ - gunicorn=20.1.0=py310hff52083_2
90
+ - h5py=3.7.0=nompi_py310h06dffec_100
91
+ - hdf5=1.12.1=nompi_h2386368_104
92
+ - htmlmin=0.1.12=py_1
93
+ - icu=70.1=h27087fc_0
94
+ - idna=3.3=pyhd8ed1ab_0
95
+ - imagehash=4.2.1=pyhd8ed1ab_0
96
+ - importlib-metadata=4.11.4=py310hff52083_0
97
+ - importlib_resources=5.8.0=pyhd8ed1ab_0
98
+ - ipython=8.4.0=py310hff52083_0
99
+ - itsdangerous=2.1.2=pyhd8ed1ab_0
100
+ - jedi=0.18.1=py310hff52083_1
101
+ - jinja2=3.1.2=pyhd8ed1ab_1
102
+ - joblib=1.1.0=pyhd8ed1ab_0
103
+ - jpeg=9e=h166bdaf_2
104
+ - keyutils=1.6.1=h166bdaf_0
105
+ - kiwisolver=1.4.4=py310hbf28c38_0
106
+ - krb5=1.19.3=h3790be6_0
107
+ - lame=3.100=h7f98852_1001
108
+ - lcms2=2.12=hddcbb42_0
109
+ - ld_impl_linux-64=2.38=h1181459_1
110
+ - lerc=3.0=h9c3ff4c_0
111
+ - libblas=3.9.0=15_linux64_mkl
112
+ - libbrotlicommon=1.0.9=h166bdaf_7
113
+ - libbrotlidec=1.0.9=h166bdaf_7
114
+ - libbrotlienc=1.0.9=h166bdaf_7
115
+ - libcblas=3.9.0=15_linux64_mkl
116
+ - libcrc32c=1.1.2=h9c3ff4c_0
117
+ - libcurl=7.83.1=h7bff187_0
118
+ - libdeflate=1.12=h166bdaf_0
119
+ - libdrm=2.4.112=h166bdaf_0
120
+ - libedit=3.1.20210910=h7f8727e_0
121
+ - libev=4.33=h516909a_1
122
+ - libevent=2.1.10=h9b69904_4
123
+ - libffi=3.4.2=h7f98852_5
124
+ - libgcc-ng=12.1.0=h8d9b700_16
125
+ - libgfortran-ng=12.1.0=h69a702a_16
126
+ - libgfortran5=12.1.0=hdcd56e2_16
127
+ - libgomp=12.1.0=h8d9b700_16
128
+ - libgoogle-cloud=1.40.2=hefc27d0_0
129
+ - libiconv=1.16=h516909a_0
130
+ - libidn2=2.3.3=h166bdaf_0
131
+ - liblapack=3.9.0=15_linux64_mkl
132
+ - liblapacke=3.9.0=15_linux64_mkl
133
+ - libllvm11=11.1.0=hf817b99_3
134
+ - libnghttp2=1.47.0=h727a467_0
135
+ - libnsl=2.0.0=h7f98852_0
136
+ - libpciaccess=0.16=h516909a_0
137
+ - libpng=1.6.37=h753d276_3
138
+ - libprotobuf=3.20.1=h6239696_0
139
+ - libssh2=1.10.0=ha56f1ee_2
140
+ - libstdcxx-ng=12.1.0=ha89aaad_16
141
+ - libtasn1=4.18.0=h166bdaf_1
142
+ - libthrift=0.16.0=h519c5ea_1
143
+ - libtiff=4.4.0=hc85c160_1
144
+ - libunistring=0.9.10=h7f98852_0
145
+ - libutf8proc=2.7.0=h7f98852_0
146
+ - libuuid=2.32.1=h7f98852_1000
147
+ - libva=2.15.0=h166bdaf_0
148
+ - libvpx=1.11.0=h9c3ff4c_3
149
+ - libwebp=1.2.2=h3452ae3_0
150
+ - libwebp-base=1.2.2=h7f98852_1
151
+ - libxcb=1.13=h7f98852_1004
152
+ - libxml2=2.9.14=h22db469_3
153
+ - libzlib=1.2.12=h166bdaf_2
154
+ - llvm-openmp=14.0.4=he0ac6c6_0
155
+ - llvmlite=0.38.1=py310h58363a5_0
156
+ - lz4-c=1.9.3=h9c3ff4c_1
157
+ - magma=2.5.4=h6103c52_2
158
+ - mako=1.2.1=pyhd8ed1ab_0
159
+ - markdown=3.4.1=pyhd8ed1ab_0
160
+ - markupsafe=2.1.1=py310h5764c6d_1
161
+ - matplotlib-base=3.5.2=py310h5701ce4_0
162
+ - matplotlib-inline=0.1.3=pyhd8ed1ab_0
163
+ - missingno=0.4.2=py_1
164
+ - mkl=2022.1.0=h84fe81f_915
165
+ - mkl-devel=2022.1.0=ha770c72_916
166
+ - mkl-include=2022.1.0=h84fe81f_915
167
+ - mlflow=1.27.0=py310ha13cd29_0
168
+ - multidict=6.0.2=py310h5764c6d_1
169
+ - multimethod=1.4=py_0
170
+ - munkres=1.1.4=pyh9f0ad1d_0
171
+ - nccl=2.12.12.1=h0800d71_0
172
+ - ncurses=6.3=h27087fc_1
173
+ - nettle=3.8=hc379101_0
174
+ - networkx=2.8.4=pyhd8ed1ab_0
175
+ - ninja=1.11.0=h924138e_0
176
+ - numba=0.55.0=py310h00e6091_0
177
+ - numpy=1.23.1=py310h53a5b5f_0
178
+ - oauthlib=3.2.0=pyhd8ed1ab_0
179
+ - openh264=2.2.0=h6239696_1
180
+ - openjpeg=2.4.0=hb52868f_1
181
+ - openssl=1.1.1q=h166bdaf_0
182
+ - orc=1.7.5=h6c59b99_0
183
+ - p11-kit=0.24.1=hc5aa10d_0
184
+ - packaging=21.3=pyhd8ed1ab_0
185
+ - pandas=1.4.3=py310h769672d_0
186
+ - pandas-profiling=3.2.0=pyhd8ed1ab_0
187
+ - parso=0.8.3=pyhd8ed1ab_0
188
+ - patsy=0.5.2=pyhd8ed1ab_0
189
+ - pexpect=4.8.0=pyh9f0ad1d_2
190
+ - phik=0.12.2=py310h7c64c84_0
191
+ - pickleshare=0.7.5=py_1003
192
+ - pillow=9.2.0=py310he619898_0
193
+ - pip=22.1.2=pyhd8ed1ab_0
194
+ - prometheus_client=0.14.1=pyhd8ed1ab_0
195
+ - prometheus_flask_exporter=0.20.2=pyhd8ed1ab_0
196
+ - prompt-toolkit=3.0.30=pyha770c72_0
197
+ - protobuf=3.20.1=py310hd8f1fbe_0
198
+ - pthread-stubs=0.4=h36c2ea0_1001
199
+ - ptyprocess=0.7.0=pyhd3deb0d_0
200
+ - pure_eval=0.2.2=pyhd8ed1ab_0
201
+ - pyarrow=8.0.0=py310h468efa6_0
202
+ - pyasn1=0.4.8=py_0
203
+ - pyasn1-modules=0.2.8=py_0
204
+ - pybind11-abi=4=hd8ed1ab_3
205
+ - pycparser=2.21=pyhd8ed1ab_0
206
+ - pydantic=1.9.1=py310h5764c6d_0
207
+ - pydeprecate=0.3.2=pyhd8ed1ab_0
208
+ - pygments=2.12.0=pyhd8ed1ab_0
209
+ - pyjwt=2.4.0=pyhd8ed1ab_0
210
+ - pyopenssl=22.0.0=pyhd8ed1ab_0
211
+ - pyparsing=3.0.9=pyhd8ed1ab_0
212
+ - pysocks=1.7.1=py310hff52083_5
213
+ - python=3.10.5=h582c2e5_0_cpython
214
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
215
+ - python_abi=3.10=2_cp310
216
+ - pytorch=1.12.0=py3.10_cuda11.6_cudnn8.3.2_0
217
+ - pytorch-lightning=1.6.5=pyhd8ed1ab_0
218
+ - pytorch-mutex=1.0=cuda
219
+ - pytz=2022.1=pyhd8ed1ab_0
220
+ - pyu2f=0.1.5=pyhd8ed1ab_0
221
+ - pywavelets=1.3.0=py310hde88566_1
222
+ - pyyaml=6.0=py310h5764c6d_4
223
+ - querystring_parser=1.2.4=py_0
224
+ - re2=2022.06.01=h27087fc_0
225
+ - readline=8.1.2=h0f457ee_0
226
+ - requests=2.28.1=pyhd8ed1ab_0
227
+ - requests-oauthlib=1.3.1=pyhd8ed1ab_0
228
+ - rsa=4.8=pyhd8ed1ab_0
229
+ - s2n=1.0.10=h9b69904_0
230
+ - scikit-learn=1.1.1=py310hffb9edd_0
231
+ - scipy=1.8.1=py310h7612f91_0
232
+ - seaborn=0.11.2=hd8ed1ab_0
233
+ - seaborn-base=0.11.2=pyhd8ed1ab_0
234
+ - setuptools=59.5.0=py310hff52083_0
235
+ - shap=0.41.0=py310h769672d_0
236
+ - six=1.16.0=pyh6c4a22f_0
237
+ - sleef=3.5.1=h9b69904_2
238
+ - slicer=0.0.7=pyhd8ed1ab_0
239
+ - smmap=3.0.5=pyh44b312d_0
240
+ - snappy=1.1.9=hbd366e4_1
241
+ - sqlalchemy=1.4.39=py310h5764c6d_0
242
+ - sqlite=3.39.1=h4ff8645_0
243
+ - sqlparse=0.4.2=pyhd8ed1ab_0
244
+ - stack_data=0.3.0=pyhd8ed1ab_0
245
+ - statsmodels=0.13.2=py310hde88566_0
246
+ - svt-av1=1.1.0=h27087fc_1
247
+ - tabulate=0.8.10=pyhd8ed1ab_0
248
+ - tangled-up-in-unicode=0.2.0=pyhd8ed1ab_0
249
+ - tbb=2021.5.0=h924138e_1
250
+ - tensorboard=2.6.0=py_0
251
+ - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
252
+ - threadpoolctl=3.1.0=pyh8a188c0_0
253
+ - tk=8.6.12=h27826a3_0
254
+ - torchaudio=0.12.0=py310_cu116
255
+ - torchmetrics=0.9.2=pyhd8ed1ab_0
256
+ - torchvision=0.13.0=py310_cu116
257
+ - tqdm=4.64.0=pyhd8ed1ab_0
258
+ - traitlets=5.3.0=pyhd8ed1ab_0
259
+ - typing-extensions=4.3.0=hd8ed1ab_0
260
+ - typing_extensions=4.3.0=pyha770c72_0
261
+ - tzdata=2022a=h191b570_0
262
+ - unicodedata2=14.0.0=py310h5764c6d_1
263
+ - urllib3=1.26.10=pyhd8ed1ab_0
264
+ - visions=0.7.4=pyhd8ed1ab_0
265
+ - wcwidth=0.2.5=pyh9f0ad1d_2
266
+ - websocket-client=1.3.3=pyhd8ed1ab_0
267
+ - werkzeug=2.1.2=pyhd8ed1ab_1
268
+ - wheel=0.37.1=pyhd8ed1ab_0
269
+ - x264=1!161.3030=h7f98852_1
270
+ - x265=3.5=h924138e_3
271
+ - xorg-fixesproto=5.0=h7f98852_1002
272
+ - xorg-kbproto=1.0.7=h7f98852_1002
273
+ - xorg-libx11=1.7.2=h7f98852_0
274
+ - xorg-libxau=1.0.9=h7f98852_0
275
+ - xorg-libxdmcp=1.1.3=h7f98852_0
276
+ - xorg-libxext=1.3.4=h7f98852_1
277
+ - xorg-libxfixes=5.0.3=h7f98852_1004
278
+ - xorg-xextproto=7.3.0=h7f98852_1002
279
+ - xorg-xproto=7.0.31=h7f98852_1007
280
+ - xz=5.2.5=h516909a_1
281
+ - yaml=0.2.5=h7f98852_2
282
+ - yarl=1.7.2=py310h5764c6d_2
283
+ - zipp=3.8.0=pyhd8ed1ab_0
284
+ - zlib=1.2.12=h166bdaf_2
285
+ - zstd=1.5.2=h8a70e8d_2
286
+ prefix: /home/micha/.conda/envs/DL-US-inversion
loader.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines a Loader class to load data from a file or file wildcard
3
+ """
4
+
5
+ import argparse
6
+
7
+ import h5py
8
+ import torch
9
+ import numpy as np
10
+ import glob
11
+ from typing import Tuple
12
+
13
+
14
+ class Loader:
15
+ """
16
+ Data loader class
17
+ """
18
+
19
+ def __init__(self, **kwargs):
20
+ parser = Loader.add_argparse_args()
21
+ for action in parser._actions:
22
+ if action.dest in kwargs:
23
+ action.default = kwargs[action.dest]
24
+ args = parser.parse_args([])
25
+ self.__dict__.update(vars(args))
26
+
27
+ if type(self.label_vars) is str:
28
+ self.label_vars = [self.label_vars]
29
+
30
+ @staticmethod
31
+ def add_argparse_args(parent_parser=None):
32
+ """
33
+ Add argeparse argument for the data loader
34
+ """
35
+ parser = argparse.ArgumentParser(
36
+ prog='Loader',
37
+ usage=Loader.__doc__,
38
+ parents=[parent_parser] if parent_parser is not None else [],
39
+ add_help=False)
40
+
41
+ parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data')
42
+ parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data')
43
+
44
+ parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*',
45
+ help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]')
46
+ parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]')
47
+ parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image')
48
+
49
+ parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor')
50
+ parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.')
51
+
52
+ return parser
53
+
54
+ def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
55
+ """Loads training/testing data from file(s)
56
+
57
+ Arguments:
58
+ test_file_pattern {str} -- testing dataset(s) pattern
59
+ train_file_pattern {str} -- training dataset(s) pattern
60
+
61
+ Returns:
62
+ (test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded
63
+ """
64
+
65
+ test_inputs, test_labels = self._load_data_files(test_file_pattern)
66
+ train_inputs, train_labels = self._load_data_files(train_file_pattern)
67
+
68
+ if train_file_pattern is not None and train_inputs is None:
69
+ raise ValueError('Failed to load train set')
70
+
71
+ if test_file_pattern is not None and test_inputs is None:
72
+ raise ValueError('Failed to load train set')
73
+
74
+ return test_inputs, test_labels, train_inputs, train_labels
75
+
76
+ def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
77
+ """ Perform actual data loading
78
+
79
+ Args:
80
+ file_pattern: file name pattern
81
+
82
+ Returns:
83
+ inputs and labels tensors
84
+ """
85
+
86
+ inputs, labels = None, None
87
+
88
+ if file_pattern is None:
89
+ return inputs, labels
90
+
91
+ files = glob.glob(file_pattern)
92
+
93
+ if len(files) == 0:
94
+ raise ValueError(f'{file_pattern=} comes up empty')
95
+
96
+ # Load first file to get output dimensions
97
+ with h5py.File(files[0], 'r') as f:
98
+ if self.input_var not in f:
99
+ raise ValueError(f'input data key not in file: {self.input_var=}')
100
+
101
+ shape = list(f[self.input_var].shape)
102
+ if self.inputs_crop is not None:
103
+ for i in range(len(self.inputs_crop) // 2):
104
+ shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2]
105
+
106
+ shape[0] *= len(files)
107
+
108
+ inputs = np.empty(shape, np.single)
109
+
110
+ if len(self.label_vars):
111
+ if not all([v in f for v in self.label_vars]):
112
+ raise ValueError(f'labels data key(s) not in file: {self.label_vars=}')
113
+
114
+ shape = list(f[self.label_vars[0]].shape)
115
+ shape[1] *= len(self.label_vars)
116
+ if self.labels_crop is not None:
117
+ for i in range(len(self.labels_crop) // 2):
118
+ shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2]
119
+
120
+ shape[-1] = int(shape[-1] * self.labels_resize)
121
+ shape[-2] = int(shape[-2] * self.labels_resize)
122
+ shape[0] *= len(files)
123
+
124
+ labels = np.empty(shape, np.single)
125
+
126
+ # Load data from files
127
+ pos = 0
128
+ for file in files:
129
+ with h5py.File(files[0], 'r') as f:
130
+ tmp_inputs = np.array(f[self.input_var])
131
+
132
+ if self.inputs_crop is not None:
133
+ slc = [slice(None)] * 4
134
+ for i in range(len(self.inputs_crop) // 2):
135
+ slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1])
136
+ tmp_inputs = tmp_inputs[tuple(slc)]
137
+
138
+ inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs
139
+
140
+ if len(self.label_vars):
141
+ tmp_labels = []
142
+ for v in self.label_vars:
143
+ tmp_labels.append(np.array(f[v]))
144
+ tmp_labels = np.concatenate(tmp_labels, axis=1)
145
+
146
+ if self.labels_crop is not None and self.labels_crop:
147
+ slc = [slice(None)] * 4
148
+ for i in range(len(self.labels_crop) // 2):
149
+ slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1])
150
+ tmp_labels = tmp_labels[tuple(slc)]
151
+
152
+ if self.labels_resize != 1.0:
153
+ tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy()
154
+
155
+ labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels
156
+
157
+ pos += tmp_inputs.shape[0]
158
+
159
+ inputs = inputs[:pos, ...]
160
+ if len(self.label_vars):
161
+ labels = labels[:pos, ...]
162
+
163
+ if self.data_scale != 1.0:
164
+ inputs *= self.data_scale
165
+
166
+ if self.data_gain != 0.0:
167
+ gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1))
168
+ inputs *= gain
169
+
170
+ # Required when inputs is non-continuous due to transpose
171
+ # TODO: Could probably use a check on strides and do a conditional copy.
172
+ inputs = torch.from_numpy(inputs.copy())
173
+ if len(self.label_vars):
174
+ labels = torch.from_numpy(labels)
175
+
176
+ return inputs, labels
models/embc_sos.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f16952eb255caf2c845fd8d0cbdc17e5ea8d8a63cc7669699f8003cc388b22e6
3
+ size 17301623
models/tbme2_attn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6888856204f0137b3c396fd554b1e5a30b97dcc066638f6297267804f0fd44c
3
+ size 9633441
models/tbme2_phase_sos.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5582ded80ec25c459a2307f698fccd0ea7d2c2537914fc65aa0ef0777ee27759
3
+ size 9639457
models/tbme2_sos.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9636c47b19f376e7236149584b715477249190f996f9f0f44a11692b02cf28c4
3
+ size 9633441
models/tbme_sos.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de4663eeb6b655d021fd69381703db2e0e90ba20ceb033586a6097f33d7883d4
3
+ size 6100085
net.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Network definition file
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchaudio.functional import lfilter
9
+
10
+ from pytorch_lightning import LightningModule
11
+
12
+ import numpy as np
13
+ from scipy.signal import butter, gaussian
14
+ from copy import deepcopy
15
+ import argparse
16
+
17
+
18
+ class Net(LightningModule):
19
+ def __init__(self, **kwargs):
20
+ super().__init__()
21
+
22
+ parser = Net.add_model_specific_args()
23
+ for action in parser._actions:
24
+ if action.dest in kwargs:
25
+ action.default = kwargs[action.dest]
26
+
27
+ args = parser.parse_args([])
28
+ self.hparams.update(vars(args))
29
+
30
+ if not hasattr(self, f"_init_{self.hparams.net_type}_net"):
31
+ raise ValueError(f"Unknown net type {self.hparams.net_type}")
32
+
33
+ self._net = eval(f"self._init_{self.hparams.net_type}_net(n_inputs={self.hparams.n_inputs}, n_outputs={self.hparams.n_outputs})")
34
+
35
+ if self.hparams.bias is not None:
36
+ if hasattr(self.hparams.bias, "__iter__"):
37
+ for i in range(len(self.hparams.bias)):
38
+ self._net[-1].c.bias[i].data.fill_(self.hparams.bias[i])
39
+ else:
40
+ self._net[-1].c.bias.data.fill_(self.hparams.bias)
41
+
42
+ @staticmethod
43
+ def _init_tbme2_net(n_inputs: int = 1, n_outputs: int = 1):
44
+ return nn.Sequential(
45
+ # Encoder
46
+ DownBlock(n_inputs, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
47
+ DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
48
+ DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
49
+ DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=True, layers=3),
50
+ DownBlock(32, 32, 64, 3, stride=1, pool=[2, 2], push=True, layers=3),
51
+ DownBlock(64, 64, 128, 3, stride=1, pool=[2, 2], push=True, layers=3),
52
+ DownBlock(128, 128, 512, 3, stride=1, pool=[2, 2], push=False, layers=3),
53
+ # Decoder
54
+ UpBlock(512, 128, 3, scale_factor=2, pop=False, layers=3),
55
+ UpBlock(256, 64, 3, scale_factor=2, pop=True, layers=3),
56
+ UpBlock(128, 32, 3, scale_factor=2, pop=True, layers=3),
57
+ UpBlock(64, 32, 3, scale_factor=2, pop=True, layers=3),
58
+ UpStep(32, 32, 3, scale_factor=1),
59
+ Compress(32, n_outputs))
60
+
61
+ @staticmethod
62
+ def _init_embc_net(n_inputs: int = 1, n_outputs: int = 1):
63
+ return nn.Sequential(
64
+ # Encoder
65
+ DownBlock(n_inputs, 32, 32, 15, [1, 2], None, layers=1),
66
+ DownBlock(32, 32, 32, 13, [1, 2], None, layers=1),
67
+ DownBlock(32, 32, 32, 11, [1, 2], None, layers=1),
68
+ DownBlock(32, 32, 32, 9, [1, 2], None, True, layers=1),
69
+ DownBlock(32, 32, 64, 7, 1, [2, 2], True, layers=1),
70
+ DownBlock(64, 64, 128, 5, 1, [2, 2], True, layers=1),
71
+ DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
72
+ # Decoder
73
+ UpBlock(512, 128, 5, 2, layers=1),
74
+ UpBlock(256, 64, 7, 2, True, layers=1),
75
+ UpBlock(128, 32, 9, 2, True, layers=1),
76
+ UpBlock(64, 32, 11, 2, True, layers=1),
77
+ UpStep(32, 32, 3, 1),
78
+ Compress(32, n_outputs))
79
+
80
+ @staticmethod
81
+ def _init_tbme_net(n_inputs: int = 1, n_outputs: int = 1):
82
+ return nn.Sequential(
83
+ # Encoder
84
+ DownBlock(n_inputs, 32, 32, 3, [1, 2], None, layers=1),
85
+ DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
86
+ DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
87
+ DownBlock(32, 32, 32, 3, [1, 2], None, True, layers=1),
88
+ DownBlock(32, 32, 64, 3, 1, [2, 2], True, layers=1),
89
+ DownBlock(64, 64, 128, 3, 1, [2, 2], True, layers=1),
90
+ DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
91
+ # Decoder
92
+ UpBlock(512, 128, 3, 2, layers=1),
93
+ UpBlock(256, 64, 3, 2, True, layers=1),
94
+ UpBlock(128, 32, 3, 2, True, layers=1),
95
+ UpBlock(64, 32, 3, 2, True, layers=1),
96
+ UpStep(32, 32, 3, 1),
97
+ Compress(32, n_outputs))
98
+
99
+ @staticmethod
100
+ def add_model_specific_args(parent_parser=None):
101
+ parser = argparse.ArgumentParser(
102
+ prog="Net",
103
+ usage=Net.__doc__,
104
+ parents=[parent_parser] if parent_parser is not None else [],
105
+ add_help=False)
106
+
107
+ parser.add_argument("--random_mirror", type=int, nargs="?", default=1, help="Randomly mirror data to increase diversity when using flat plate wave")
108
+ parser.add_argument("--noise_std", type=float, nargs="*", help="range of std of random noise to add to the input signal [0 val] or [min max]")
109
+ parser.add_argument("--quantization", type=float, nargs="?", help="Quantization noise")
110
+ parser.add_argument("--rand_drop", type=int, nargs="*", help="Random drop lines, between 0 and value lines if single value, or between two values")
111
+ parser.add_argument("--normalize_net", type=float, default=0.0, help="Coefficient for normalizing network weights")
112
+
113
+ parser.add_argument("--learning_rate", type=float, default=5e-3, help="Learning rate to use for optimizer")
114
+ parser.add_argument("--lr_sched_step", type=int, default=15, help="Learning decay, update step size")
115
+ parser.add_argument("--lr_sched_gamma", type=float, default=0.65, help="Learning decay gamma")
116
+
117
+ parser.add_argument("--net_type", default="tbme2", help="The network to use [tbme2/embc/tbme]")
118
+ parser.add_argument("--bias", type=float, nargs="*", help="Set bias on last layer, set to 1500 when training from scratch on SoS output")
119
+ parser.add_argument("--decimation", type=int, help="Subsample phase signal")
120
+ parser.add_argument("--phase_inv", type=int, default=0, help="Use phase for inversion")
121
+
122
+ parser.add_argument("--center_freq", type=float, default=5e6, help="Matched filter and IQ demodulation frequency")
123
+ parser.add_argument("--n_periods", type=float, default=5, help="Matched filter length")
124
+ parser.add_argument("--matched_filter", type=int, nargs="?", default=0, help="Apply matched filter, set to 1 to run during forward pass, 2 to run during preprocessing phase (before adding noise)")
125
+
126
+ parser.add_argument("--rand_output_crop", type=int, help="Subsample phase signal")
127
+ parser.add_argument("--rand_scale", type=float, nargs="*", help="Random scaling range [min max] -- (10 ** rand_scale)")
128
+ parser.add_argument("--rand_gain", type=float, nargs="*", help="Random gain coefficient range [min max] -- (10 ** rand_gain)")
129
+
130
+ parser.add_argument("--n_inputs", type=int, default=1, help="Number of input layers")
131
+ parser.add_argument("--n_outputs", type=int, default=1, help="Number of output layers")
132
+ parser.add_argument("--scale_losses", type=float, nargs="*", help="Scale each layer of the loss function by given value")
133
+
134
+ return parser
135
+
136
+ def forward(self, x) -> torch.Tensor:
137
+ # Matched filter
138
+ if self.hparams.matched_filter == 1:
139
+ x = self._matched_filter(x)
140
+
141
+ # compute IQ phase if in phase_inv mode
142
+ if self.hparams.phase_inv:
143
+ x = self._phase(x)
144
+
145
+ # Decimation
146
+ if self.hparams.decimation != 1:
147
+ x = x[..., ::self.hparams.decimation]
148
+
149
+ # Apply network
150
+ x = self._net((x, []))
151
+
152
+ return x
153
+
154
+ def _matched_filter(self, x):
155
+ sampling_freq = 40e6
156
+
157
+ samples_per_cycle = sampling_freq / self.hparams.center_freq
158
+ n_samples = np.ceil(samples_per_cycle * self.hparams.n_periods + 1)
159
+
160
+ signal = torch.sin(torch.arange(n_samples, device=x.device) / samples_per_cycle * 2 * np.pi) * torch.from_numpy(gaussian(n_samples, (n_samples - 1) / 6).astype(np.single)).to(x.device)
161
+
162
+ return torch.nn.functional.conv1d(x.reshape(x.shape[:2] + (-1,)), signal.reshape(1, 1, -1), padding="same").reshape(x.shape)
163
+
164
+ def _phase(self, x):
165
+ f = self.hparams.center_freq
166
+ F = 40e6
167
+ N = x.shape[-1]
168
+
169
+ n = int(round(f * N / F))
170
+
171
+ X = torch.fft.fft(x, dim=-1)
172
+ X[..., (2 * n + 1):] = 0
173
+ X[..., :(2 * n + 1)] *= torch.from_numpy(gaussian(2 * n + 1, 2 * n / 6).astype(np.single)).to(x.device)
174
+ X = X.roll(-n, dims=-1)
175
+ x = torch.fft.ifft(X, dim=-1)
176
+
177
+ return x.angle()
178
+
179
+ def _preprocess(self, x):
180
+ # Matched filter
181
+ if self.hparams.matched_filter == 2:
182
+ x = self._matched_filter(x)
183
+
184
+ # Gaussian (normal) noise - random scaling, normalized to signal STD
185
+ if (ns := self.hparams.noise_std) and len(ns):
186
+ scl = ns[0] if len(ns) == 1 else torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (ns[-1] - ns[-2]) + ns[-2]
187
+ scl *= x.std()
188
+ x += torch.empty_like(x).normal_() * scl
189
+
190
+ # Random multiplicative scaling
191
+ if (rs := self.hparams.rand_scale) and len(rs):
192
+ x *= 10 ** (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (rs[-1] - rs[-2]) + rs[-2])
193
+
194
+ # Random exponential gain
195
+ if (gs := self.hparams.rand_gain) and len(gs):
196
+ gain = torch.FloatTensor([10.0]).to(x.device) ** \
197
+ (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * ((gs[-1] - gs[-2]) + gs[-2]) *
198
+ torch.linspace(0, 1, x.shape[-1]).to(x.device).view(1, 1, 1, -1))
199
+ x *= gain
200
+
201
+ # Quantization noise, to emulated ADC
202
+ if (quantization := self.hparams.quantization) is not None:
203
+ x = (x * quantization).round() * (1.0 / quantization)
204
+
205
+ # Randomly zero out some of the channels
206
+ if (rand_drop := self.hparams.rand_drop) and len(rand_drop):
207
+ if len(rand_drop) == 1:
208
+ rand_drop = [0, ] + rand_drop
209
+
210
+ for i in range(x.shape[0]):
211
+ lines = np.random.randint(0, x.shape[2], np.random.randint(rand_drop[0], rand_drop[1] + 1))
212
+ x[i, :, lines, :] = 0.
213
+
214
+ return x
215
+
216
+ def _log_losses(self, outputs: torch.Tensor, labels: torch.Tensor, prefix: str = ""):
217
+ diff = torch.abs(labels.detach() - outputs.detach())
218
+
219
+ s1 = int(diff.shape[-1] * (1.0 / 3.0))
220
+ s2 = int(diff.shape[-1] * (2.0 / 3.0))
221
+
222
+ for i in range(diff.shape[1]):
223
+ tag = f"{i}_" if diff.shape[1] > 1 else ""
224
+
225
+ losses = {
226
+ f"{prefix + tag}rmse": torch.sqrt(torch.mean(diff[:, i, ...] * diff[:, i, ...])).item(),
227
+ f"{prefix + tag}mean": torch.mean(diff[:, i, ...]).item(),
228
+ f"{prefix + tag}short": torch.mean(diff[:, i, :, :s1]).item(),
229
+ f"{prefix + tag}med": torch.mean(diff[:, i, :, s1:s2]).item(),
230
+ f"{prefix + tag}long": torch.mean(diff[:, i, :, s2:]).item()}
231
+
232
+ self.log_dict(losses, prog_bar=True)
233
+
234
+ def training_step(self, batch, batch_idx):
235
+ if self.hparams.random_mirror:
236
+ mirror = np.random.randint(0, 2, batch[0].shape[0])
237
+
238
+ for b in batch:
239
+ for i, m in enumerate(mirror):
240
+ if not m:
241
+ continue
242
+
243
+ b[i, ...] = b[i, :, range(b.shape[-2] - 1, -1, -1), :] # Pytorch does not handle negative steps
244
+
245
+ loss = self._common_step(batch, batch_idx, "train_")
246
+
247
+ if self.hparams.normalize_net:
248
+ for W in self.parameters():
249
+ loss += self.hparams.normalize_net * W.norm(2)
250
+
251
+ return loss
252
+
253
+ def validation_step(self, batch, batch_idx):
254
+ return self._common_step(batch, batch_idx, "validate_")
255
+
256
+ def test_step(self, batch, batch_idx):
257
+ return self._common_step(batch, batch_idx, "test_")
258
+
259
+ def predict_step(self, batch, batch_idx):
260
+ x = batch[0]
261
+
262
+ x = self._preprocess(x)
263
+ z = self(x)
264
+
265
+ if isinstance(z, tuple):
266
+ z = z[0]
267
+
268
+ return z
269
+
270
+ def _common_step(self, batch, batch_idx, prefix):
271
+ x, y = batch
272
+
273
+ if self.hparams.rand_output_crop:
274
+ crop = np.random.randint(0, self.hparams.rand_output_crop, batch[0].shape[0])
275
+
276
+ for i, c in enumerate(crop):
277
+ if not c:
278
+ continue
279
+
280
+ x[i, :, :-c, :] = x[i, :, c:, :].clone()
281
+ y[i, :, :-c*2, :] = \
282
+ y[i, :, c*2-1:-1, :].clone() if np.random.randint(2) else \
283
+ y[i, :, c*2:, :].clone()
284
+
285
+ x = x[..., :-self.hparams.rand_output_crop, :]
286
+ y = y[..., :-self.hparams.rand_output_crop*2, :]
287
+
288
+ x = self._preprocess(x)
289
+ z = self(x)
290
+
291
+ outputs = z[0] if isinstance(z, tuple) or isinstance(z, list) else z
292
+ self._log_losses(outputs, y, prefix)
293
+
294
+ if (self.hparams.scale_losses) and len(self.hparams.scale_losses):
295
+ s = torch.FloatTensor(self.hparams.scale_losses).to(y.device).view(1, -1, 1, 1)
296
+ loss = F.mse_loss(s * z, s * y)
297
+ else:
298
+ loss = F.mse_loss(y, outputs)
299
+
300
+ self.log(prefix + "loss", np.sqrt(loss.item()))
301
+
302
+ return loss
303
+
304
+ def configure_optimizers(self):
305
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
306
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.lr_sched_step, self.hparams.lr_sched_gamma)
307
+
308
+ return [optimizer], [scheduler]
309
+
310
+
311
+ class DownStep(nn.Module):
312
+ """
313
+ Down scaling step in the encoder decoder network
314
+ """
315
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: int = 1, pool: tuple = None) -> None:
316
+ """Constructor
317
+
318
+ Arguments:
319
+ in_channels {int} -- Number of input channels for 2D convolution
320
+ out_channels {int} -- Number of output channels for 2D convolution
321
+ kernel_size {tuple} -- Convolution kernel size
322
+
323
+ Keyword Arguments:
324
+ stride {int} -- Stride of convolution, set to 1 to disable (default: {1})
325
+ pool {tuple} -- max pulling size, set to None to disable (default: {None})
326
+ """
327
+ super(DownStep, self).__init__()
328
+
329
+ self.c = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
330
+ self.n = nn.BatchNorm2d(out_channels)
331
+ self.pool = pool
332
+
333
+ def forward(self, x: torch.tensor) -> torch.tensor:
334
+ """Run the forward step
335
+
336
+ Arguments:
337
+ x {torch.tensor} -- input tensor
338
+
339
+ Returns:
340
+ torch.tensor -- output tensor
341
+ """
342
+ x = self.c(x)
343
+ x = F.relu(x)
344
+ if self.pool is not None:
345
+ x = F.max_pool2d(x, self.pool)
346
+ x = self.n(x)
347
+
348
+ return x
349
+
350
+
351
+ class UpStep(nn.Module):
352
+ """
353
+ Up scaling step in the encoder decoder network
354
+ """
355
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, scale_factor: int = 2) -> None:
356
+ """Constructor
357
+
358
+ Arguments:
359
+ in_channels {int} -- Number of input channels for 2D convolution
360
+ out_channels {int} -- Number of output channels for 2D convolution
361
+ kernel_size {int} -- Convolution kernel size
362
+
363
+ Keyword Arguments:
364
+ scale_factor {int} -- Upsampling scaling factor (default: {2})
365
+ """
366
+ super(UpStep, self).__init__()
367
+
368
+ self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
369
+ self.n = nn.BatchNorm2d(out_channels)
370
+ self.scale_factor = scale_factor
371
+
372
+ def forward(self, x: torch.tensor) -> torch.tensor:
373
+ """Run the forward step
374
+
375
+ Arguments:
376
+ x {torch.tensor} -- input tensor
377
+
378
+ Returns:
379
+ torch.tensor -- output tensor
380
+ """
381
+ if isinstance(x, tuple):
382
+ x = x[0]
383
+
384
+ if self.scale_factor != 1:
385
+ x = F.interpolate(x, scale_factor=self.scale_factor)
386
+
387
+ x = self.c(x)
388
+ x = F.relu(x)
389
+ x = self.n(x)
390
+
391
+ return x
392
+
393
+
394
+ class Compress(nn.Module):
395
+ """
396
+ Up scaling step in the encoder decoder network
397
+ """
398
+ def __init__(self, in_channels: int, out_channels: int = 1, kernel_size: int = 1, scale_factor: int = 1) -> None:
399
+ """Constructor
400
+
401
+ Arguments:
402
+ in_channels {int} -- [description]
403
+
404
+ Keyword Arguments:
405
+ out_channels {int} -- [description] (default: {1})
406
+ kernel_size {int} -- [description] (default: {1})
407
+ """
408
+ super(Compress, self).__init__()
409
+
410
+ self.scale_factor = scale_factor
411
+
412
+ self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
413
+
414
+ def forward(self, x: torch.tensor) -> torch.tensor:
415
+ """Run the forward step
416
+
417
+ Arguments:
418
+ x {torch.tensor} -- input tensor
419
+
420
+ Returns:
421
+ torch.tensor -- output tensor
422
+ """
423
+ if isinstance(x, tuple) or isinstance(x, list):
424
+ x = x[0]
425
+
426
+ x = self.c(x)
427
+
428
+ if self.scale_factor != 1:
429
+ x = F.interpolate(x, scale_factor=self.scale_factor)
430
+
431
+ return x
432
+
433
+
434
+ class DownBlock(nn.Module):
435
+ def __init__(
436
+ self,
437
+ in_chan: int, inter_chan: int, out_chan: int,
438
+ kernel_size: int = 3, stride: int = 1, pool: tuple = None,
439
+ push: bool = False,
440
+ layers: int = 3):
441
+ super().__init__()
442
+
443
+ self.s = []
444
+ for i in range(layers):
445
+ self.s.append(deepcopy(DownStep(
446
+ in_chan if i == 0 else inter_chan,
447
+ inter_chan if i < layers - 1 else out_chan,
448
+ kernel_size,
449
+ 1 if i < layers - 1 else stride,
450
+ None if i < layers - 1 else pool)))
451
+ self.s = nn.Sequential(*self.s)
452
+
453
+ self.push = push
454
+
455
+ def forward(self, x: torch.tensor) -> torch.tensor:
456
+ i, s = x
457
+
458
+ i = self.s(i)
459
+
460
+ if self.push:
461
+ s.append(i)
462
+
463
+ return i, s
464
+
465
+
466
+ class UpBlock(nn.Module):
467
+ def __init__(
468
+ self,
469
+ in_chan: int, out_chan: int,
470
+ kernel_size: int, scale_factor: int = 2,
471
+ pop: bool = False,
472
+ layers: int = 3):
473
+ super().__init__()
474
+
475
+ self.s = []
476
+ for i in range(layers):
477
+ self.s.append(deepcopy(UpStep(
478
+ in_chan if i == 0 else out_chan,
479
+ out_chan,
480
+ kernel_size,
481
+ 1 if i < layers - 1 else scale_factor)))
482
+ self.s = nn.Sequential(*self.s)
483
+
484
+ self.pop = pop
485
+
486
+ def forward(self, x: torch.tensor) -> torch.tensor:
487
+ i, s = x
488
+
489
+ if self.pop:
490
+ i = torch.cat((i, s.pop()), dim=1)
491
+
492
+ i = self.s(i)
493
+
494
+ return i, s
run_logger.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Support log functions
3
+
4
+ TODO: log model using mlflow.pytorch in parallel / addition to checkpointing
5
+ """
6
+
7
+ import numpy as np
8
+ import h5py
9
+ import os
10
+ import argparse
11
+
12
+ import torch
13
+ import torchvision.utils as vutils
14
+ import pytorch_lightning as pl
15
+
16
+
17
+ class ImgCB(pl.Callback):
18
+ def __init__(self, **kwargs):
19
+ parser = ImgCB.add_argparse_args()
20
+ for action in parser._actions:
21
+ if action.dest in kwargs:
22
+ action.default = kwargs[action.dest]
23
+ args = parser.parse_args([])
24
+ self.__dict__.update(vars(args))
25
+
26
+ @staticmethod
27
+ def add_argparse_args(parent_parser=None):
28
+ parser = argparse.ArgumentParser(
29
+ prog='ImgCB',
30
+ usage=ImgCB.__doc__,
31
+ parents=[parent_parser] if parent_parser is not None else [],
32
+ add_help=False)
33
+
34
+ parser.add_argument('--img_ranges', default=[1300, 1800], nargs='*', help='Scaling range on output image, either pair, or set of pairs')
35
+ parser.add_argument('--err_ranges', default=[0, 50], nargs='*', help='Scaling range on error images, either pair, or set of pairs')
36
+
37
+ return parser
38
+
39
+ def log_images(self, mfl_logger, y, z, prefix):
40
+ img_ranges = tuple(self.img_ranges)
41
+ err_ranges = tuple(self.err_ranges)
42
+ #
43
+ for i in range(y.shape[1]):
44
+ if y.shape[1] > 1:
45
+ tag = f'_{i}_'
46
+
47
+ if len(self.img_ranges) > 2:
48
+ img_ranges = tuple(self.img_ranges[2*i, 2*i + 1])
49
+ if len(self.err_ranges) > 2:
50
+ err_ranges = tuple(self.err_ranges[2*i, 2*i + 1])
51
+ else:
52
+ tag = ''
53
+
54
+ mfl_logger.experiment.log_image(
55
+ mfl_logger.run_id,
56
+ (np.array(vutils.make_grid(
57
+ y[:, [i], ...].detach(),
58
+ normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
59
+ prefix + tag + '_labels.png')
60
+
61
+ mfl_logger.experiment.log_image(
62
+ mfl_logger.run_id,
63
+ (np.array(vutils.make_grid(
64
+ z[:, [i], ...].detach(),
65
+ normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
66
+ prefix + tag + '_outputs.png')
67
+
68
+ mfl_logger.experiment.log_image(
69
+ mfl_logger.run_id,
70
+ (np.array(vutils.make_grid(
71
+ torch.abs(y[:, [i], ...].detach() - z[:, [i], ...].detach()),
72
+ normalize=True, value_range=err_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
73
+ prefix + tag + '_errors.png')
74
+
75
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
76
+ if batch_idx == 0:
77
+ with torch.no_grad():
78
+ x, y = batch
79
+
80
+ if pl_module.hparams.rand_output_crop:
81
+ x = x[..., :-pl_module.hparams.rand_output_crop, :]
82
+ y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
83
+
84
+ z = pl_module(x.to(pl_module.device))
85
+
86
+ if isinstance(z, tuple) or isinstance(z, list):
87
+ z = z[0]
88
+
89
+ self.log_images(pl_module.logger, y.to(pl_module.device), z, 'train_')
90
+
91
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
92
+ if batch_idx == 0:
93
+ with torch.no_grad():
94
+ x, y = batch
95
+
96
+ if pl_module.hparams.rand_output_crop:
97
+ x = x[..., :-pl_module.hparams.rand_output_crop, :]
98
+ y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
99
+
100
+ z = pl_module(x.to(pl_module.device))
101
+
102
+ if isinstance(z, tuple) or isinstance(z, list):
103
+ z = z[0]
104
+
105
+ self.log_images(pl_module.logger, y.to(pl_module.device), z, 'validate_')
106
+
107
+
108
+ class TestLogger(pl.Callback):
109
+ """
110
+ pytorch_lightning Data saving logger for testing output
111
+ Warning !!! : this function is not multi GPU / multi device safe -- only run on a single gpu / device
112
+ """
113
+ def __init__(self, fname: str = 'output.h5'):
114
+ self.fname = fname
115
+
116
+ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
117
+ with h5py.File(self.fname, 'a') as f:
118
+ f[f'batch_{batch_idx:05}'] = outputs.to('cpu').numpy()
119
+ if len(batch) > 1:
120
+ f[f'labels_{batch_idx:05}'] = batch[1].to('cpu').numpy()