|
from PIL import ImageFilter |
|
from torchvision import transforms |
|
import numpy as np |
|
from utils.nn_utils import positional_encoding_2d |
|
from torch.utils.data.dataloader import default_collate |
|
|
|
|
|
def RandomBlur(radius=2.): |
|
blur = GaussianBlur(radius=radius) |
|
full_transform = transforms.RandomApply([blur], p=.3) |
|
return full_transform |
|
|
|
|
|
class ImageFilterTransform(object): |
|
|
|
def __init__(self): |
|
raise NotImplementedError |
|
|
|
def __call__(self, img): |
|
return img.filter(self.filter) |
|
|
|
|
|
class GaussianBlur(ImageFilterTransform): |
|
|
|
def __init__(self, radius=2.): |
|
self.filter = ImageFilter.GaussianBlur(radius=radius) |
|
|
|
|
|
def collate_fn(data): |
|
batched_data = {} |
|
for field in data[0].keys(): |
|
if field in ['annot', 'rec_mat']: |
|
batch_values = [item[field] for item in data] |
|
else: |
|
batch_values = default_collate([d[field] for d in data]) |
|
if field in ['pixel_features', 'pixel_labels', 'gauss_labels']: |
|
batch_values = batch_values.float() |
|
batched_data[field] = batch_values |
|
|
|
return batched_data |
|
|
|
|
|
def get_pixel_features(image_size, d_pe=128): |
|
all_pe = positional_encoding_2d(d_pe, image_size, image_size) |
|
pixels_x = np.arange(0, image_size) |
|
pixels_y = np.arange(0, image_size) |
|
|
|
xv, yv = np.meshgrid(pixels_x, pixels_y) |
|
all_pixels = list() |
|
for i in range(xv.shape[0]): |
|
pixs = np.stack([xv[i], yv[i]], axis=-1) |
|
all_pixels.append(pixs) |
|
pixels = np.stack(all_pixels, axis=0) |
|
|
|
pixel_features = all_pe[:, pixels[:, :, 1], pixels[:, :, 0]] |
|
pixel_features = pixel_features.permute(1, 2, 0) |
|
return pixels, pixel_features |