File size: 13,913 Bytes
b9bac12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import os
import random
from collections import defaultdict

import dijkprofile_annotator.preprocessing as preprocessing
import dijkprofile_annotator.config as config
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.isotonic import IsotonicRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler


def extract_img(size, in_tensor):
    """
    Args:
        size(int) : size of cut
        in_tensor(tensor) : tensor to be cut
    """
    dim1 = in_tensor.size()[2]
    in_tensor = in_tensor[:, :, int((dim1-size)/2):int((size + (dim1-size)/2))]
    return in_tensor


def ffill(arr):
    """Forward fill utility function.

    Args:
        arr (np.array): numpy array to fill

    Returns:
        np.array: filled array.
    """
    mask = np.isnan(arr)
    idx = np.where(~mask, np.arange(mask.shape[1]), 0)
    np.maximum.accumulate(idx, axis=1, out=idx)
    out = arr[np.arange(idx.shape[0])[:,None], idx]
    return out

def train_scaler(profile_dict, scaler_type='minmax'):
    """Train a scaler given a profile dict

    Args:
        profile_dict (dict): dict containing the profile heights and labels

    Returns:
        sklearn MinMaxScaler or StandardScaler: fitted scaler in sklearn format
    """
    if scaler_type == 'minmax':
        scaler = MinMaxScaler(feature_range=(-1, 1))  # for neural networks -1,1 is better than 0,1
    elif scaler_type == 'standard':
        scaler = StandardScaler()
    else:
        raise NotImplementedError(f"no scaler: {scaler}")
    randkey = random.choice(list(profile_dict.keys()))
    accumulator = np.zeros((len(profile_dict), profile_dict[randkey]['profile'].shape[0]))

    for i, key in enumerate(profile_dict.keys()):
        accumulator[i, :] = profile_dict[key]['profile']

    scaler.fit(accumulator.reshape(-1, 1))
    return scaler


def get_class_dict(class_list):
    """Get correct class dicts and weights from config.

    Args:
        class_list (string): string representing the class mappings to use

    Raises:
        NotImplementedError: raise if an not implemented class mapping is passed

    Returns:
        (dict,dict,list): dict with class mappings, inverse of that dict, weights for each class.
    """
    class_list = class_list.lower()
    if class_list == 'regional':
        class_dict = config.CLASS_DICT_REGIONAL
        inverse_class_dict = config.INVERSE_CLASS_DICT_REGIONAL
        class_weights = config.WEIGHT_DICT_REGIONAL
    elif class_list == 'simple':
        class_dict = config.CLASS_DICT_SIMPLE
        class_weights = config.WEIGHT_DICT_SIMPLE
        inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE
    elif class_list == 'berm':
        class_dict = config.CLASS_DICT_SIMPLE_BERM
        class_weights = config.WEIGHT_DICT_SIMPLE_BERM
        inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE_BERM
    elif class_list == 'sloot':
        class_dict = config.CLASS_DICT_SIMPLE_SLOOT
        class_weights = config.WEIGHT_DICT_SIMPLE_SLOOT
        inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE_SLOOT
    elif class_list == 'full':
        class_dict = config.CLASS_DICT_FULL
        class_weights = config.WEIGHT_DICT_FULL
        inverse_class_dict = config.INVERSE_CLASS_DICT_FULL
    else:
        raise NotImplementedError(f"No configs found for class list of type: {class_list}")
    return class_dict, inverse_class_dict, class_weights


