medical imaging
ultrasound
laughingrice's picture
Upload 11 files
6ce7d82
raw
history blame contribute delete
No virus
18.7 kB
"""
Network definition file
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter
from pytorch_lightning import LightningModule
import numpy as np
from scipy.signal import butter, gaussian
from copy import deepcopy
import argparse
class Net(LightningModule):
def __init__(self, **kwargs):
super().__init__()
parser = Net.add_model_specific_args()
for action in parser._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
args = parser.parse_args([])
self.hparams.update(vars(args))
if not hasattr(self, f"_init_{self.hparams.net_type}_net"):
raise ValueError(f"Unknown net type {self.hparams.net_type}")
self._net = eval(f"self._init_{self.hparams.net_type}_net(n_inputs={self.hparams.n_inputs}, n_outputs={self.hparams.n_outputs})")
if self.hparams.bias is not None:
if hasattr(self.hparams.bias, "__iter__"):
for i in range(len(self.hparams.bias)):
self._net[-1].c.bias[i].data.fill_(self.hparams.bias[i])
else:
self._net[-1].c.bias.data.fill_(self.hparams.bias)
@staticmethod
def _init_tbme2_net(n_inputs: int = 1, n_outputs: int = 1):
return nn.Sequential(
# Encoder
DownBlock(n_inputs, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=True, layers=3),
DownBlock(32, 32, 64, 3, stride=1, pool=[2, 2], push=True, layers=3),
DownBlock(64, 64, 128, 3, stride=1, pool=[2, 2], push=True, layers=3),
DownBlock(128, 128, 512, 3, stride=1, pool=[2, 2], push=False, layers=3),
# Decoder
UpBlock(512, 128, 3, scale_factor=2, pop=False, layers=3),
UpBlock(256, 64, 3, scale_factor=2, pop=True, layers=3),
UpBlock(128, 32, 3, scale_factor=2, pop=True, layers=3),
UpBlock(64, 32, 3, scale_factor=2, pop=True, layers=3),
UpStep(32, 32, 3, scale_factor=1),
Compress(32, n_outputs))
@staticmethod
def _init_embc_net(n_inputs: int = 1, n_outputs: int = 1):
return nn.Sequential(
# Encoder
DownBlock(n_inputs, 32, 32, 15, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 13, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 11, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 9, [1, 2], None, True, layers=1),
DownBlock(32, 32, 64, 7, 1, [2, 2], True, layers=1),
DownBlock(64, 64, 128, 5, 1, [2, 2], True, layers=1),
DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
# Decoder
UpBlock(512, 128, 5, 2, layers=1),
UpBlock(256, 64, 7, 2, True, layers=1),
UpBlock(128, 32, 9, 2, True, layers=1),
UpBlock(64, 32, 11, 2, True, layers=1),
UpStep(32, 32, 3, 1),
Compress(32, n_outputs))
@staticmethod
def _init_tbme_net(n_inputs: int = 1, n_outputs: int = 1):
return nn.Sequential(
# Encoder
DownBlock(n_inputs, 32, 32, 3, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 3, [1, 2], None, True, layers=1),
DownBlock(32, 32, 64, 3, 1, [2, 2], True, layers=1),
DownBlock(64, 64, 128, 3, 1, [2, 2], True, layers=1),
DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
# Decoder
UpBlock(512, 128, 3, 2, layers=1),
UpBlock(256, 64, 3, 2, True, layers=1),
UpBlock(128, 32, 3, 2, True, layers=1),
UpBlock(64, 32, 3, 2, True, layers=1),
UpStep(32, 32, 3, 1),
Compress(32, n_outputs))
@staticmethod
def add_model_specific_args(parent_parser=None):
parser = argparse.ArgumentParser(
prog="Net",
usage=Net.__doc__,
parents=[parent_parser] if parent_parser is not None else [],
add_help=False)
parser.add_argument("--random_mirror", type=int, nargs="?", default=1, help="Randomly mirror data to increase diversity when using flat plate wave")
parser.add_argument("--noise_std", type=float, nargs="*", help="range of std of random noise to add to the input signal [0 val] or [min max]")
parser.add_argument("--quantization", type=float, nargs="?", help="Quantization noise")
parser.add_argument("--rand_drop", type=int, nargs="*", help="Random drop lines, between 0 and value lines if single value, or between two values")
parser.add_argument("--normalize_net", type=float, default=0.0, help="Coefficient for normalizing network weights")
parser.add_argument("--learning_rate", type=float, default=5e-3, help="Learning rate to use for optimizer")
parser.add_argument("--lr_sched_step", type=int, default=15, help="Learning decay, update step size")
parser.add_argument("--lr_sched_gamma", type=float, default=0.65, help="Learning decay gamma")
parser.add_argument("--net_type", default="tbme2", help="The network to use [tbme2/embc/tbme]")
parser.add_argument("--bias", type=float, nargs="*", help="Set bias on last layer, set to 1500 when training from scratch on SoS output")
parser.add_argument("--decimation", type=int, help="Subsample phase signal")
parser.add_argument("--phase_inv", type=int, default=0, help="Use phase for inversion")
parser.add_argument("--center_freq", type=float, default=5e6, help="Matched filter and IQ demodulation frequency")
parser.add_argument("--n_periods", type=float, default=5, help="Matched filter length")
parser.add_argument("--matched_filter", type=int, nargs="?", default=0, help="Apply matched filter, set to 1 to run during forward pass, 2 to run during preprocessing phase (before adding noise)")
parser.add_argument("--rand_output_crop", type=int, help="Subsample phase signal")
parser.add_argument("--rand_scale", type=float, nargs="*", help="Random scaling range [min max] -- (10 ** rand_scale)")
parser.add_argument("--rand_gain", type=float, nargs="*", help="Random gain coefficient range [min max] -- (10 ** rand_gain)")
parser.add_argument("--n_inputs", type=int, default=1, help="Number of input layers")
parser.add_argument("--n_outputs", type=int, default=1, help="Number of output layers")
parser.add_argument("--scale_losses", type=float, nargs="*", help="Scale each layer of the loss function by given value")
return parser
def forward(self, x) -> torch.Tensor:
# Matched filter
if self.hparams.matched_filter == 1:
x = self._matched_filter(x)
# compute IQ phase if in phase_inv mode
if self.hparams.phase_inv:
x = self._phase(x)
# Decimation
if self.hparams.decimation != 1:
x = x[..., ::self.hparams.decimation]
# Apply network
x = self._net((x, []))
return x
def _matched_filter(self, x):
sampling_freq = 40e6
samples_per_cycle = sampling_freq / self.hparams.center_freq
n_samples = np.ceil(samples_per_cycle * self.hparams.n_periods + 1)
signal = torch.sin(torch.arange(n_samples, device=x.device) / samples_per_cycle * 2 * np.pi) * torch.from_numpy(gaussian(n_samples, (n_samples - 1) / 6).astype(np.single)).to(x.device)
return torch.nn.functional.conv1d(x.reshape(x.shape[:2] + (-1,)), signal.reshape(1, 1, -1), padding="same").reshape(x.shape)
def _phase(self, x):
f = self.hparams.center_freq
F = 40e6
N = x.shape[-1]
n = int(round(f * N / F))
X = torch.fft.fft(x, dim=-1)
X[..., (2 * n + 1):] = 0
X[..., :(2 * n + 1)] *= torch.from_numpy(gaussian(2 * n + 1, 2 * n / 6).astype(np.single)).to(x.device)
X = X.roll(-n, dims=-1)
x = torch.fft.ifft(X, dim=-1)
return x.angle()
def _preprocess(self, x):
# Matched filter
if self.hparams.matched_filter == 2:
x = self._matched_filter(x)
# Gaussian (normal) noise - random scaling, normalized to signal STD
if (ns := self.hparams.noise_std) and len(ns):
scl = ns[0] if len(ns) == 1 else torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (ns[-1] - ns[-2]) + ns[-2]
scl *= x.std()
x += torch.empty_like(x).normal_() * scl
# Random multiplicative scaling
if (rs := self.hparams.rand_scale) and len(rs):
x *= 10 ** (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (rs[-1] - rs[-2]) + rs[-2])
# Random exponential gain
if (gs := self.hparams.rand_gain) and len(gs):
gain = torch.FloatTensor([10.0]).to(x.device) ** \
(torch.rand([x.shape[0]] + [1] * 3).to(x.device) * ((gs[-1] - gs[-2]) + gs[-2]) *
torch.linspace(0, 1, x.shape[-1]).to(x.device).view(1, 1, 1, -1))
x *= gain
# Quantization noise, to emulated ADC
if (quantization := self.hparams.quantization) is not None:
x = (x * quantization).round() * (1.0 / quantization)
# Randomly zero out some of the channels
if (rand_drop := self.hparams.rand_drop) and len(rand_drop):
if len(rand_drop) == 1:
rand_drop = [0, ] + rand_drop
for i in range(x.shape[0]):
lines = np.random.randint(0, x.shape[2], np.random.randint(rand_drop[0], rand_drop[1] + 1))
x[i, :, lines, :] = 0.
return x
def _log_losses(self, outputs: torch.Tensor, labels: torch.Tensor, prefix: str = ""):
diff = torch.abs(labels.detach() - outputs.detach())
s1 = int(diff.shape[-1] * (1.0 / 3.0))
s2 = int(diff.shape[-1] * (2.0 / 3.0))
for i in range(diff.shape[1]):
tag = f"{i}_" if diff.shape[1] > 1 else ""
losses = {
f"{prefix + tag}rmse": torch.sqrt(torch.mean(diff[:, i, ...] * diff[:, i, ...])).item(),
f"{prefix + tag}mean": torch.mean(diff[:, i, ...]).item(),
f"{prefix + tag}short": torch.mean(diff[:, i, :, :s1]).item(),
f"{prefix + tag}med": torch.mean(diff[:, i, :, s1:s2]).item(),
f"{prefix + tag}long": torch.mean(diff[:, i, :, s2:]).item()}
self.log_dict(losses, prog_bar=True)
def training_step(self, batch, batch_idx):
if self.hparams.random_mirror:
mirror = np.random.randint(0, 2, batch[0].shape[0])
for b in batch:
for i, m in enumerate(mirror):
if not m:
continue
b[i, ...] = b[i, :, range(b.shape[-2] - 1, -1, -1), :] # Pytorch does not handle negative steps
loss = self._common_step(batch, batch_idx, "train_")
if self.hparams.normalize_net:
for W in self.parameters():
loss += self.hparams.normalize_net * W.norm(2)
return loss
def validation_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "validate_")
def test_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "test_")
def predict_step(self, batch, batch_idx):
x = batch[0]
x = self._preprocess(x)
z = self(x)
if isinstance(z, tuple):
z = z[0]
return z
def _common_step(self, batch, batch_idx, prefix):
x, y = batch
if self.hparams.rand_output_crop:
crop = np.random.randint(0, self.hparams.rand_output_crop, batch[0].shape[0])
for i, c in enumerate(crop):
if not c:
continue
x[i, :, :-c, :] = x[i, :, c:, :].clone()
y[i, :, :-c*2, :] = \
y[i, :, c*2-1:-1, :].clone() if np.random.randint(2) else \
y[i, :, c*2:, :].clone()
x = x[..., :-self.hparams.rand_output_crop, :]
y = y[..., :-self.hparams.rand_output_crop*2, :]
x = self._preprocess(x)
z = self(x)
outputs = z[0] if isinstance(z, tuple) or isinstance(z, list) else z
self._log_losses(outputs, y, prefix)
if (self.hparams.scale_losses) and len(self.hparams.scale_losses):
s = torch.FloatTensor(self.hparams.scale_losses).to(y.device).view(1, -1, 1, 1)
loss = F.mse_loss(s * z, s * y)
else:
loss = F.mse_loss(y, outputs)
self.log(prefix + "loss", np.sqrt(loss.item()))
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.lr_sched_step, self.hparams.lr_sched_gamma)
return [optimizer], [scheduler]
class DownStep(nn.Module):
"""
Down scaling step in the encoder decoder network
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: int = 1, pool: tuple = None) -> None:
"""Constructor
Arguments:
in_channels {int} -- Number of input channels for 2D convolution
out_channels {int} -- Number of output channels for 2D convolution
kernel_size {tuple} -- Convolution kernel size
Keyword Arguments:
stride {int} -- Stride of convolution, set to 1 to disable (default: {1})
pool {tuple} -- max pulling size, set to None to disable (default: {None})
"""
super(DownStep, self).__init__()
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
self.n = nn.BatchNorm2d(out_channels)
self.pool = pool
def forward(self, x: torch.tensor) -> torch.tensor:
"""Run the forward step
Arguments:
x {torch.tensor} -- input tensor
Returns:
torch.tensor -- output tensor
"""
x = self.c(x)
x = F.relu(x)
if self.pool is not None:
x = F.max_pool2d(x, self.pool)
x = self.n(x)
return x
class UpStep(nn.Module):
"""
Up scaling step in the encoder decoder network
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, scale_factor: int = 2) -> None:
"""Constructor
Arguments:
in_channels {int} -- Number of input channels for 2D convolution
out_channels {int} -- Number of output channels for 2D convolution
kernel_size {int} -- Convolution kernel size
Keyword Arguments:
scale_factor {int} -- Upsampling scaling factor (default: {2})
"""
super(UpStep, self).__init__()
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.n = nn.BatchNorm2d(out_channels)
self.scale_factor = scale_factor
def forward(self, x: torch.tensor) -> torch.tensor:
"""Run the forward step
Arguments:
x {torch.tensor} -- input tensor
Returns:
torch.tensor -- output tensor
"""
if isinstance(x, tuple):
x = x[0]
if self.scale_factor != 1:
x = F.interpolate(x, scale_factor=self.scale_factor)
x = self.c(x)
x = F.relu(x)
x = self.n(x)
return x
class Compress(nn.Module):
"""
Up scaling step in the encoder decoder network
"""
def __init__(self, in_channels: int, out_channels: int = 1, kernel_size: int = 1, scale_factor: int = 1) -> None:
"""Constructor
Arguments:
in_channels {int} -- [description]
Keyword Arguments:
out_channels {int} -- [description] (default: {1})
kernel_size {int} -- [description] (default: {1})
"""
super(Compress, self).__init__()
self.scale_factor = scale_factor
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
def forward(self, x: torch.tensor) -> torch.tensor:
"""Run the forward step
Arguments:
x {torch.tensor} -- input tensor
Returns:
torch.tensor -- output tensor
"""
if isinstance(x, tuple) or isinstance(x, list):
x = x[0]
x = self.c(x)
if self.scale_factor != 1:
x = F.interpolate(x, scale_factor=self.scale_factor)
return x
class DownBlock(nn.Module):
def __init__(
self,
in_chan: int, inter_chan: int, out_chan: int,
kernel_size: int = 3, stride: int = 1, pool: tuple = None,
push: bool = False,
layers: int = 3):
super().__init__()
self.s = []
for i in range(layers):
self.s.append(deepcopy(DownStep(
in_chan if i == 0 else inter_chan,
inter_chan if i < layers - 1 else out_chan,
kernel_size,
1 if i < layers - 1 else stride,
None if i < layers - 1 else pool)))
self.s = nn.Sequential(*self.s)
self.push = push
def forward(self, x: torch.tensor) -> torch.tensor:
i, s = x
i = self.s(i)
if self.push:
s.append(i)
return i, s
class UpBlock(nn.Module):
def __init__(
self,
in_chan: int, out_chan: int,
kernel_size: int, scale_factor: int = 2,
pop: bool = False,
layers: int = 3):
super().__init__()
self.s = []
for i in range(layers):
self.s.append(deepcopy(UpStep(
in_chan if i == 0 else out_chan,
out_chan,
kernel_size,
1 if i < layers - 1 else scale_factor)))
self.s = nn.Sequential(*self.s)
self.pop = pop
def forward(self, x: torch.tensor) -> torch.tensor:
i, s = x
if self.pop:
i = torch.cat((i, s.pop()), dim=1)
i = self.s(i)
return i, s