|
from collections import OrderedDict |
|
|
|
import PIL |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
import kornia |
|
|
|
import torch_lydorn.kornia |
|
import torch_lydorn.torchvision |
|
|
|
from lydorn_utils import print_utils |
|
|
|
|
|
class Print(object): |
|
"""Convert polygons to a single graph""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, sample): |
|
print("\n") |
|
print(sample.keys()) |
|
for key, item in sample.items(): |
|
if type(item) == np.ndarray or type(item) == torch.Tensor: |
|
if len(item.shape): |
|
print(key, type(item), item.shape, item.dtype, item.min(), item.max()) |
|
else: |
|
print(key, type(item), item, item.dtype, item.min(), item.max()) |
|
elif type(item) == PIL.Image.Image: |
|
print(key, type(item), item.size, item.mode, np.array(item).min(), np.array(item).max()) |
|
elif type(item) == list: |
|
print(key, type(item[0]), len(item)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
class CudaDataAugmentation(object): |
|
def __init__(self, input_patch_size: int, vflip: bool, affine: bool, scaling: list, color_jitter: bool): |
|
self.input_patch_size = input_patch_size |
|
self.vflip = vflip |
|
self.affine = affine |
|
self.scaling = scaling |
|
self.color_jitter = None |
|
if color_jitter: |
|
self.color_jitter = kornia.augmentation.ColorJitter(brightness=0.05, contrast=0.05, saturation=.5, hue=.1) |
|
self.tensor_keys_bilinear = ["image", "gt_polygons_image", "distances", |
|
"valid_mask"] |
|
self.tensor_keys_nearest = ["sizes", "gt_crossfield_angle"] |
|
|
|
@staticmethod |
|
def get_slices(batch, keys, last_slice_stop=0): |
|
slices = OrderedDict() |
|
for key in keys: |
|
s = slice(last_slice_stop, last_slice_stop + batch[key].shape[1]) |
|
last_slice_stop += batch[key].shape[1] |
|
slices[key] = s |
|
return slices |
|
|
|
def __call__(self, batch): |
|
with torch.no_grad(): |
|
batch_size, im_channels, height, width = batch["image"].shape |
|
device = batch["image"].device |
|
batch["valid_mask"] = torch.ones((batch_size, 1, height, width), dtype=torch.float, |
|
device=device) |
|
|
|
|
|
tensor_keys_bilinear = [key for key in self.tensor_keys_bilinear if key in batch] |
|
tensor_keys_nearest = [key for key in self.tensor_keys_nearest if key in batch] |
|
tensor_keys = tensor_keys_bilinear + tensor_keys_nearest |
|
combined = torch.cat([batch[tensor_key] for tensor_key in tensor_keys], dim=1) |
|
slices_bilinear = self.get_slices(batch, tensor_keys_bilinear, last_slice_stop=0) |
|
slices_nearest = self.get_slices(batch, tensor_keys_nearest, |
|
last_slice_stop=slices_bilinear[tensor_keys_bilinear[-1]].stop) |
|
bilinear_slice = slice(slices_bilinear[tensor_keys_bilinear[0]].start, |
|
slices_bilinear[tensor_keys_bilinear[-1]].stop) |
|
nearest_slice = slice(slices_nearest[tensor_keys_nearest[0]].start, |
|
slices_nearest[tensor_keys_nearest[-1]].stop) |
|
|
|
|
|
if self.affine: |
|
angle: torch.Tensor = torch.empty(batch_size, device=device).uniform_(-180, 180) |
|
|
|
max_offset = np.sqrt(2) - 1 |
|
offset: torch.Tensor = torch.empty((batch_size, 2), device=device).uniform_(-max_offset, max_offset) |
|
downscale_factor = None |
|
if self.scaling is not None: |
|
downscale_factor: torch.Tensor = torch.empty(batch_size, device=device).uniform_(*self.scaling) |
|
affine_grid = torch_lydorn.kornia.geometry.transform.get_affine_grid(combined, angle, offset, |
|
downscale_factor) |
|
combined[:, bilinear_slice, ...] = \ |
|
torch.nn.functional.grid_sample(combined[:, bilinear_slice, ...], |
|
affine_grid, mode='bilinear') |
|
|
|
|
|
|
|
|
|
|
|
combined[:, nearest_slice, ...] = torch.nn.functional.grid_sample(combined[:, nearest_slice, ...], |
|
affine_grid, mode='nearest') |
|
|
|
|
|
combined[:, slices_nearest["gt_crossfield_angle"], |
|
...] = torch_lydorn.torchvision.transforms.functional.rotate_anglefield( |
|
combined[:, slices_nearest["gt_crossfield_angle"], ...], angle) |
|
|
|
|
|
if downscale_factor is not None: |
|
if "sizes" in slices_nearest: |
|
size_equals_one = combined[:, slices_nearest["sizes"], ...] == 1 |
|
combined[:, slices_nearest["sizes"], :, :] /= downscale_factor[:, None, None, None] |
|
combined[:, slices_nearest["sizes"], ...][size_equals_one] = 1 |
|
if "distances" in slices_bilinear: |
|
distance_equals_one = combined[:, slices_bilinear["distances"], ...] == 1 |
|
combined[:, slices_bilinear["distances"], :, :] /= downscale_factor[:, None, None, None] |
|
combined[:, slices_bilinear["distances"], ...][distance_equals_one] = 1 |
|
|
|
|
|
if self.input_patch_size is not None: |
|
prev_image_norm = combined.shape[2] + combined.shape[3] |
|
combined = torch_lydorn.torchvision.transforms.functional.center_crop(combined, self.input_patch_size) |
|
current_image_norm = combined.shape[2] + combined.shape[3] |
|
|
|
|
|
size_ratio = prev_image_norm / current_image_norm |
|
if "sizes" in slices_nearest: |
|
combined[:, slices_nearest["sizes"], ...][ |
|
combined[:, slices_nearest["sizes"], ...] != 1] *= size_ratio |
|
if "distances" in slices_bilinear: |
|
combined[:, slices_bilinear["distances"], ...][ |
|
combined[:, slices_bilinear["distances"], ...] != 1] *= size_ratio |
|
|
|
|
|
if self.vflip: |
|
to_flip: torch.Tensor = torch.empty(batch_size, device=device).uniform_(0, 1) < 0.5 |
|
combined[to_flip] = kornia.geometry.transform.vflip(combined[to_flip]) |
|
combined[ |
|
to_flip, slices_nearest[ |
|
"gt_crossfield_angle"], ...] = torch_lydorn.torchvision.transforms.functional.vflip_anglefield( |
|
combined[to_flip, slices_nearest["gt_crossfield_angle"], ...]) |
|
|
|
|
|
batch["image"] = combined[:, slices_bilinear["image"], ...] |
|
if "gt_polygons_image" in slices_bilinear: |
|
batch["gt_polygons_image"] = combined[:, slices_bilinear["gt_polygons_image"], ...] |
|
if "distances" in slices_bilinear: |
|
batch["distances"] = combined[:, slices_bilinear["distances"], ...] |
|
batch["valid_mask"] = 0.99 < combined[:, slices_bilinear["valid_mask"], |
|
...] |
|
|
|
if "sizes" in slices_nearest: |
|
batch["sizes"] = combined[:, slices_nearest["sizes"], ...] |
|
batch["gt_crossfield_angle"] = combined[:, slices_nearest["gt_crossfield_angle"], ...] |
|
|
|
|
|
if self.color_jitter is not None and batch["image"].shape[1] == 3: |
|
batch["image"] = self.color_jitter(batch["image"]) |
|
|
|
|
|
if self.affine: |
|
if "sizes" in slices_nearest: |
|
batch["sizes"][~batch["valid_mask"]] = 1 |
|
if "distances" in slices_bilinear: |
|
batch["distances"][~batch["valid_mask"]] = 1 |
|
|
|
return batch |
|
|
|
|
|
class CudaCrop(object): |
|
def __init__(self, input_patch_size: int): |
|
self.input_patch_size = input_patch_size |
|
self.tensor_keys = ["image", "gt_polygons_image", "distances", "valid_mask", "sizes", "gt_crossfield_angle"] |
|
|
|
def __call__(self, batch): |
|
for tensor_key in self.tensor_keys: |
|
if tensor_key in batch: |
|
batch[tensor_key] = torch_lydorn.torchvision.transforms.functional.center_crop(batch[tensor_key], |
|
self.input_patch_size) |
|
return batch |
|
|
|
|
|
def get_offline_transform(config, augmentations=False, to_patches=True): |
|
data_patch_size = config["dataset_params"]["data_patch_size"] if augmentations else config["dataset_params"][ |
|
"input_patch_size"] |
|
transform_list = [ |
|
torch_lydorn.torchvision.transforms.Map( |
|
transform=torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Compose([ |
|
torch_lydorn.torchvision.transforms.RemoveDoubles(epsilon=0.01), |
|
torch_lydorn.torchvision.transforms.FilterPolyVertexCount(min=3), |
|
torch_lydorn.torchvision.transforms.ApproximatePolygon(tolerance=0.01), |
|
torch_lydorn.torchvision.transforms.FilterPolyVertexCount(min=3) |
|
]), key="gt_polygons")), |
|
|
|
torch_lydorn.torchvision.transforms.FilterEmptyPolygons(key="gt_polygons"), |
|
] |
|
if to_patches: |
|
transform_list.extend([ |
|
torch_lydorn.torchvision.transforms.ToPatches(stride=config["dataset_params"]["input_patch_size"], |
|
size=data_patch_size), |
|
torch_lydorn.torchvision.transforms.FilterEmptyPolygons(key="gt_polygons"), |
|
]) |
|
transform_list.extend([ |
|
torch_lydorn.torchvision.transforms.Map( |
|
transform=torchvision.transforms.Compose([ |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.Rasterize(fill=True, edges=True, vertices=True, |
|
line_width=4, antialiasing=True), |
|
key=["image", "gt_polygons"], outkey="gt_polygons_image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.AngleFieldInit(line_width=6), |
|
key=["image", "gt_polygons"], |
|
outkey="gt_crossfield_angle") |
|
])), |
|
]) |
|
offline_transform = torchvision.transforms.Compose(transform_list) |
|
return offline_transform |
|
|
|
|
|
def get_offline_transform_patch(raster: bool = True, fill: bool = True, edges: bool = True, vertices: bool = True, |
|
distances: bool = True, sizes: bool = True, angle_field: bool = True): |
|
transform_list = [] |
|
if raster: |
|
if not distances and not sizes: |
|
rasterize_transform = torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.Rasterize(fill=fill, edges=edges, vertices=vertices, |
|
line_width=4, antialiasing=True, |
|
return_distances=False, |
|
return_sizes=False), |
|
key=["image", "gt_polygons"], outkey="gt_polygons_image") |
|
elif distances and sizes: |
|
rasterize_transform = torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.Rasterize(fill=fill, edges=edges, vertices=vertices, |
|
line_width=4, antialiasing=True, |
|
return_distances=True, |
|
return_sizes=True), |
|
key=["image", "gt_polygons"], outkey=["gt_polygons_image", "distances", "sizes"]) |
|
else: |
|
raise NotImplementedError |
|
transform_list.append(rasterize_transform) |
|
if angle_field: |
|
transform_list.append( |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.AngleFieldInit(line_width=6), |
|
key=["image", "gt_polygons"], |
|
outkey="gt_crossfield_angle") |
|
) |
|
|
|
return torchvision.transforms.Compose(transform_list) |
|
|
|
|
|
def get_online_cpu_transform(config, augmentations=False): |
|
if augmentations and config["data_aug_params"]["device"] == "cpu": |
|
print_utils.print_error("ERROR: CPU augmentations is not supported anymore. " |
|
"Look at CudaDataAugmentation to see what additional augs would need to be implemented.") |
|
raise NotImplementedError |
|
online_transform_list = [] |
|
|
|
if not augmentations \ |
|
or (augmentations and config["data_aug_params"]["device"] == "cpu"): |
|
online_transform_list.extend([ |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.ToPILImage(), |
|
key="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.ToPILImage(), |
|
key="gt_polygons_image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.ToPILImage(), |
|
key="gt_crossfield_angle"), |
|
]) |
|
|
|
if augmentations and config["data_aug_params"]["device"] == "cpu" and \ |
|
config["data_aug_params"]["affine"]: |
|
online_transform_list.extend([ |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.SampleUniform(-180, 180), |
|
outkey="rand_angle"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.functional.rotate, |
|
key=["image", "rand_angle"], outkey="image", |
|
resample=PIL.Image.BILINEAR), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.functional.rotate, |
|
key=["gt_polygons_image", "rand_angle"], |
|
outkey="gt_polygons_image", |
|
resample=PIL.Image.BILINEAR), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.functional.rotate, |
|
key=["gt_crossfield_angle", "rand_angle"], |
|
outkey="gt_crossfield_angle", |
|
resample=PIL.Image.NEAREST), |
|
]) |
|
|
|
|
|
if not augmentations \ |
|
or (augmentations and config["data_aug_params"]["device"] == "cpu"): |
|
if "input_patch_size" in config["dataset_params"]: |
|
online_transform_list.extend([ |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.CenterCrop(config["dataset_params"]["input_patch_size"]), |
|
key="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.CenterCrop(config["dataset_params"]["input_patch_size"]), |
|
key="gt_polygons_image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.CenterCrop(config["dataset_params"]["input_patch_size"]), |
|
key="gt_crossfield_angle"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.CenterCrop( |
|
config["dataset_params"]["input_patch_size"]), |
|
key="distances"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.CenterCrop( |
|
config["dataset_params"]["input_patch_size"]), |
|
key="sizes"), |
|
]) |
|
|
|
|
|
if augmentations and config["data_aug_params"]["device"] == "cpu" and \ |
|
config["data_aug_params"]["vflip"]: |
|
online_transform_list.extend([ |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.RandomBool(p=0.5), |
|
outkey="rand_flip"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.ConditionApply( |
|
transform=torchvision.transforms.functional.vflip), |
|
key=["image", "rand_flip"], outkey="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.ConditionApply( |
|
transform=torchvision.transforms.functional.vflip), |
|
key=["gt_polygons_image", "rand_flip"], outkey="gt_polygons_image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.ConditionApply( |
|
transform=torchvision.transforms.functional.vflip), |
|
key=["gt_crossfield_angle", "rand_flip"], outkey="gt_crossfield_angle"), |
|
]) |
|
|
|
|
|
if augmentations and config["data_aug_params"]["device"] == "cpu" and \ |
|
config["data_aug_params"]["color_jitter"]: |
|
online_transform_list.append( |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.ColorJitter(brightness=0.05, contrast=0.05, |
|
saturation=.5, hue=.1), |
|
key="image") |
|
) |
|
|
|
online_transform_list.extend([ |
|
|
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(torch.from_numpy), |
|
key="image_mean"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(torch.from_numpy), |
|
key="image_std"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="gt_polygons_image", ignore_key_error=True), |
|
torch_lydorn.torchvision.transforms.TransformByKey(torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="gt_crossfield_angle", ignore_key_error=True), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="distances", ignore_key_error=True), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="sizes", ignore_key_error=True), |
|
]) |
|
|
|
online_transform_list.append( |
|
torch_lydorn.torchvision.transforms.RemoveKeys(keys=["gt_polygons"]) |
|
) |
|
|
|
online_transform = torchvision.transforms.Compose(online_transform_list) |
|
return online_transform |
|
|
|
|
|
def get_eval_online_cpu_transform(): |
|
online_transform = torchvision.transforms.Compose([ |
|
|
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Lambda( |
|
torch.from_numpy), |
|
key="image_mean"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Lambda( |
|
torch.from_numpy), |
|
key="image_std"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="gt_polygons_image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="gt_crossfield_angle"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="distances"), |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(), |
|
key="sizes"), |
|
torch_lydorn.torchvision.transforms.RemoveKeys(keys=["gt_polygons"]) |
|
]) |
|
return online_transform |
|
|
|
|
|
def get_online_cuda_transform(config, augmentations=False): |
|
device_transform_list = [ |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Compose([ |
|
torchvision.transforms.Lambda(lambda tensor: tensor.float().div(255)) |
|
]), key="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float().div(255)), |
|
key="gt_polygons_image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(lambda tensor: np.pi * tensor.float().div(255)), |
|
key="gt_crossfield_angle"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float()), |
|
key="distances", ignore_key_error=True), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float()), |
|
key="sizes", ignore_key_error=True), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float()), |
|
key="class_freq", ignore_key_error=True), |
|
] |
|
if augmentations and config["data_aug_params"]["device"] == "cpu": |
|
if config["data_aug_params"]["affine"]: |
|
|
|
device_transform_list.append( |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.functional.rotate_anglefield, |
|
key=["gt_crossfield_angle", "rand_angle"], |
|
outkey="gt_crossfield_angle")) |
|
if config["data_aug_params"]["vflip"]: |
|
device_transform_list.append( |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.functional.vflip_anglefield, |
|
key=["gt_crossfield_angle", "rand_flip"], |
|
outkey="gt_crossfield_angle")) |
|
if config["data_aug_params"]["device"] == "cuda": |
|
input_patch_size = config["dataset_params"]["input_patch_size"] if "input_patch_size" in config[ |
|
"dataset_params"] else None |
|
if augmentations: |
|
device_transform_list.append(CudaDataAugmentation(input_patch_size, |
|
config["data_aug_params"]["vflip"], |
|
config["data_aug_params"]["affine"], |
|
config["data_aug_params"]["scaling"], |
|
config["data_aug_params"]["color_jitter"])) |
|
elif input_patch_size is not None: |
|
device_transform_list.append(CudaCrop(input_patch_size)) |
|
device_transform_list.append( |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.functional.batch_normalize, |
|
key=["image", "image_mean", "image_std"], |
|
outkey="image"), ) |
|
device_transform = torchvision.transforms.Compose(device_transform_list) |
|
return device_transform |
|
|
|
|
|
def get_eval_online_cuda_transform(config): |
|
device_transform_list = [ |
|
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Compose([ |
|
torchvision.transforms.Lambda(lambda tensor: tensor.float().div(255)) |
|
]), key="image"), |
|
torch_lydorn.torchvision.transforms.TransformByKey( |
|
transform=torch_lydorn.torchvision.transforms.functional.batch_normalize, |
|
key=["image", "image_mean", "image_std"], |
|
outkey="image") |
|
] |
|
device_transform = torchvision.transforms.Compose(device_transform_list) |
|
return device_transform |
|
|