Luuu / frame_field_learning /data_transforms.py
็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
27.2 kB
from collections import OrderedDict
import PIL
import numpy as np
import torch
import torchvision
import kornia
import torch_lydorn.kornia
import torch_lydorn.torchvision
from lydorn_utils import print_utils
class Print(object):
"""Convert polygons to a single graph"""
def __init__(self):
pass
def __call__(self, sample):
print("\n")
print(sample.keys())
for key, item in sample.items():
if type(item) == np.ndarray or type(item) == torch.Tensor:
if len(item.shape):
print(key, type(item), item.shape, item.dtype, item.min(), item.max())
else:
print(key, type(item), item, item.dtype, item.min(), item.max())
elif type(item) == PIL.Image.Image:
print(key, type(item), item.size, item.mode, np.array(item).min(), np.array(item).max())
elif type(item) == list:
print(key, type(item[0]), len(item))
# exit()
# print(sample["image"].dtype)
# print(sample["image"].shape)
# print(sample["image"].min())
# print(sample["image"].max())
# for key, value in sample.items():
# print(key + ":")
# if type(value) == np.ndarray:
# print(value.shape)
# elif type(value) == list:
# print(len(value))
# else:
# print("a")
return sample
class CudaDataAugmentation(object):
def __init__(self, input_patch_size: int, vflip: bool, affine: bool, scaling: list, color_jitter: bool):
self.input_patch_size = input_patch_size
self.vflip = vflip
self.affine = affine
self.scaling = scaling
self.color_jitter = None
if color_jitter:
self.color_jitter = kornia.augmentation.ColorJitter(brightness=0.05, contrast=0.05, saturation=.5, hue=.1)
self.tensor_keys_bilinear = ["image", "gt_polygons_image", "distances",
"valid_mask"] # Affine transform applied with bilinear sampling
self.tensor_keys_nearest = ["sizes", "gt_crossfield_angle"] # Affine transform applied with nearest sampling
@staticmethod
def get_slices(batch, keys, last_slice_stop=0):
slices = OrderedDict()
for key in keys:
s = slice(last_slice_stop, last_slice_stop + batch[key].shape[1])
last_slice_stop += batch[key].shape[1]
slices[key] = s
return slices
def __call__(self, batch):
with torch.no_grad():
batch_size, im_channels, height, width = batch["image"].shape
device = batch["image"].device
batch["valid_mask"] = torch.ones((batch_size, 1, height, width), dtype=torch.float,
device=device) # Apply losses only when valid_mask is True (pad with 0 when rotating)
# Combine all images into one for faster/easier processing (store slices to separate them later on):
tensor_keys_bilinear = [key for key in self.tensor_keys_bilinear if key in batch]
tensor_keys_nearest = [key for key in self.tensor_keys_nearest if key in batch]
tensor_keys = tensor_keys_bilinear + tensor_keys_nearest
combined = torch.cat([batch[tensor_key] for tensor_key in tensor_keys], dim=1)
slices_bilinear = self.get_slices(batch, tensor_keys_bilinear, last_slice_stop=0)
slices_nearest = self.get_slices(batch, tensor_keys_nearest,
last_slice_stop=slices_bilinear[tensor_keys_bilinear[-1]].stop)
bilinear_slice = slice(slices_bilinear[tensor_keys_bilinear[0]].start,
slices_bilinear[tensor_keys_bilinear[-1]].stop)
nearest_slice = slice(slices_nearest[tensor_keys_nearest[0]].start,
slices_nearest[tensor_keys_nearest[-1]].stop)
# Rotation (and translation)
if self.affine:
angle: torch.Tensor = torch.empty(batch_size, device=device).uniform_(-180, 180)
# To include corner pixels if angle=45 (coords are between -1 and 1 for grid_sample):
max_offset = np.sqrt(2) - 1
offset: torch.Tensor = torch.empty((batch_size, 2), device=device).uniform_(-max_offset, max_offset)
downscale_factor = None
if self.scaling is not None:
downscale_factor: torch.Tensor = torch.empty(batch_size, device=device).uniform_(*self.scaling)
affine_grid = torch_lydorn.kornia.geometry.transform.get_affine_grid(combined, angle, offset,
downscale_factor)
combined[:, bilinear_slice, ...] = \
torch.nn.functional.grid_sample(combined[:, bilinear_slice, ...],
affine_grid, mode='bilinear')
# Rotate sizes and anglefield with mode='nearest'
# because it makes no sense to interpolate size values and angle values:
combined[:, nearest_slice, ...] = torch.nn.functional.grid_sample(combined[:, nearest_slice, ...],
affine_grid, mode='nearest')
# Additionally the angle field's values themselves have to be rotated:
combined[:, slices_nearest["gt_crossfield_angle"],
...] = torch_lydorn.torchvision.transforms.functional.rotate_anglefield(
combined[:, slices_nearest["gt_crossfield_angle"], ...], angle)
# The sizes and distances should be adjusted as well because of the scaling.
if downscale_factor is not None:
if "sizes" in slices_nearest:
size_equals_one = combined[:, slices_nearest["sizes"], ...] == 1
combined[:, slices_nearest["sizes"], :, :] /= downscale_factor[:, None, None, None]
combined[:, slices_nearest["sizes"], ...][size_equals_one] = 1
if "distances" in slices_bilinear:
distance_equals_one = combined[:, slices_bilinear["distances"], ...] == 1
combined[:, slices_bilinear["distances"], :, :] /= downscale_factor[:, None, None, None]
combined[:, slices_bilinear["distances"], ...][distance_equals_one] = 1
# Center crop
if self.input_patch_size is not None:
prev_image_norm = combined.shape[2] + combined.shape[3]
combined = torch_lydorn.torchvision.transforms.functional.center_crop(combined, self.input_patch_size)
current_image_norm = combined.shape[2] + combined.shape[3]
# Sizes and distances are affected by this because they are relative to the image's size.
# All non-one pixels have to be renormalized:
size_ratio = prev_image_norm / current_image_norm
if "sizes" in slices_nearest:
combined[:, slices_nearest["sizes"], ...][
combined[:, slices_nearest["sizes"], ...] != 1] *= size_ratio
if "distances" in slices_bilinear:
combined[:, slices_bilinear["distances"], ...][
combined[:, slices_bilinear["distances"], ...] != 1] *= size_ratio
# vflip
if self.vflip:
to_flip: torch.Tensor = torch.empty(batch_size, device=device).uniform_(0, 1) < 0.5
combined[to_flip] = kornia.geometry.transform.vflip(combined[to_flip])
combined[
to_flip, slices_nearest[
"gt_crossfield_angle"], ...] = torch_lydorn.torchvision.transforms.functional.vflip_anglefield(
combined[to_flip, slices_nearest["gt_crossfield_angle"], ...])
# Split data:
batch["image"] = combined[:, slices_bilinear["image"], ...]
if "gt_polygons_image" in slices_bilinear:
batch["gt_polygons_image"] = combined[:, slices_bilinear["gt_polygons_image"], ...]
if "distances" in slices_bilinear:
batch["distances"] = combined[:, slices_bilinear["distances"], ...]
batch["valid_mask"] = 0.99 < combined[:, slices_bilinear["valid_mask"],
...] # Take a very high threshold to remove fuzzy pixels
if "sizes" in slices_nearest:
batch["sizes"] = combined[:, slices_nearest["sizes"], ...]
batch["gt_crossfield_angle"] = combined[:, slices_nearest["gt_crossfield_angle"], ...]
# Color jitter
if self.color_jitter is not None and batch["image"].shape[1] == 3:
batch["image"] = self.color_jitter(batch["image"])
# --- Zero padding of sizes and distances is not correct, they should be padded with ones:
if self.affine:
if "sizes" in slices_nearest:
batch["sizes"][~batch["valid_mask"]] = 1
if "distances" in slices_bilinear:
batch["distances"][~batch["valid_mask"]] = 1
return batch
class CudaCrop(object):
def __init__(self, input_patch_size: int):
self.input_patch_size = input_patch_size
self.tensor_keys = ["image", "gt_polygons_image", "distances", "valid_mask", "sizes", "gt_crossfield_angle"]
def __call__(self, batch):
for tensor_key in self.tensor_keys:
if tensor_key in batch:
batch[tensor_key] = torch_lydorn.torchvision.transforms.functional.center_crop(batch[tensor_key],
self.input_patch_size)
return batch
def get_offline_transform(config, augmentations=False, to_patches=True):
data_patch_size = config["dataset_params"]["data_patch_size"] if augmentations else config["dataset_params"][
"input_patch_size"]
transform_list = [
torch_lydorn.torchvision.transforms.Map(
transform=torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Compose([
torch_lydorn.torchvision.transforms.RemoveDoubles(epsilon=0.01),
torch_lydorn.torchvision.transforms.FilterPolyVertexCount(min=3),
torch_lydorn.torchvision.transforms.ApproximatePolygon(tolerance=0.01),
torch_lydorn.torchvision.transforms.FilterPolyVertexCount(min=3)
]), key="gt_polygons")),
torch_lydorn.torchvision.transforms.FilterEmptyPolygons(key="gt_polygons"),
]
if to_patches:
transform_list.extend([
torch_lydorn.torchvision.transforms.ToPatches(stride=config["dataset_params"]["input_patch_size"],
size=data_patch_size),
torch_lydorn.torchvision.transforms.FilterEmptyPolygons(key="gt_polygons"),
])
transform_list.extend([
torch_lydorn.torchvision.transforms.Map(
transform=torchvision.transforms.Compose([
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.Rasterize(fill=True, edges=True, vertices=True,
line_width=4, antialiasing=True),
key=["image", "gt_polygons"], outkey="gt_polygons_image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.AngleFieldInit(line_width=6),
key=["image", "gt_polygons"],
outkey="gt_crossfield_angle")
])),
])
offline_transform = torchvision.transforms.Compose(transform_list)
return offline_transform
def get_offline_transform_patch(raster: bool = True, fill: bool = True, edges: bool = True, vertices: bool = True,
distances: bool = True, sizes: bool = True, angle_field: bool = True):
transform_list = []
if raster:
if not distances and not sizes:
rasterize_transform = torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.Rasterize(fill=fill, edges=edges, vertices=vertices,
line_width=4, antialiasing=True,
return_distances=False,
return_sizes=False),
key=["image", "gt_polygons"], outkey="gt_polygons_image")
elif distances and sizes:
rasterize_transform = torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.Rasterize(fill=fill, edges=edges, vertices=vertices,
line_width=4, antialiasing=True,
return_distances=True,
return_sizes=True),
key=["image", "gt_polygons"], outkey=["gt_polygons_image", "distances", "sizes"])
else:
raise NotImplementedError
transform_list.append(rasterize_transform)
if angle_field:
transform_list.append(
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.AngleFieldInit(line_width=6),
key=["image", "gt_polygons"],
outkey="gt_crossfield_angle")
)
return torchvision.transforms.Compose(transform_list)
def get_online_cpu_transform(config, augmentations=False):
if augmentations and config["data_aug_params"]["device"] == "cpu":
print_utils.print_error("ERROR: CPU augmentations is not supported anymore. "
"Look at CudaDataAugmentation to see what additional augs would need to be implemented.")
raise NotImplementedError
online_transform_list = []
# Convert to PIL images
if not augmentations \
or (augmentations and config["data_aug_params"]["device"] == "cpu"):
online_transform_list.extend([
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.ToPILImage(),
key="image"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.ToPILImage(),
key="gt_polygons_image"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.ToPILImage(),
key="gt_crossfield_angle"),
])
# Add rotation data augmentation:
if augmentations and config["data_aug_params"]["device"] == "cpu" and \
config["data_aug_params"]["affine"]:
online_transform_list.extend([
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.SampleUniform(-180, 180),
outkey="rand_angle"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.functional.rotate,
key=["image", "rand_angle"], outkey="image",
resample=PIL.Image.BILINEAR),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.functional.rotate,
key=["gt_polygons_image", "rand_angle"],
outkey="gt_polygons_image",
resample=PIL.Image.BILINEAR),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.functional.rotate,
key=["gt_crossfield_angle", "rand_angle"],
outkey="gt_crossfield_angle",
resample=PIL.Image.NEAREST),
])
# Crop to final size
if not augmentations \
or (augmentations and config["data_aug_params"]["device"] == "cpu"):
if "input_patch_size" in config["dataset_params"]:
online_transform_list.extend([
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.CenterCrop(config["dataset_params"]["input_patch_size"]),
key="image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.CenterCrop(config["dataset_params"]["input_patch_size"]),
key="gt_polygons_image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.CenterCrop(config["dataset_params"]["input_patch_size"]),
key="gt_crossfield_angle"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.CenterCrop(
config["dataset_params"]["input_patch_size"]),
key="distances"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.CenterCrop(
config["dataset_params"]["input_patch_size"]),
key="sizes"),
])
# Random Horizontal flip:
if augmentations and config["data_aug_params"]["device"] == "cpu" and \
config["data_aug_params"]["vflip"]:
online_transform_list.extend([
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.RandomBool(p=0.5),
outkey="rand_flip"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.ConditionApply(
transform=torchvision.transforms.functional.vflip),
key=["image", "rand_flip"], outkey="image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.ConditionApply(
transform=torchvision.transforms.functional.vflip),
key=["gt_polygons_image", "rand_flip"], outkey="gt_polygons_image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.ConditionApply(
transform=torchvision.transforms.functional.vflip),
key=["gt_crossfield_angle", "rand_flip"], outkey="gt_crossfield_angle"),
])
# Other augs:
if augmentations and config["data_aug_params"]["device"] == "cpu" and \
config["data_aug_params"]["color_jitter"]:
online_transform_list.append(
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.ColorJitter(brightness=0.05, contrast=0.05,
saturation=.5, hue=.1),
key="image")
)
# Convert to PyTorch tensors:
online_transform_list.extend([
# Print(),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(torch.from_numpy),
key="image_mean"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(torch.from_numpy),
key="image_std"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="gt_polygons_image", ignore_key_error=True),
torch_lydorn.torchvision.transforms.TransformByKey(torch_lydorn.torchvision.transforms.ToTensor(),
key="gt_crossfield_angle", ignore_key_error=True),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="distances", ignore_key_error=True),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="sizes", ignore_key_error=True),
])
online_transform_list.append(
torch_lydorn.torchvision.transforms.RemoveKeys(keys=["gt_polygons"])
)
online_transform = torchvision.transforms.Compose(online_transform_list)
return online_transform
def get_eval_online_cpu_transform():
online_transform = torchvision.transforms.Compose([
# Print(),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="image"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Lambda(
torch.from_numpy),
key="image_mean"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Lambda(
torch.from_numpy),
key="image_std"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="gt_polygons_image"),
torch_lydorn.torchvision.transforms.TransformByKey(torch_lydorn.torchvision.transforms.ToTensor(),
key="gt_crossfield_angle"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="distances"),
torch_lydorn.torchvision.transforms.TransformByKey(transform=torch_lydorn.torchvision.transforms.ToTensor(),
key="sizes"),
torch_lydorn.torchvision.transforms.RemoveKeys(keys=["gt_polygons"])
])
return online_transform
def get_online_cuda_transform(config, augmentations=False):
device_transform_list = [
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Compose([
torchvision.transforms.Lambda(lambda tensor: tensor.float().div(255))
]), key="image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float().div(255)),
key="gt_polygons_image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(lambda tensor: np.pi * tensor.float().div(255)),
key="gt_crossfield_angle"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float()),
key="distances", ignore_key_error=True),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float()),
key="sizes", ignore_key_error=True),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torchvision.transforms.Lambda(lambda tensor: tensor.float()),
key="class_freq", ignore_key_error=True),
]
if augmentations and config["data_aug_params"]["device"] == "cpu":
if config["data_aug_params"]["affine"]:
# Rotate angle field
device_transform_list.append(
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.functional.rotate_anglefield,
key=["gt_crossfield_angle", "rand_angle"],
outkey="gt_crossfield_angle"))
if config["data_aug_params"]["vflip"]:
device_transform_list.append(
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.functional.vflip_anglefield,
key=["gt_crossfield_angle", "rand_flip"],
outkey="gt_crossfield_angle"))
if config["data_aug_params"]["device"] == "cuda":
input_patch_size = config["dataset_params"]["input_patch_size"] if "input_patch_size" in config[
"dataset_params"] else None # No crop if None
if augmentations:
device_transform_list.append(CudaDataAugmentation(input_patch_size,
config["data_aug_params"]["vflip"],
config["data_aug_params"]["affine"],
config["data_aug_params"]["scaling"],
config["data_aug_params"]["color_jitter"]))
elif input_patch_size is not None:
device_transform_list.append(CudaCrop(input_patch_size))
device_transform_list.append(
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.functional.batch_normalize,
key=["image", "image_mean", "image_std"],
outkey="image"), )
device_transform = torchvision.transforms.Compose(device_transform_list)
return device_transform
def get_eval_online_cuda_transform(config):
device_transform_list = [
torch_lydorn.torchvision.transforms.TransformByKey(transform=torchvision.transforms.Compose([
torchvision.transforms.Lambda(lambda tensor: tensor.float().div(255))
]), key="image"),
torch_lydorn.torchvision.transforms.TransformByKey(
transform=torch_lydorn.torchvision.transforms.functional.batch_normalize,
key=["image", "image_mean", "image_std"],
outkey="image")
]
device_transform = torchvision.transforms.Compose(device_transform_list)
return device_transform