File size: 4,469 Bytes
7f19394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Alle transforms sind grundsätzlich auf batches bezogen!
Vae transforms sind invertierbar
"""
import pickle
from dataclasses import dataclass
from functools import partial, reduce, wraps

import numpy as np
import torch

# Allgemeine Funktionen -------------------------------------------------------------
# Transformations in Pytorch sind am einfachsten.


def load(p):
    with open(p, "rb") as stream:
        return pickle.load(stream)


def save(obj, p):
    with open(p, "wb") as stream:
        pickle.dump(obj, stream)


def sequential_function(*functions):
    return lambda x: reduce(lambda res, func: func(res), functions, x)


def np_sample(func):
    rtn = sequential_function(
        lambda x: torch.from_numpy(x).float(),
        lambda x: torch.unsqueeze(x, 0),
        func,
        lambda x: x[0].numpy(),
    )
    return rtn


# Inverseabvle
class SequentialInversable(torch.nn.Sequential):
    def __init__(self, *functions):
        super().__init__(*functions)

        self.inv_funcs = [f.inv for f in functions]
        self.inv_funcs.reverse()

    # def forward(self, x):
    #     return sequential_function(*self.functions)(x)

    def inv(self, x):
        return sequential_function(*self.inv_funcs)(x)


class LatentSelector(torch.nn.Module):
    """Verarbeitet Tensoren und numpy arrays"""

    def __init__(self, ldim: int, selectdim: int):
        super().__init__()
        self.ldim = ldim
        self.selectdim = selectdim

    def forward(self, x: torch.Tensor):
        return x[:, : self.selectdim]

    def inv(self, x: torch.Tensor):
        rtn = torch.cat(
            [x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)],
            dim=1,
        )
        return rtn


class MinMaxScaler(torch.nn.Module):
    #! Bei mehreren Signalen vorsicht mit dem Broadcasting.
    def __init__(
        self,
        _min: torch.Tensor,
        _max: torch.Tensor,
        min_norm: float = 0.0,
        max_norm: float = 1.0,
    ):
        super().__init__()
        self._min = _min
        self._max = _max
        self.min_norm = min_norm
        self.max_norm = max_norm

    def forward(self, ts):
        """None, no_signals"""
        std = (ts - self._min) / (self._max - self._min)
        rtn = std * (self.max_norm - self.min_norm) + self.min_norm
        return rtn

    def inv(self, ts):
        std = (ts - self.min_norm) / (self.max_norm - self.min_norm)
        rtn = std * (self._max - self._min) + self._min
        return rtn

    @classmethod
    def from_array(cls, arr: torch.Tensor):
        _min = torch.min(arr, axis=0).values
        _max = torch.max(arr, axis=0).values

        return cls(_min, _max)


class LatentSorter(torch.nn.Module):
    def __init__(self, kl_dict: dict):
        super().__init__()
        self.kl_dict = kl_dict

    def forward(self, latent):
        """
        unsorted -> sorted
        latent: (None, latent_dim)
        """
        return latent[:, list(self.kl_dict.keys())]

    def inv(self, latent):
        keys = np.array(list(self.kl_dict.keys()))
        return latent[:, torch.from_numpy(keys.argsort())]

    @property
    def names(self):
        rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()]
        return rtn


def apply_along_axis(function, x, axis: int = 0):
    return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis)


# Eingangsshapes bleiben wie sie sind!
class SumField(torch.nn.Module):
    """
    time series: [idx, time_step, signal]
    image: [idx, signal, time_step, time_step]
    """

    def forward(self, ts: torch.Tensor):
        """ts2img"""

        samples = ts.shape[0]
        time = ts.shape[1]
        channels = ts.shape[2]

        ts = torch.swapaxes(ts, 1, 2)  # Zeitachse ans Ende
        ts = torch.reshape(
            ts, (samples * channels, time)
        )  # Zusammenfassen von Channel + idx
        #! TODO: Schleife besser lösen
        rtn = apply_along_axis(self._mtf_forward, ts, 0)
        rtn = torch.reshape(rtn, (samples, channels, time, time))

        return rtn

    def inv(self, img: torch.Tensor):
        """img2ts"""
        rtn = torch.diagonal(img, dim1=2, dim2=3)
        rtn = torch.swapaxes(rtn, 1, 2)  # Channel und Zeitachse tauschen

        return rtn

    @staticmethod
    def _mtf_forward(ts):
        """For one dimensional time series ts"""
        return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2