SunderAli17 commited on
Commit
c4c35ae
1 Parent(s): 8f49c43

Create transform.py

Browse files
Files changed (1) hide show
  1. eva_clip/transform.py +103 -0
eva_clip/transform.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms.functional as F
6
+
7
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
+ CenterCrop
9
+
10
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
+
12
+
13
+ class ResizeMaxSize(nn.Module):
14
+
15
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
+ super().__init__()
17
+ if not isinstance(max_size, int):
18
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
19
+ self.max_size = max_size
20
+ self.interpolation = interpolation
21
+ self.fn = min if fn == 'min' else min
22
+ self.fill = fill
23
+
24
+ def forward(self, img):
25
+ if isinstance(img, torch.Tensor):
26
+ height, width = img.shape[:2]
27
+ else:
28
+ width, height = img.size
29
+ scale = self.max_size / float(max(height, width))
30
+ if scale != 1.0:
31
+ new_size = tuple(round(dim * scale) for dim in (height, width))
32
+ img = F.resize(img, new_size, self.interpolation)
33
+ pad_h = self.max_size - new_size[0]
34
+ pad_w = self.max_size - new_size[1]
35
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
+ return img
37
+
38
+
39
+ def _convert_to_rgb(image):
40
+ return image.convert('RGB')
41
+
42
+
43
+ # class CatGen(nn.Module):
44
+ # def __init__(self, num=4):
45
+ # self.num = num
46
+ # def mixgen_batch(image, text):
47
+ # batch_size = image.shape[0]
48
+ # index = np.random.permutation(batch_size)
49
+
50
+ # cat_images = []
51
+ # for i in range(batch_size):
52
+ # # image mixup
53
+ # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
+ # # text concat
55
+ # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
+ # text = torch.stack(text)
57
+ # return image, text
58
+
59
+
60
+ def image_transform(
61
+ image_size: int,
62
+ is_train: bool,
63
+ mean: Optional[Tuple[float, ...]] = None,
64
+ std: Optional[Tuple[float, ...]] = None,
65
+ resize_longest_max: bool = False,
66
+ fill_color: int = 0,
67
+ ):
68
+ mean = mean or OPENAI_DATASET_MEAN
69
+ if not isinstance(mean, (list, tuple)):
70
+ mean = (mean,) * 3
71
+
72
+ std = std or OPENAI_DATASET_STD
73
+ if not isinstance(std, (list, tuple)):
74
+ std = (std,) * 3
75
+
76
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
+ image_size = image_size[0]
79
+
80
+ normalize = Normalize(mean=mean, std=std)
81
+ if is_train:
82
+ return Compose([
83
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
+ _convert_to_rgb,
85
+ ToTensor(),
86
+ normalize,
87
+ ])
88
+ else:
89
+ if resize_longest_max:
90
+ transforms = [
91
+ ResizeMaxSize(image_size, fill=fill_color)
92
+ ]
93
+ else:
94
+ transforms = [
95
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
+ CenterCrop(image_size),
97
+ ]
98
+ transforms.extend([
99
+ _convert_to_rgb,
100
+ ToTensor(),
101
+ normalize,
102
+ ])
103
+ return Compose(transforms)