jgerbscheid's picture
initial commit
b9bac12
raw
history blame
No virus
13.9 kB
import os
import random
from collections import defaultdict
import dijkprofile_annotator.preprocessing as preprocessing
import dijkprofile_annotator.config as config
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.isotonic import IsotonicRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
def extract_img(size, in_tensor):
"""
Args:
size(int) : size of cut
in_tensor(tensor) : tensor to be cut
"""
dim1 = in_tensor.size()[2]
in_tensor = in_tensor[:, :, int((dim1-size)/2):int((size + (dim1-size)/2))]
return in_tensor
def ffill(arr):
"""Forward fill utility function.
Args:
arr (np.array): numpy array to fill
Returns:
np.array: filled array.
"""
mask = np.isnan(arr)
idx = np.where(~mask, np.arange(mask.shape[1]), 0)
np.maximum.accumulate(idx, axis=1, out=idx)
out = arr[np.arange(idx.shape[0])[:,None], idx]
return out
def train_scaler(profile_dict, scaler_type='minmax'):
"""Train a scaler given a profile dict
Args:
profile_dict (dict): dict containing the profile heights and labels
Returns:
sklearn MinMaxScaler or StandardScaler: fitted scaler in sklearn format
"""
if scaler_type == 'minmax':
scaler = MinMaxScaler(feature_range=(-1, 1)) # for neural networks -1,1 is better than 0,1
elif scaler_type == 'standard':
scaler = StandardScaler()
else:
raise NotImplementedError(f"no scaler: {scaler}")
randkey = random.choice(list(profile_dict.keys()))
accumulator = np.zeros((len(profile_dict), profile_dict[randkey]['profile'].shape[0]))
for i, key in enumerate(profile_dict.keys()):
accumulator[i, :] = profile_dict[key]['profile']
scaler.fit(accumulator.reshape(-1, 1))
return scaler
def get_class_dict(class_list):
"""Get correct class dicts and weights from config.
Args:
class_list (string): string representing the class mappings to use
Raises:
NotImplementedError: raise if an not implemented class mapping is passed
Returns:
(dict,dict,list): dict with class mappings, inverse of that dict, weights for each class.
"""
class_list = class_list.lower()
if class_list == 'regional':
class_dict = config.CLASS_DICT_REGIONAL
inverse_class_dict = config.INVERSE_CLASS_DICT_REGIONAL
class_weights = config.WEIGHT_DICT_REGIONAL
elif class_list == 'simple':
class_dict = config.CLASS_DICT_SIMPLE
class_weights = config.WEIGHT_DICT_SIMPLE
inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE
elif class_list == 'berm':
class_dict = config.CLASS_DICT_SIMPLE_BERM
class_weights = config.WEIGHT_DICT_SIMPLE_BERM
inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE_BERM
elif class_list == 'sloot':
class_dict = config.CLASS_DICT_SIMPLE_SLOOT
class_weights = config.WEIGHT_DICT_SIMPLE_SLOOT
inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE_SLOOT
elif class_list == 'full':
class_dict = config.CLASS_DICT_FULL
class_weights = config.WEIGHT_DICT_FULL
inverse_class_dict = config.INVERSE_CLASS_DICT_FULL
else:
raise NotImplementedError(f"No configs found for class list of type: {class_list}")
return class_dict, inverse_class_dict, class_weights
def force_sequential_predictions(predictions, method='isotonic'):
"""Force the classes in the sample to always go up from left to right. This is
makes sense because a higher class could never be left of a lower class in the
representation chosen here. Two methods are available, Isotonic Regression and
a group first method. I would use the Isotonic regression.
Args:
predictions (torch.Tensor): Tensor output of the model in shape (batch_size, channel_size, sample_size)
method (str, optional): method to use for enforcing the sequentiality. Defaults to 'isotonic'.
Raises:
NotImplementedError: if the given method is not implemented
Returns:
torch.Tensor: Tensor in the same shape as the input but then with only increasing classes from left to right.
"""
predictions = predictions.detach().cpu()
n_classes = predictions.shape[1] # 1 is the channel dimension
if method == 'first':
# loop over batch
for j in range(predictions.shape[0]):
pred = torch.argmax(predictions[j], dim=0)
# construct dict of groups of start-end indices for class
groups = defaultdict(list)
current_class = pred[0]
group_start_idx = 0
for i in range(1, len(pred)):
if pred[i] != current_class:
groups[current_class.item()].append((group_start_idx, i))
group_start_idx = i
current_class = pred[i]
# if the class occurs again later in the profile
# discard this occurance of it
new_pred = torch.zeros(len(pred))
last_index = 0
for class_n, group_tuples in sorted(groups.items()):
for group_tuple in group_tuples:
if group_tuple[0] >= last_index:
new_pred[group_tuple[0]:group_tuple[1]] = class_n
last_index = group_tuple[1]
break
# simple forward fill
for i in range(1, len(new_pred)):
if new_pred[i] == 0:
new_pred[i] = new_pred[i-1]
# encode back to one-hot tensor
predictions[j] = F.one_hot(new_pred.to(torch.int64), num_classes=n_classes).permute(1,0)
elif method == 'isotonic':
for i in range(predictions.shape[0]):
pred = torch.argmax(predictions[i], dim=0)
x = np.arange(0,len(pred))
iso_reg = IsotonicRegression().fit(x, pred)
new_pred = iso_reg.predict(x)
new_pred = np.round(new_pred)
# encode back to one-hot tensor
new_pred = F.one_hot(torch.Tensor(new_pred).to(torch.int64), num_classes=n_classes).permute(1,0)
predictions[i] = new_pred
else:
raise NotImplementedError(f"Unknown method: {method}")
return predictions
def visualize_prediction(heights, prediction, labels, location_name, class_list):
"""visualize a profile plus labels and prediction
Args:
heights (tensor): tensor containing the heights data of the profile
prediction (tensor): tensor containing the predicted data of the profile
labels (tensor): tensor containing the labels for each height point in heights
location_name (str): name of the profile, just for visualization
class_list (str): class mapping to use, determines which labels are visualized
"""
class_dict, inverse_class_dict, _ = get_class_dict(class_list)
fig, ax = plt.subplots(figsize=(20,11))
plt.title(location_name)
plt.plot(heights, label='profile')
# change one-hot batched format to list of classes
if prediction.dim() == 3:
prediction = torch.argmax(torch.squeeze(prediction, dim=0), dim=0)
if prediction.dim() == 2:
# assuming channel first representation
prediction = torch.argmax(prediction, dim=0)
prediction = prediction.detach().cpu().numpy()
# ax.set_ylim(top=np.max(heights), bottom=np.min(heights))
label_height = np.min(heights)
n_labels = len(np.unique(labels))
label_height_distance = (np.max(heights) - np.min(heights))/(n_labels*2)
cmap = sns.color_palette("Set2", len(set(class_dict.values())))
# plot actual labels
prev_class_n = 999
for index, class_n in enumerate(labels):
if class_n == 0:
continue
if class_n != prev_class_n:
plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
label_height += label_height_distance
prev_class_n = class_n
# plot predicted points
used_classes = []
prev_class_n = 999
for index, class_n in enumerate(prediction):
if class_n == 0 or class_n in used_classes:
continue
if class_n != prev_class_n:
plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(1,1))) # small dots
plt.text(index, label_height, "predicted " + inverse_class_dict[class_n], rotation=0)
label_height += label_height_distance
used_classes.append(prev_class_n)
prev_class_n = class_n
plt.show()
def visualize_sample(heights, labels, location_name, class_list):
"""visualize a profile and labels.
Args:
heights (tensor): tensor containing the heights data of the profile
labels (tensor): tensor containing the labels for each height point in heights
location_name (str): name of the profile, just for visualization
class_list (str): class mapping to use, determines which labels are visualized
"""
class_dict, inverse_class_dict, _ = get_class_dict(class_list)
fig, ax = plt.subplots(figsize=(20,11))
plt.title(location_name)
plt.plot(heights, label='profile')
# ax.set_ylim(top=np.max(heights), bottom=np.min(heights))
label_height = np.min(heights)
n_labels = len(np.unique(labels))
label_height_distance = (np.max(heights) - np.min(heights))/(n_labels*2)
cmap = sns.color_palette("Set2", len(set(class_dict.values())))
# plot actual labels
prev_class_n = 999
for index, class_n in enumerate(labels):
if class_n == 0:
continue
if class_n != prev_class_n:
plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
label_height += label_height_distance
prev_class_n = class_n
plt.show()
def visualize_files(linesfp, pointsfp, max_profile_size=512, class_list='simple', location_index=0, return_dict=False):
"""visualize profile lines and points filepaths.
Args:
linesfp (str): path to surfacelines file.
pointsfp (str): path to points file.
max_profile_size (int, optional): cutoff size of the profile, can leave on default here. Defaults to 512.
class_list (str, optional): class mapping to use. Defaults to 'simple'.
location_index (int, optional): index of profile to visualize.. Defaults to 0.
return_dict (bool, optional): return the profile dict for faster visualization. Defaults to False.
Returns:
[dict, optional]: profile dict containing the profiles of the given files
"""
profile_label_dict = preprocessing.filepath_pair_to_labeled_sample(linesfp,
pointsfp,
max_profile_size=max_profile_size,
class_list=class_list)
location_name = list(profile_label_dict.keys())[location_index]
heights = profile_label_dict[location_name]['profile']
labels = profile_label_dict[location_name]['label']
class_dict, inverse_class_dict, _ = get_class_dict(class_list)
fig, ax = plt.subplots(figsize=(20,11))
plt.title(location_name)
plt.plot(heights, label='profile')
label_height = np.min(heights)
n_labels = len(np.unique(labels))
label_height_distance = (np.max(heights) - np.min(heights))/(n_labels)
cmap = sns.color_palette("Set2", len(set(class_dict.values())))
# plot actual labels
prev_class_n = 999
for index, class_n in enumerate(labels):
if class_n == 0:
continue
if class_n != prev_class_n:
plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
label_height += label_height_distance
prev_class_n = class_n
plt.show()
if return_dict:
return profile_label_dict
def visualize_dict(profile_label_dict, class_list='simple', location_index=0):
"""visualise profile with labels from profile_dict, profile specified by index.
Args:
profile_label_dict (dict): dict containing profiles and labels
class_list (str, optional): class_mapping to use for visualization. Defaults to 'simple'.
location_index (int, optional): specifies the index of the profile to visualize. Defaults to 0.
"""
location_name = list(profile_label_dict.keys())[location_index]
heights = profile_label_dict[location_name]['profile']
labels = profile_label_dict[location_name]['label']
class_dict, inverse_class_dict, _ = get_class_dict(class_list)
fig, ax = plt.subplots(figsize=(20,11))
plt.title(location_name)
plt.plot(heights, label='profile')
label_height = np.min(heights)
n_labels = len(np.unique(labels))
label_height_distance = (np.max(heights) - np.min(heights))/(n_labels)
cmap = sns.color_palette("Set2", len(set(class_dict.values())))
# plot actual labels
prev_class_n = 999
for index, class_n in enumerate(labels):
if class_n == 0:
continue
if class_n != prev_class_n:
plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes
plt.text(index, label_height, inverse_class_dict[class_n], rotation=0)
label_height += label_height_distance
prev_class_n = class_n
plt.show()