def force_sequential_predictions(predictions, method='isotonic'):
    """Force the classes in the sample to always go up from left to right. This is
    makes sense because a higher class could never be left of a lower class in the 
    representation chosen here. Two methods are available, Isotonic Regression and
    a group first method. I would use the Isotonic regression.

    Args:
        predictions (torch.Tensor): Tensor output of the model in shape (batch_size, channel_size, sample_size)
        method (str, optional): method to use for enforcing the sequentiality. Defaults to 'isotonic'.

    Raises:
        NotImplementedError: if the given method is not implemented

    Returns:
        torch.Tensor: Tensor in the same shape as the input but then with only increasing classes from left to right.
    """
    predictions = predictions.detach().cpu()
    n_classes = predictions.shape[1]  # 1 is the channel dimension
    if method == 'first':
        # loop over batch
        for j in range(predictions.shape[0]):
            pred = torch.argmax(predictions[j], dim=0)

            # construct dict of groups of start-end indices for class
            groups = defaultdict(list)
            current_class = pred[0]
            group_start_idx = 0
            for i in range(1, len(pred)):
                if pred[i] != current_class:
                    groups[current_class.item()].append((group_start_idx, i))
                    group_start_idx = i
                    current_class = pred[i]

            # if the class occurs again later in the profile
            # discard this occurance of it
            new_pred = torch.zeros(len(pred))
            last_index = 0
            for class_n, group_tuples in sorted(groups.items()):
                for group_tuple in group_tuples:
                    if group_tuple[0] >= last_index:
                        new_pred[group_tuple[0]:group_tuple[1]] = class_n
                        last_index = group_tuple[1]
                        break
            
            # simple forward fill
            for i in range(1, len(new_pred)):
                if new_pred[i] == 0:
                    new_pred[i] = new_pred[i-1]
            
            # encode back to one-hot tensor
            predictions[j] = F.one_hot(new_pred.to(torch.int64), num_classes=n_classes).permute(1,0)
    elif method == 'isotonic':
        for i in range(predictions.shape[0]):
            pred = torch.argmax(predictions[i], dim=0)

            x = np.arange(0,len(pred))
            iso_reg = IsotonicRegression().fit(x, pred)
            new_pred = iso_reg.predict(x)
            new_pred = np.round(new_pred)

            # encode back to one-hot tensor
            new_pred = F.one_hot(torch.Tensor(new_pred).to(torch.int64), num_classes=n_classes).permute(1,0)
            predictions[i] = new_pred
    else:
        raise NotImplementedError(f"Unknown method: {method}")
    
    return predictions



def visualize_prediction(heights, prediction, labels, location_name, class_list):
    """visualize a profile plus labels and prediction

    Args:
        heights (tensor): tensor containing the heights data of the profile
        prediction (tensor): tensor containing the predicted data of the profile
        labels (tensor): tensor containing the labels for each height point in heights
        location_name (str): name of the profile, just for visualization
        class_list (str): class mapping to use, determines which labels are visualized
    """
    class_dict, inverse_class_dict, _ = get_class_dict(class_list)
    fig, ax = plt.subplots(figsize=(20,11))
    plt.title(location_name)
    plt.plot(heights, label='profile')

    # change one-hot batched format to list of classes
    if prediction.dim() == 3:
        prediction = torch.argmax(torch.squeeze(prediction, dim=0), dim=0)
    if prediction.dim() == 2:
        # assuming channel first representation
        prediction = torch.argmax(prediction, dim=0)
    prediction = prediction.detach().cpu().numpy()
    
    # ax.set_ylim(top=np.max(heights), bottom=np.min(heights))
    label_height = np.min(heights)
    n_labels = len(np.unique(labels))
    label_height_distance = (np.max(heights) - np.min(heights))/(n_labels*2)

    cmap = sns.color_palette("Set2", len(set(class_dict.values())))

    # plot actual labels
    prev_class_n = 999
    for index, class_n in enumerate(labels):
        if class_n == 0:
            continue
        if class_n != prev_class_n:
            plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
            plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
            label_height += label_height_distance
            prev_class_n = class_n
        
    # plot predicted points
    used_classes = []
    prev_class_n = 999
    for index, class_n in enumerate(prediction):
        if class_n == 0 or class_n in used_classes:
            continue
        if class_n != prev_class_n:
            plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(1,1))) # small dots
            plt.text(index, label_height, "predicted " + inverse_class_dict[class_n], rotation=0)
            label_height += label_height_distance
            used_classes.append(prev_class_n)
            prev_class_n = class_n
    
    plt.show()


