HEAT / datasets /data_utils.py
Egrt's picture
init
424188c
raw
history blame contribute delete
No virus
1.67 kB
from PIL import ImageFilter
from torchvision import transforms
import numpy as np
from utils.nn_utils import positional_encoding_2d
from torch.utils.data.dataloader import default_collate
def RandomBlur(radius=2.):
blur = GaussianBlur(radius=radius)
full_transform = transforms.RandomApply([blur], p=.3)
return full_transform
class ImageFilterTransform(object):
def __init__(self):
raise NotImplementedError
def __call__(self, img):
return img.filter(self.filter)
class GaussianBlur(ImageFilterTransform):
def __init__(self, radius=2.):
self.filter = ImageFilter.GaussianBlur(radius=radius)
def collate_fn(data):
batched_data = {}
for field in data[0].keys():
if field in ['annot', 'rec_mat']:
batch_values = [item[field] for item in data]
else:
batch_values = default_collate([d[field] for d in data])
if field in ['pixel_features', 'pixel_labels', 'gauss_labels']:
batch_values = batch_values.float()
batched_data[field] = batch_values
return batched_data
def get_pixel_features(image_size, d_pe=128):
all_pe = positional_encoding_2d(d_pe, image_size, image_size)
pixels_x = np.arange(0, image_size)
pixels_y = np.arange(0, image_size)
xv, yv = np.meshgrid(pixels_x, pixels_y)
all_pixels = list()
for i in range(xv.shape[0]):
pixs = np.stack([xv[i], yv[i]], axis=-1)
all_pixels.append(pixs)
pixels = np.stack(all_pixels, axis=0)
pixel_features = all_pe[:, pixels[:, :, 1], pixels[:, :, 0]]
pixel_features = pixel_features.permute(1, 2, 0)
return pixels, pixel_features