""" Alle transforms sind grundsätzlich auf batches bezogen! Vae transforms sind invertierbar """ import pickle from dataclasses import dataclass from functools import partial, reduce, wraps import numpy as np import torch # Allgemeine Funktionen ------------------------------------------------------------- # Transformations in Pytorch sind am einfachsten. def load(p): with open(p, "rb") as stream: return pickle.load(stream) def save(obj, p): with open(p, "wb") as stream: pickle.dump(obj, stream) def sequential_function(*functions): return lambda x: reduce(lambda res, func: func(res), functions, x) def np_sample(func): rtn = sequential_function( lambda x: torch.from_numpy(x).float(), lambda x: torch.unsqueeze(x, 0), func, lambda x: x[0].numpy(), ) return rtn # Inverseabvle class SequentialInversable(torch.nn.Sequential): def __init__(self, *functions): super().__init__(*functions) self.inv_funcs = [f.inv for f in functions] self.inv_funcs.reverse() # def forward(self, x): # return sequential_function(*self.functions)(x) def inv(self, x): return sequential_function(*self.inv_funcs)(x) class LatentSelector(torch.nn.Module): """Verarbeitet Tensoren und numpy arrays""" def __init__(self, ldim: int, selectdim: int): super().__init__() self.ldim = ldim self.selectdim = selectdim def forward(self, x: torch.Tensor): return x[:, : self.selectdim] def inv(self, x: torch.Tensor): rtn = torch.cat( [x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)], dim=1, ) return rtn class MinMaxScaler(torch.nn.Module): #! Bei mehreren Signalen vorsicht mit dem Broadcasting. def __init__( self, _min: torch.Tensor, _max: torch.Tensor, min_norm: float = 0.0, max_norm: float = 1.0, ): super().__init__() self._min = _min self._max = _max self.min_norm = min_norm self.max_norm = max_norm def forward(self, ts): """None, no_signals""" std = (ts - self._min) / (self._max - self._min) rtn = std * (self.max_norm - self.min_norm) + self.min_norm return rtn def inv(self, ts): std = (ts - self.min_norm) / (self.max_norm - self.min_norm) rtn = std * (self._max - self._min) + self._min return rtn @classmethod def from_array(cls, arr: torch.Tensor): _min = torch.min(arr, axis=0).values _max = torch.max(arr, axis=0).values return cls(_min, _max) class LatentSorter(torch.nn.Module): def __init__(self, kl_dict: dict): super().__init__() self.kl_dict = kl_dict def forward(self, latent): """ unsorted -> sorted latent: (None, latent_dim) """ return latent[:, list(self.kl_dict.keys())] def inv(self, latent): keys = np.array(list(self.kl_dict.keys())) return latent[:, torch.from_numpy(keys.argsort())] @property def names(self): rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()] return rtn def apply_along_axis(function, x, axis: int = 0): return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis) # Eingangsshapes bleiben wie sie sind! class SumField(torch.nn.Module): """ time series: [idx, time_step, signal] image: [idx, signal, time_step, time_step] """ def forward(self, ts: torch.Tensor): """ts2img""" samples = ts.shape[0] time = ts.shape[1] channels = ts.shape[2] ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende ts = torch.reshape( ts, (samples * channels, time) ) # Zusammenfassen von Channel + idx #! TODO: Schleife besser lösen rtn = apply_along_axis(self._mtf_forward, ts, 0) rtn = torch.reshape(rtn, (samples, channels, time, time)) return rtn def inv(self, img: torch.Tensor): """img2ts""" rtn = torch.diagonal(img, dim1=2, dim2=3) rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen return rtn @staticmethod def _mtf_forward(ts): """For one dimensional time series ts""" return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2