def visualize_sample(heights, labels, location_name, class_list):
    """visualize a profile and labels.

    Args:
        heights (tensor): tensor containing the heights data of the profile
        labels (tensor): tensor containing the labels for each height point in heights
        location_name (str): name of the profile, just for visualization
        class_list (str): class mapping to use, determines which labels are visualized
    """
    class_dict, inverse_class_dict, _ = get_class_dict(class_list)
    fig, ax = plt.subplots(figsize=(20,11))
    plt.title(location_name)
    plt.plot(heights, label='profile')
    
    # ax.set_ylim(top=np.max(heights), bottom=np.min(heights))
    label_height = np.min(heights)
    n_labels = len(np.unique(labels))
    label_height_distance = (np.max(heights) - np.min(heights))/(n_labels*2)

    cmap = sns.color_palette("Set2", len(set(class_dict.values())))

    # plot actual labels
    prev_class_n = 999
    for index, class_n in enumerate(labels):
        if class_n == 0:
            continue
        if class_n != prev_class_n:
            plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
            plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
            label_height += label_height_distance
            prev_class_n = class_n
    
    plt.show()
    
def visualize_files(linesfp, pointsfp, max_profile_size=512, class_list='simple', location_index=0, return_dict=False):
    """visualize profile lines and points filepaths.

    Args:
        linesfp (str): path to surfacelines file.
        pointsfp (str): path to points file.
        max_profile_size (int, optional): cutoff size of the profile, can leave on default here. Defaults to 512.
        class_list (str, optional): class mapping to use. Defaults to 'simple'.
        location_index (int, optional): index of profile to visualize.. Defaults to 0.
        return_dict (bool, optional): return the profile dict for faster visualization. Defaults to False.

    Returns:
        [dict, optional]: profile dict containing the profiles of the given files
    """
    profile_label_dict = preprocessing.filepath_pair_to_labeled_sample(linesfp, 
                                                               pointsfp, 
                                                               max_profile_size=max_profile_size, 
                                                               class_list=class_list)

    location_name = list(profile_label_dict.keys())[location_index]
    heights = profile_label_dict[location_name]['profile']
    labels = profile_label_dict[location_name]['label']
    
    class_dict, inverse_class_dict, _ = get_class_dict(class_list)
    fig, ax = plt.subplots(figsize=(20,11))
    plt.title(location_name)
    plt.plot(heights, label='profile')
    
    label_height = np.min(heights)
    n_labels = len(np.unique(labels))
    label_height_distance = (np.max(heights) - np.min(heights))/(n_labels)

    cmap = sns.color_palette("Set2", len(set(class_dict.values())))

    # plot actual labels
    prev_class_n = 999
    for index, class_n in enumerate(labels):
        if class_n == 0:
            continue
        if class_n != prev_class_n:
            plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
            plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
            label_height += label_height_distance
            prev_class_n = class_n
    
    plt.show()

    if return_dict:
        return profile_label_dict

def visualize_dict(profile_label_dict, class_list='simple', location_index=0):
    """visualise profile with labels from profile_dict, profile specified by index.

    Args:
        profile_label_dict (dict): dict containing profiles and labels
        class_list (str, optional): class_mapping to use for visualization. Defaults to 'simple'.
        location_index (int, optional): specifies the index of the profile to visualize. Defaults to 0.
    """
    location_name = list(profile_label_dict.keys())[location_index]
    heights = profile_label_dict[location_name]['profile']
    labels = profile_label_dict[location_name]['label']
    
    class_dict, inverse_class_dict, _ = get_class_dict(class_list)
    fig, ax = plt.subplots(figsize=(20,11))
    plt.title(location_name)
    plt.plot(heights, label='profile')
    
    label_height = np.min(heights)
    n_labels = len(np.unique(labels))
    label_height_distance = (np.max(heights) - np.min(heights))/(n_labels)

    cmap = sns.color_palette("Set2", len(set(class_dict.values())))

    # plot actual labels
    prev_class_n = 999
    for index, class_n in enumerate(labels):
        if class_n == 0:
            continue
        if class_n != prev_class_n:
            plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
            plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
            label_height += label_height_distance
            prev_class_n = class_n
    
    plt.show()