magma / magma /transforms.py
stellaathena's picture
This should work
bb5cd12
from torchvision import transforms as T
import torch.nn.functional as F
from PIL import ImageOps
import PIL
import random
def pad_to_size(x, size=256):
delta_w = size - x.size[0]
delta_h = size - x.size[1]
padding = (
delta_w // 2,
delta_h // 2,
delta_w - (delta_w // 2),
delta_h - (delta_h // 2),
)
new_im = ImageOps.expand(x, padding)
return new_im
def pad_to_size_tensor(x, size=256):
offset_dim_1 = size - x.shape[1]
offset_dim_2 = size - x.shape[2]
padding_dim_1 = max(offset_dim_1 // 2, 0)
padding_dim_2 = max(offset_dim_2 // 2, 0)
if offset_dim_1 % 2 == 0:
pad_tuple_1 = (padding_dim_1, padding_dim_1)
else:
pad_tuple_1 = (padding_dim_1 + 1, padding_dim_1)
if offset_dim_2 % 2 == 0:
pad_tuple_2 = (padding_dim_2, padding_dim_2)
else:
pad_tuple_2 = (padding_dim_2 + 1, padding_dim_2)
padded = F.pad(x, pad=(*pad_tuple_2, *pad_tuple_1, 0, 0))
return padded
class RandCropResize(object):
"""
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
"""
def __init__(self, target_size):
self.target_size = target_size
def __call__(self, img):
img = pad_to_size(img, self.target_size)
d_min = min(img.size)
img = T.RandomCrop(size=d_min)(img)
t_min = min(d_min, round(9 / 8 * self.target_size))
t_max = min(d_min, round(12 / 8 * self.target_size))
t = random.randint(t_min, t_max + 1)
img = T.Resize(t)(img)
if min(img.size) < 256:
img = T.Resize(256)(img)
return T.RandomCrop(size=self.target_size)(img)
def get_transforms(
image_size, encoder_name, input_resolution=None, use_extra_transforms=False
):
if "clip" in encoder_name:
assert input_resolution is not None
return clip_preprocess(input_resolution)
base_transforms = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
RandCropResize(image_size),
T.RandomHorizontalFlip(p=0.5),
]
if use_extra_transforms:
extra_transforms = [T.ColorJitter(0.1, 0.1, 0.1, 0.05)]
base_transforms += extra_transforms
base_transforms += [
T.ToTensor(),
maybe_add_batch_dim,
]
base_transforms = T.Compose(base_transforms)
return base_transforms
def maybe_add_batch_dim(t):
if t.ndim == 3:
return t.unsqueeze(0)
else:
return t
def pad_img(desired_size):
def fn(im):
old_size = im.size # old_size[0] is in (width, height) format
ratio = float(desired_size) / max(old_size)
new_size = tuple([int(x * ratio) for x in old_size])
im = im.resize(new_size, PIL.Image.ANTIALIAS)
# create a new image and paste the resized on it
new_im = PIL.Image.new("RGB", (desired_size, desired_size))
new_im.paste(
im, ((desired_size - new_size[0]) // 2, (desired_size - new_size[1]) // 2)
)
return new_im
return fn
def crop_or_pad(n_px, pad=False):
if pad:
return pad_img(n_px)
else:
return T.CenterCrop(n_px)
def clip_preprocess(n_px, use_pad=False):
return T.Compose(
[
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
crop_or_pad(n_px, pad=use_pad),
lambda image: image.convert("RGB"),
T.ToTensor(),
maybe_add_batch_dim,
T.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)