kangaroo / eva_clip /transform.py
WEBing's picture
upload
388bd8b
raw
history blame
3.38 kB
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == 'min' else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
return img
def _convert_to_rgb(image):
return image.convert('RGB')
# class CatGen(nn.Module):
# def __init__(self, num=4):
# self.num = num
# def mixgen_batch(image, text):
# batch_size = image.shape[0]
# index = np.random.permutation(batch_size)
# cat_images = []
# for i in range(batch_size):
# # image mixup
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
# # text concat
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
# text = torch.stack(text)
# return image, text
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose([
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
])
else:
if resize_longest_max:
transforms = [
ResizeMaxSize(image_size, fill=fill_color)
]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)