""" Network definition file """ import torch import torch.nn as nn import torch.nn.functional as F from torchaudio.functional import lfilter from pytorch_lightning import LightningModule import numpy as np from scipy.signal import butter, gaussian from copy import deepcopy import argparse class Net(LightningModule): def __init__(self, **kwargs): super().__init__() parser = Net.add_model_specific_args() for action in parser._actions: if action.dest in kwargs: action.default = kwargs[action.dest] args = parser.parse_args([]) self.hparams.update(vars(args)) if not hasattr(self, f"_init_{self.hparams.net_type}_net"): raise ValueError(f"Unknown net type {self.hparams.net_type}") self._net = eval(f"self._init_{self.hparams.net_type}_net(n_inputs={self.hparams.n_inputs}, n_outputs={self.hparams.n_outputs})") if self.hparams.bias is not None: if hasattr(self.hparams.bias, "__iter__"): for i in range(len(self.hparams.bias)): self._net[-1].c.bias[i].data.fill_(self.hparams.bias[i]) else: self._net[-1].c.bias.data.fill_(self.hparams.bias) @staticmethod def _init_tbme2_net(n_inputs: int = 1, n_outputs: int = 1): return nn.Sequential( # Encoder DownBlock(n_inputs, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3), DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3), DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3), DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=True, layers=3), DownBlock(32, 32, 64, 3, stride=1, pool=[2, 2], push=True, layers=3), DownBlock(64, 64, 128, 3, stride=1, pool=[2, 2], push=True, layers=3), DownBlock(128, 128, 512, 3, stride=1, pool=[2, 2], push=False, layers=3), # Decoder UpBlock(512, 128, 3, scale_factor=2, pop=False, layers=3), UpBlock(256, 64, 3, scale_factor=2, pop=True, layers=3), UpBlock(128, 32, 3, scale_factor=2, pop=True, layers=3), UpBlock(64, 32, 3, scale_factor=2, pop=True, layers=3), UpStep(32, 32, 3, scale_factor=1), Compress(32, n_outputs)) @staticmethod def _init_embc_net(n_inputs: int = 1, n_outputs: int = 1): return nn.Sequential( # Encoder DownBlock(n_inputs, 32, 32, 15, [1, 2], None, layers=1), DownBlock(32, 32, 32, 13, [1, 2], None, layers=1), DownBlock(32, 32, 32, 11, [1, 2], None, layers=1), DownBlock(32, 32, 32, 9, [1, 2], None, True, layers=1), DownBlock(32, 32, 64, 7, 1, [2, 2], True, layers=1), DownBlock(64, 64, 128, 5, 1, [2, 2], True, layers=1), DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1), # Decoder UpBlock(512, 128, 5, 2, layers=1), UpBlock(256, 64, 7, 2, True, layers=1), UpBlock(128, 32, 9, 2, True, layers=1), UpBlock(64, 32, 11, 2, True, layers=1), UpStep(32, 32, 3, 1), Compress(32, n_outputs)) @staticmethod def _init_tbme_net(n_inputs: int = 1, n_outputs: int = 1): return nn.Sequential( # Encoder DownBlock(n_inputs, 32, 32, 3, [1, 2], None, layers=1), DownBlock(32, 32, 32, 3, [1, 2], None, layers=1), DownBlock(32, 32, 32, 3, [1, 2], None, layers=1), DownBlock(32, 32, 32, 3, [1, 2], None, True, layers=1), DownBlock(32, 32, 64, 3, 1, [2, 2], True, layers=1), DownBlock(64, 64, 128, 3, 1, [2, 2], True, layers=1), DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1), # Decoder UpBlock(512, 128, 3, 2, layers=1), UpBlock(256, 64, 3, 2, True, layers=1), UpBlock(128, 32, 3, 2, True, layers=1), UpBlock(64, 32, 3, 2, True, layers=1), UpStep(32, 32, 3, 1), Compress(32, n_outputs)) @staticmethod def add_model_specific_args(parent_parser=None): parser = argparse.ArgumentParser( prog="Net", usage=Net.__doc__, parents=[parent_parser] if parent_parser is not None else [], add_help=False) parser.add_argument("--random_mirror", type=int, nargs="?", default=1, help="Randomly mirror data to increase diversity when using flat plate wave") parser.add_argument("--noise_std", type=float, nargs="*", help="range of std of random noise to add to the input signal [0 val] or [min max]") parser.add_argument("--quantization", type=float, nargs="?", help="Quantization noise") parser.add_argument("--rand_drop", type=int, nargs="*", help="Random drop lines, between 0 and value lines if single value, or between two values") parser.add_argument("--normalize_net", type=float, default=0.0, help="Coefficient for normalizing network weights") parser.add_argument("--learning_rate", type=float, default=5e-3, help="Learning rate to use for optimizer") parser.add_argument("--lr_sched_step", type=int, default=15, help="Learning decay, update step size") parser.add_argument("--lr_sched_gamma", type=float, default=0.65, help="Learning decay gamma") parser.add_argument("--net_type", default="tbme2", help="The network to use [tbme2/embc/tbme]") parser.add_argument("--bias", type=float, nargs="*", help="Set bias on last layer, set to 1500 when training from scratch on SoS output") parser.add_argument("--decimation", type=int, help="Subsample phase signal") parser.add_argument("--phase_inv", type=int, default=0, help="Use phase for inversion") parser.add_argument("--center_freq", type=float, default=5e6, help="Matched filter and IQ demodulation frequency") parser.add_argument("--n_periods", type=float, default=5, help="Matched filter length") parser.add_argument("--matched_filter", type=int, nargs="?", default=0, help="Apply matched filter, set to 1 to run during forward pass, 2 to run during preprocessing phase (before adding noise)") parser.add_argument("--rand_output_crop", type=int, help="Subsample phase signal") parser.add_argument("--rand_scale", type=float, nargs="*", help="Random scaling range [min max] -- (10 ** rand_scale)") parser.add_argument("--rand_gain", type=float, nargs="*", help="Random gain coefficient range [min max] -- (10 ** rand_gain)") parser.add_argument("--n_inputs", type=int, default=1, help="Number of input layers") parser.add_argument("--n_outputs", type=int, default=1, help="Number of output layers") parser.add_argument("--scale_losses", type=float, nargs="*", help="Scale each layer of the loss function by given value") return parser def forward(self, x) -> torch.Tensor: # Matched filter if self.hparams.matched_filter == 1: x = self._matched_filter(x) # compute IQ phase if in phase_inv mode if self.hparams.phase_inv: x = self._phase(x) # Decimation if self.hparams.decimation != 1: x = x[..., ::self.hparams.decimation] # Apply network x = self._net((x, [])) return x def _matched_filter(self, x): sampling_freq = 40e6 samples_per_cycle = sampling_freq / self.hparams.center_freq n_samples = np.ceil(samples_per_cycle * self.hparams.n_periods + 1) signal = torch.sin(torch.arange(n_samples, device=x.device) / samples_per_cycle * 2 * np.pi) * torch.from_numpy(gaussian(n_samples, (n_samples - 1) / 6).astype(np.single)).to(x.device) return torch.nn.functional.conv1d(x.reshape(x.shape[:2] + (-1,)), signal.reshape(1, 1, -1), padding="same").reshape(x.shape) def _phase(self, x): f = self.hparams.center_freq F = 40e6 N = x.shape[-1] n = int(round(f * N / F)) X = torch.fft.fft(x, dim=-1) X[..., (2 * n + 1):] = 0 X[..., :(2 * n + 1)] *= torch.from_numpy(gaussian(2 * n + 1, 2 * n / 6).astype(np.single)).to(x.device) X = X.roll(-n, dims=-1) x = torch.fft.ifft(X, dim=-1) return x.angle() def _preprocess(self, x): # Matched filter if self.hparams.matched_filter == 2: x = self._matched_filter(x) # Gaussian (normal) noise - random scaling, normalized to signal STD if (ns := self.hparams.noise_std) and len(ns): scl = ns[0] if len(ns) == 1 else torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (ns[-1] - ns[-2]) + ns[-2] scl *= x.std() x += torch.empty_like(x).normal_() * scl # Random multiplicative scaling if (rs := self.hparams.rand_scale) and len(rs): x *= 10 ** (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (rs[-1] - rs[-2]) + rs[-2]) # Random exponential gain if (gs := self.hparams.rand_gain) and len(gs): gain = torch.FloatTensor([10.0]).to(x.device) ** \ (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * ((gs[-1] - gs[-2]) + gs[-2]) * torch.linspace(0, 1, x.shape[-1]).to(x.device).view(1, 1, 1, -1)) x *= gain # Quantization noise, to emulated ADC if (quantization := self.hparams.quantization) is not None: x = (x * quantization).round() * (1.0 / quantization) # Randomly zero out some of the channels if (rand_drop := self.hparams.rand_drop) and len(rand_drop): if len(rand_drop) == 1: rand_drop = [0, ] + rand_drop for i in range(x.shape[0]): lines = np.random.randint(0, x.shape[2], np.random.randint(rand_drop[0], rand_drop[1] + 1)) x[i, :, lines, :] = 0. return x def _log_losses(self, outputs: torch.Tensor, labels: torch.Tensor, prefix: str = ""): diff = torch.abs(labels.detach() - outputs.detach()) s1 = int(diff.shape[-1] * (1.0 / 3.0)) s2 = int(diff.shape[-1] * (2.0 / 3.0)) for i in range(diff.shape[1]): tag = f"{i}_" if diff.shape[1] > 1 else "" losses = { f"{prefix + tag}rmse": torch.sqrt(torch.mean(diff[:, i, ...] * diff[:, i, ...])).item(), f"{prefix + tag}mean": torch.mean(diff[:, i, ...]).item(), f"{prefix + tag}short": torch.mean(diff[:, i, :, :s1]).item(), f"{prefix + tag}med": torch.mean(diff[:, i, :, s1:s2]).item(), f"{prefix + tag}long": torch.mean(diff[:, i, :, s2:]).item()} self.log_dict(losses, prog_bar=True) def training_step(self, batch, batch_idx): if self.hparams.random_mirror: mirror = np.random.randint(0, 2, batch[0].shape[0]) for b in batch: for i, m in enumerate(mirror): if not m: continue b[i, ...] = b[i, :, range(b.shape[-2] - 1, -1, -1), :] # Pytorch does not handle negative steps loss = self._common_step(batch, batch_idx, "train_") if self.hparams.normalize_net: for W in self.parameters(): loss += self.hparams.normalize_net * W.norm(2) return loss def validation_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "validate_") def test_step(self, batch, batch_idx): return self._common_step(batch, batch_idx, "test_") def predict_step(self, batch, batch_idx): x = batch[0] x = self._preprocess(x) z = self(x) if isinstance(z, tuple): z = z[0] return z def _common_step(self, batch, batch_idx, prefix): x, y = batch if self.hparams.rand_output_crop: crop = np.random.randint(0, self.hparams.rand_output_crop, batch[0].shape[0]) for i, c in enumerate(crop): if not c: continue x[i, :, :-c, :] = x[i, :, c:, :].clone() y[i, :, :-c*2, :] = \ y[i, :, c*2-1:-1, :].clone() if np.random.randint(2) else \ y[i, :, c*2:, :].clone() x = x[..., :-self.hparams.rand_output_crop, :] y = y[..., :-self.hparams.rand_output_crop*2, :] x = self._preprocess(x) z = self(x) outputs = z[0] if isinstance(z, tuple) or isinstance(z, list) else z self._log_losses(outputs, y, prefix) if (self.hparams.scale_losses) and len(self.hparams.scale_losses): s = torch.FloatTensor(self.hparams.scale_losses).to(y.device).view(1, -1, 1, 1) loss = F.mse_loss(s * z, s * y) else: loss = F.mse_loss(y, outputs) self.log(prefix + "loss", np.sqrt(loss.item())) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.lr_sched_step, self.hparams.lr_sched_gamma) return [optimizer], [scheduler] class DownStep(nn.Module): """ Down scaling step in the encoder decoder network """ def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: int = 1, pool: tuple = None) -> None: """Constructor Arguments: in_channels {int} -- Number of input channels for 2D convolution out_channels {int} -- Number of output channels for 2D convolution kernel_size {tuple} -- Convolution kernel size Keyword Arguments: stride {int} -- Stride of convolution, set to 1 to disable (default: {1}) pool {tuple} -- max pulling size, set to None to disable (default: {None}) """ super(DownStep, self).__init__() self.c = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2) self.n = nn.BatchNorm2d(out_channels) self.pool = pool def forward(self, x: torch.tensor) -> torch.tensor: """Run the forward step Arguments: x {torch.tensor} -- input tensor Returns: torch.tensor -- output tensor """ x = self.c(x) x = F.relu(x) if self.pool is not None: x = F.max_pool2d(x, self.pool) x = self.n(x) return x class UpStep(nn.Module): """ Up scaling step in the encoder decoder network """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, scale_factor: int = 2) -> None: """Constructor Arguments: in_channels {int} -- Number of input channels for 2D convolution out_channels {int} -- Number of output channels for 2D convolution kernel_size {int} -- Convolution kernel size Keyword Arguments: scale_factor {int} -- Upsampling scaling factor (default: {2}) """ super(UpStep, self).__init__() self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2) self.n = nn.BatchNorm2d(out_channels) self.scale_factor = scale_factor def forward(self, x: torch.tensor) -> torch.tensor: """Run the forward step Arguments: x {torch.tensor} -- input tensor Returns: torch.tensor -- output tensor """ if isinstance(x, tuple): x = x[0] if self.scale_factor != 1: x = F.interpolate(x, scale_factor=self.scale_factor) x = self.c(x) x = F.relu(x) x = self.n(x) return x class Compress(nn.Module): """ Up scaling step in the encoder decoder network """ def __init__(self, in_channels: int, out_channels: int = 1, kernel_size: int = 1, scale_factor: int = 1) -> None: """Constructor Arguments: in_channels {int} -- [description] Keyword Arguments: out_channels {int} -- [description] (default: {1}) kernel_size {int} -- [description] (default: {1}) """ super(Compress, self).__init__() self.scale_factor = scale_factor self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2) def forward(self, x: torch.tensor) -> torch.tensor: """Run the forward step Arguments: x {torch.tensor} -- input tensor Returns: torch.tensor -- output tensor """ if isinstance(x, tuple) or isinstance(x, list): x = x[0] x = self.c(x) if self.scale_factor != 1: x = F.interpolate(x, scale_factor=self.scale_factor) return x class DownBlock(nn.Module): def __init__( self, in_chan: int, inter_chan: int, out_chan: int, kernel_size: int = 3, stride: int = 1, pool: tuple = None, push: bool = False, layers: int = 3): super().__init__() self.s = [] for i in range(layers): self.s.append(deepcopy(DownStep( in_chan if i == 0 else inter_chan, inter_chan if i < layers - 1 else out_chan, kernel_size, 1 if i < layers - 1 else stride, None if i < layers - 1 else pool))) self.s = nn.Sequential(*self.s) self.push = push def forward(self, x: torch.tensor) -> torch.tensor: i, s = x i = self.s(i) if self.push: s.append(i) return i, s class UpBlock(nn.Module): def __init__( self, in_chan: int, out_chan: int, kernel_size: int, scale_factor: int = 2, pop: bool = False, layers: int = 3): super().__init__() self.s = [] for i in range(layers): self.s.append(deepcopy(UpStep( in_chan if i == 0 else out_chan, out_chan, kernel_size, 1 if i < layers - 1 else scale_factor))) self.s = nn.Sequential(*self.s) self.pop = pop def forward(self, x: torch.tensor) -> torch.tensor: i, s = x if self.pop: i = torch.cat((i, s.pop()), dim=1) i = self.s(i) return i, s