jwyang commited on
Commit
0b36c03
1 Parent(s): 7a45a62

first commit

Browse files
Files changed (40) hide show
  1. __pycache__/config.cpython-39.pyc +0 -0
  2. app.py +140 -0
  3. config.py +245 -0
  4. configs/unicl_focalnet_giant.yaml +16 -0
  5. configs/unicl_swin_base.yaml +16 -0
  6. configs/unicl_swin_tiny.yaml +16 -0
  7. model/__init__.py +1 -0
  8. model/__pycache__/__init__.cpython-39.pyc +0 -0
  9. model/__pycache__/model.cpython-39.pyc +0 -0
  10. model/__pycache__/templates.cpython-39.pyc +0 -0
  11. model/image_encoder/__init__.py +1 -0
  12. model/image_encoder/__pycache__/__init__.cpython-38.pyc +0 -0
  13. model/image_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  14. model/image_encoder/__pycache__/build.cpython-38.pyc +0 -0
  15. model/image_encoder/__pycache__/build.cpython-39.pyc +0 -0
  16. model/image_encoder/__pycache__/focalnet.cpython-38.pyc +0 -0
  17. model/image_encoder/__pycache__/focalnet.cpython-39.pyc +0 -0
  18. model/image_encoder/__pycache__/swin_transformer.cpython-38.pyc +0 -0
  19. model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc +0 -0
  20. model/image_encoder/build.py +59 -0
  21. model/image_encoder/focalnet.py +649 -0
  22. model/image_encoder/swin_transformer.py +586 -0
  23. model/model.py +204 -0
  24. model/templates.py +83 -0
  25. model/text_encoder/__init__.py +9 -0
  26. model/text_encoder/__pycache__/__init__.cpython-38.pyc +0 -0
  27. model/text_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  28. model/text_encoder/__pycache__/build.cpython-38.pyc +0 -0
  29. model/text_encoder/__pycache__/build.cpython-39.pyc +0 -0
  30. model/text_encoder/__pycache__/hf_model.cpython-38.pyc +0 -0
  31. model/text_encoder/__pycache__/hf_model.cpython-39.pyc +0 -0
  32. model/text_encoder/__pycache__/registry.cpython-38.pyc +0 -0
  33. model/text_encoder/__pycache__/registry.cpython-39.pyc +0 -0
  34. model/text_encoder/__pycache__/transformer.cpython-38.pyc +0 -0
  35. model/text_encoder/__pycache__/transformer.cpython-39.pyc +0 -0
  36. model/text_encoder/build.py +31 -0
  37. model/text_encoder/hf_model.py +27 -0
  38. model/text_encoder/registry.py +18 -0
  39. model/text_encoder/transformer.py +194 -0
  40. requirements.txt +5 -0
__pycache__/config.cpython-39.pyc ADDED
Binary file (3.73 kB). View file
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.data import create_transform
12
+ from config import get_config
13
+ from model import build_model
14
+
15
+ # Download human-readable labels for ImageNet.
16
+ response = requests.get("https://git.io/JJkYN")
17
+ labels = response.text.split("\n")
18
+
19
+ def parse_option():
20
+ parser = argparse.ArgumentParser('UniCL demo script', add_help=False)
21
+ parser.add_argument('--cfg', type=str, default="configs/unicl_swin_base.yaml", metavar="FILE", help='path to config file', )
22
+ args, unparsed = parser.parse_known_args()
23
+
24
+ config = get_config(args)
25
+
26
+ return args, config
27
+
28
+ def build_transforms(img_size, center_crop=True):
29
+ t = [transforms.ToPILImage()]
30
+ if center_crop:
31
+ size = int((256 / 224) * img_size)
32
+ t.append(
33
+ transforms.Resize(size)
34
+ )
35
+ t.append(
36
+ transforms.CenterCrop(img_size)
37
+ )
38
+ else:
39
+ t.append(
40
+ transforms.Resize(img_size)
41
+ )
42
+ t.append(transforms.ToTensor())
43
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
44
+ return transforms.Compose(t)
45
+
46
+ def build_transforms4display(img_size, center_crop=True):
47
+ t = [transforms.ToPILImage()]
48
+ if center_crop:
49
+ size = int((256 / 224) * img_size)
50
+ t.append(
51
+ transforms.Resize(size)
52
+ )
53
+ t.append(
54
+ transforms.CenterCrop(img_size)
55
+ )
56
+ else:
57
+ t.append(
58
+ transforms.Resize(img_size)
59
+ )
60
+ t.append(transforms.ToTensor())
61
+ return transforms.Compose(t)
62
+
63
+ args, config = parse_option()
64
+
65
+ '''
66
+ build model
67
+ '''
68
+ model = build_model(config)
69
+
70
+ url = 'https://projects4jw.blob.core.windows.net/unicl/release/in21k_yfcc14m_gcc15m_swin_base.pth'
71
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
72
+ model.load_state_dict(checkpoint["model"])
73
+ model.eval()
74
+
75
+ '''
76
+ build data transform
77
+ '''
78
+ eval_transforms = build_transforms(224, center_crop=True)
79
+ display_transforms = build_transforms4display(224, center_crop=True)
80
+
81
+ '''
82
+ build upsampler
83
+ '''
84
+ # upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
85
+
86
+ '''
87
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
88
+ '''
89
+ def show_cam_on_image(img: np.ndarray,
90
+ mask: np.ndarray,
91
+ use_rgb: bool = False,
92
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
93
+ """ This function overlays the cam mask on the image as an heatmap.
94
+ By default the heatmap is in BGR format.
95
+ :param img: The base image in RGB or BGR format.
96
+ :param mask: The cam mask.
97
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
98
+ :param colormap: The OpenCV colormap to be used.
99
+ :returns: The default image with the cam overlay.
100
+ """
101
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
102
+ if use_rgb:
103
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
104
+ heatmap = np.float32(heatmap) / 255
105
+
106
+ if np.max(img) > 1:
107
+ raise Exception(
108
+ "The input image should np.float32 in the range [0, 1]")
109
+
110
+ cam = 0.7*heatmap + 0.3*img
111
+ # cam = cam / np.max(cam)
112
+ return np.uint8(255 * cam)
113
+
114
+ def recognize_image(image, texts):
115
+ print(texts)
116
+ img_t = eval_transforms(image)
117
+ img_d = display_transforms(image).permute(1, 2, 0).numpy()
118
+
119
+ text_embeddings = model.get_text_embeddings(texts.split(';'))
120
+
121
+ # compute output
122
+ feat_img = model.encode_image(img_t.unsqueeze(0))
123
+ output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
124
+ prediction = output.softmax(-1).flatten()
125
+
126
+ return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
127
+
128
+
129
+ image = gr.inputs.Image()
130
+ label = gr.outputs.Label(num_top_classes=100)
131
+
132
+ gr.Interface(
133
+ description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
134
+ fn=recognize_image,
135
+ inputs=["image", "text"],
136
+ outputs=[
137
+ label,
138
+ ],
139
+ examples=[["./donut.png", 'a donut; several donuts; a number of donuts'], ["./horses.png"], ["./pencil.png"], ["./apple_with_ipod.jpg"]],
140
+ ).launch()
config.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Unified Contrastive Learning (UniCL)
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang (jianwyan@microsoft.com)
6
+ # Based on Swin Transformer written by Zhe Liu
7
+ # --------------------------------------------------------
8
+
9
+ import os
10
+ import yaml
11
+ from yacs.config import CfgNode as CN
12
+
13
+ _C = CN()
14
+ _C.VERBOSE = False
15
+
16
+ # Base config files
17
+ _C.BASE = ['']
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Data settings
21
+ # -----------------------------------------------------------------------------
22
+ _C.DATA = CN()
23
+ # Batch size for a single GPU, could be overwritten by command line argument
24
+ _C.DATA.BATCH_SIZE = 128
25
+ # Path to dataset, could be overwritten by command line argument
26
+ _C.DATA.DATA_PATH = ''
27
+ # Dataset name
28
+ _C.DATA.DATASET = 'imagenet'
29
+ # Input image size
30
+ _C.DATA.IMG_SIZE = 224
31
+ # Interpolation to resize image (random, bilinear, bicubic)
32
+ _C.DATA.INTERPOLATION = 'bicubic'
33
+ # Use zipped dataset instead of folder dataset
34
+ # could be overwritten by command line argument
35
+ _C.DATA.ZIP_MODE = False
36
+ # Cache Data in Memory, could be overwritten by command line argument
37
+ _C.DATA.CACHE_MODE = 'part'
38
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
39
+ _C.DATA.PIN_MEMORY = True
40
+ # Number of data loading threads
41
+ _C.DATA.NUM_WORKERS = 8
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Model settings
45
+ # -----------------------------------------------------------------------------
46
+ _C.MODEL = CN()
47
+ # Model name
48
+ _C.MODEL.NAME = ''
49
+ # Checkpoint to resume, could be overwritten by command line argument
50
+ _C.MODEL.RESUME = ''
51
+ # Number of classes, overwritten in data preparation
52
+ _C.MODEL.NUM_CLASSES = 0
53
+ # Label Smoothing
54
+ _C.MODEL.LABEL_SMOOTHING = 0.1
55
+ # Whether load pretrained model
56
+ _C.MODEL.PRETRAINED = ''
57
+ # Projection dimension
58
+ _C.MODEL.DIM_PROJECTION = 512
59
+ # Mode specific
60
+ _C.MODEL.SPEC = CN(new_allowed=True)
61
+ # -----------------------------------------------------------------------------
62
+ # Build Image Encoder
63
+ # -----------------------------------------------------------------------------
64
+ _C.MODEL.IMAGE_ENCODER = CN()
65
+ # Image encoder type
66
+ _C.MODEL.IMAGE_ENCODER.TYPE = 'swin'
67
+ # Input image size
68
+ _C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224
69
+ # Dropout rate
70
+ _C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0
71
+ # Drop path rate
72
+ _C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1
73
+
74
+ # Swin Transformer parameters
75
+ _C.MODEL.IMAGE_ENCODER.SWIN = CN()
76
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4
77
+ _C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3
78
+ _C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96
79
+ _C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2]
80
+ _C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24]
81
+ _C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7
82
+ _C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4.
83
+ _C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True
84
+ _C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None
85
+ _C.MODEL.IMAGE_ENCODER.SWIN.APE = False
86
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True
87
+
88
+ # FocalNet parameters
89
+ _C.MODEL.IMAGE_ENCODER.FOCAL = CN()
90
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4
91
+ _C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3
92
+ _C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96
93
+ _C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2]
94
+ _C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4.
95
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True
96
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2]
97
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3]
98
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2]
99
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False
100
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False
101
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False
102
+
103
+ # -----------------------------------------------------------------------------
104
+ # Build Text Encoder
105
+ # -----------------------------------------------------------------------------
106
+ _C.MODEL.TEXT_ENCODER = CN()
107
+
108
+ _C.MODEL.TEXT_ENCODER.NAME = 'transformer'
109
+ _C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False
110
+ _C.MODEL.TEXT_ENCODER.PRETRAINED = ''
111
+ _C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip'
112
+ _C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
113
+ _C.MODEL.TEXT_ENCODER.WIDTH = 1024
114
+ _C.MODEL.TEXT_ENCODER.HEADS = 16
115
+ _C.MODEL.TEXT_ENCODER.LAYERS = 12
116
+ _C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True
117
+
118
+ # -----------------------------------------------------------------------------
119
+ # Training settings
120
+ # -----------------------------------------------------------------------------
121
+ _C.TRAIN = CN()
122
+ _C.TRAIN.START_EPOCH = 0
123
+ _C.TRAIN.EPOCHS = 32
124
+ _C.TRAIN.WARMUP_EPOCHS = 5
125
+ _C.TRAIN.WEIGHT_DECAY = 0.1
126
+ _C.TRAIN.BASE_LR = 5e-4
127
+ _C.TRAIN.WARMUP_LR = 5e-7
128
+ _C.TRAIN.MIN_LR = 5e-6
129
+ # Clip gradient norm
130
+ _C.TRAIN.CLIP_GRAD = 5.0
131
+ # Auto resume from latest checkpoint
132
+ _C.TRAIN.AUTO_RESUME = True
133
+ # Gradient accumulation steps
134
+ # could be overwritten by command line argument
135
+ _C.TRAIN.ACCUMULATION_STEPS = 0
136
+ # Whether to use gradient checkpointing to save memory
137
+ # could be overwritten by command line argument
138
+ _C.TRAIN.USE_CHECKPOINT = False
139
+
140
+ # LR scheduler
141
+ _C.TRAIN.LR_SCHEDULER = CN()
142
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
143
+ # Epoch interval to decay LR, used in StepLRScheduler
144
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
145
+ # LR decay rate, used in StepLRScheduler
146
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
147
+
148
+ # Optimizer
149
+ _C.TRAIN.OPTIMIZER = CN()
150
+ _C.TRAIN.OPTIMIZER.NAME = 'adamw'
151
+ # Optimizer Epsilon
152
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
153
+ # Optimizer Betas
154
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
155
+ # SGD momentum
156
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
157
+
158
+ # -----------------------------------------------------------------------------
159
+ # Augmentation settings
160
+ # -----------------------------------------------------------------------------
161
+ _C.AUG = CN()
162
+ # Color jitter factor
163
+ _C.AUG.COLOR_JITTER = 0.4
164
+ # Use AutoAugment policy. "v0" or "original"
165
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
166
+ # Random erase prob
167
+ _C.AUG.REPROB = 0.25
168
+ # Random erase mode
169
+ _C.AUG.REMODE = 'pixel'
170
+ # Random erase count
171
+ _C.AUG.RECOUNT = 1
172
+ # Mixup alpha, mixup enabled if > 0
173
+ _C.AUG.MIXUP = 0.8
174
+ # Cutmix alpha, cutmix enabled if > 0
175
+ _C.AUG.CUTMIX = 1.0
176
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
177
+ _C.AUG.CUTMIX_MINMAX = None
178
+ # Probability of performing mixup or cutmix when either/both is enabled
179
+ _C.AUG.MIXUP_PROB = 1.0
180
+ # Probability of switching to cutmix when both mixup and cutmix enabled
181
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
182
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
183
+ _C.AUG.MIXUP_MODE = 'batch'
184
+
185
+ # -----------------------------------------------------------------------------
186
+ # Testing settings
187
+ # -----------------------------------------------------------------------------
188
+ _C.TEST = CN()
189
+ # Whether to use center crop when testing
190
+ _C.TEST.CROP = True
191
+
192
+ # -----------------------------------------------------------------------------
193
+ # Misc
194
+ # -----------------------------------------------------------------------------
195
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
196
+ # overwritten by command line argument
197
+ _C.AMP_OPT_LEVEL = ''
198
+ # Path to output folder, overwritten by command line argument
199
+ _C.OUTPUT = ''
200
+ # Tag of experiment, overwritten by command line argument
201
+ _C.TAG = 'default'
202
+ # Frequency to save checkpoint
203
+ _C.SAVE_FREQ = 1
204
+ # Frequency to logging info
205
+ _C.PRINT_FREQ = 100
206
+ # Fixed random seed
207
+ _C.SEED = 0
208
+ # Perform evaluation only, overwritten by command line argument
209
+ _C.EVAL_MODE = False
210
+ # Test throughput only, overwritten by command line argument
211
+ _C.THROUGHPUT_MODE = False
212
+ # Debug only so that skip dataloader initialization, overwritten by command line argument
213
+ _C.DEBUG_MODE = False
214
+ # local rank for DistributedDataParallel, given by command line argument
215
+ _C.LOCAL_RANK = 0
216
+
217
+
218
+ def _update_config_from_file(config, cfg_file):
219
+ config.defrost()
220
+ with open(cfg_file, 'r') as f:
221
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
222
+
223
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
224
+ if cfg:
225
+ _update_config_from_file(
226
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
227
+ )
228
+ print('=> merge config from {}'.format(cfg_file))
229
+ config.merge_from_file(cfg_file)
230
+ config.freeze()
231
+
232
+
233
+ def update_config(config, args):
234
+ _update_config_from_file(config, args.cfg)
235
+ config.freeze()
236
+
237
+
238
+ def get_config(args):
239
+ """Get a yacs CfgNode object with default values."""
240
+ # Return a clone so that the defaults will not be altered
241
+ # This is for the "local variable" use pattern
242
+ config = _C.clone()
243
+ update_config(config, args)
244
+
245
+ return config
configs/unicl_focalnet_giant.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: unicl_focalnet_giant
3
+ DIM_PROJECTION: 1024
4
+ IMAGE_ENCODER:
5
+ TYPE: focalnet_giant_lrf
6
+ DROP_PATH_RATE: 0.5
7
+ FOCAL:
8
+ USE_POSTLN: False
9
+ USE_CONV_EMBED: False
10
+ EMBED_DIM: 512
11
+ USE_LAYERSCALE: True
12
+ TEXT_ENCODER:
13
+ NAME: 'transformer'
14
+ WIDTH: 1024
15
+ HEADS: 16
16
+ LAYERS: 16
configs/unicl_swin_base.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: unicl_swin_base
3
+ DIM_PROJECTION: 512
4
+ IMAGE_ENCODER:
5
+ TYPE: swin
6
+ DROP_PATH_RATE: 0.5
7
+ SWIN:
8
+ EMBED_DIM: 128
9
+ DEPTHS: [ 2, 2, 18, 2 ]
10
+ NUM_HEADS: [ 4, 8, 16, 32 ]
11
+ WINDOW_SIZE: 7
12
+ TEXT_ENCODER:
13
+ NAME: 'transformer'
14
+ WIDTH: 512
15
+ HEADS: 8
16
+ LAYERS: 12
configs/unicl_swin_tiny.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: unicl_swin_tiny
3
+ DIM_PROJECTION: 512
4
+ IMAGE_ENCODER:
5
+ TYPE: swin
6
+ DROP_PATH_RATE: 0.2
7
+ SWIN:
8
+ EMBED_DIM: 96
9
+ DEPTHS: [ 2, 2, 6, 2 ]
10
+ NUM_HEADS: [ 3, 6, 12, 24 ]
11
+ WINDOW_SIZE: 7
12
+ TEXT_ENCODER:
13
+ NAME: 'transformer'
14
+ WIDTH: 512
15
+ HEADS: 8
16
+ LAYERS: 12
model/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .model import build_unicl_model as build_model
model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (190 Bytes). View file
model/__pycache__/model.cpython-39.pyc ADDED
Binary file (6.82 kB). View file
model/__pycache__/templates.cpython-39.pyc ADDED
Binary file (1.99 kB). View file
model/image_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .build import build_model as build_image_encoder
model/image_encoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (201 Bytes). View file
model/image_encoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (206 Bytes). View file
model/image_encoder/__pycache__/build.cpython-38.pyc ADDED
Binary file (1.13 kB). View file
model/image_encoder/__pycache__/build.cpython-39.pyc ADDED
Binary file (1.36 kB). View file
model/image_encoder/__pycache__/focalnet.cpython-38.pyc ADDED
Binary file (19.6 kB). View file
model/image_encoder/__pycache__/focalnet.cpython-39.pyc ADDED
Binary file (19.8 kB). View file
model/image_encoder/__pycache__/swin_transformer.cpython-38.pyc ADDED
Binary file (19.9 kB). View file
model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc ADDED
Binary file (19.8 kB). View file
model/image_encoder/build.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.models import create_model
2
+ from .swin_transformer import SwinTransformer
3
+ from . import focalnet
4
+
5
+ def build_model(config):
6
+ model_type = config.TYPE
7
+ print(f"Creating model: {model_type}")
8
+
9
+ if "swin" in model_type:
10
+ model = SwinTransformer(
11
+ num_classes=0,
12
+ img_size=config.IMG_SIZE,
13
+ patch_size=config.SWIN.PATCH_SIZE,
14
+ in_chans=config.SWIN.IN_CHANS,
15
+ embed_dim=config.SWIN.EMBED_DIM,
16
+ depths=config.SWIN.DEPTHS,
17
+ num_heads=config.SWIN.NUM_HEADS,
18
+ window_size=config.SWIN.WINDOW_SIZE,
19
+ mlp_ratio=config.SWIN.MLP_RATIO,
20
+ qkv_bias=config.SWIN.QKV_BIAS,
21
+ qk_scale=config.SWIN.QK_SCALE,
22
+ drop_rate=config.DROP_RATE,
23
+ drop_path_rate=config.DROP_PATH_RATE,
24
+ ape=config.SWIN.APE,
25
+ patch_norm=config.SWIN.PATCH_NORM,
26
+ use_checkpoint=False
27
+ )
28
+ elif "focal" in model_type:
29
+ model = create_model(
30
+ model_type,
31
+ pretrained=False,
32
+ img_size=config.IMG_SIZE,
33
+ num_classes=0,
34
+ drop_path_rate=config.DROP_PATH_RATE,
35
+ use_conv_embed=config.FOCAL.USE_CONV_EMBED,
36
+ use_layerscale=config.FOCAL.USE_LAYERSCALE,
37
+ use_postln=config.FOCAL.USE_POSTLN
38
+ )
39
+
40
+ elif "vit" in model_type:
41
+ model = create_model(
42
+ model_type,
43
+ pretrained=is_pretrained,
44
+ img_size=config.DATA.IMG_SIZE,
45
+ num_classes=config.MODEL.NUM_CLASSES,
46
+ )
47
+ elif "resnet" in model_type:
48
+ model = create_model(
49
+ model_type,
50
+ pretrained=is_pretrained,
51
+ num_classes=config.MODEL.NUM_CLASSES
52
+ )
53
+ else:
54
+ model = create_model(
55
+ model_type,
56
+ pretrained=is_pretrained,
57
+ num_classes=config.MODEL.NUM_CLASSES
58
+ )
59
+ return model
model/image_encoder/focalnet.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FocalNets -- Focal Modulation Networks
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang (jianwyan@microsoft.com)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+ from timm.models.registry import register_model
14
+
15
+ from torchvision import transforms
16
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from timm.data import create_transform
18
+ from timm.data.transforms import _pil_interp
19
+
20
+ class Mlp(nn.Module):
21
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
22
+ super().__init__()
23
+ out_features = out_features or in_features
24
+ hidden_features = hidden_features or in_features
25
+ self.fc1 = nn.Linear(in_features, hidden_features)
26
+ self.act = act_layer()
27
+ self.fc2 = nn.Linear(hidden_features, out_features)
28
+ self.drop = nn.Dropout(drop)
29
+
30
+ def forward(self, x):
31
+ x = self.fc1(x)
32
+ x = self.act(x)
33
+ x = self.drop(x)
34
+ x = self.fc2(x)
35
+ x = self.drop(x)
36
+ return x
37
+
38
+ class FocalModulation(nn.Module):
39
+ def __init__(self, dim, focal_window, focal_level, focal_factor=2, bias=True, proj_drop=0.):
40
+ super().__init__()
41
+
42
+ self.dim = dim
43
+ self.focal_window = focal_window
44
+ self.focal_level = focal_level
45
+ self.focal_factor = focal_factor
46
+
47
+ self.f = nn.Linear(dim, 2*dim + (self.focal_level+1), bias=bias)
48
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
49
+
50
+ self.act = nn.GELU()
51
+ self.proj = nn.Linear(dim, dim)
52
+ self.proj_drop = nn.Dropout(proj_drop)
53
+ self.focal_layers = nn.ModuleList()
54
+
55
+ self.kernel_sizes = []
56
+ for k in range(self.focal_level):
57
+ kernel_size = self.focal_factor*k + self.focal_window
58
+ self.focal_layers.append(
59
+ nn.Sequential(
60
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1,
61
+ groups=dim, padding=kernel_size//2, bias=False),
62
+ nn.GELU(),
63
+ )
64
+ )
65
+ self.kernel_sizes.append(kernel_size)
66
+ def forward(self, x):
67
+ """
68
+ Args:
69
+ x: input features with shape of (B, H, W, C)
70
+ """
71
+ C = x.shape[-1]
72
+
73
+ # pre linear projection
74
+ x = self.f(x).permute(0, 3, 1, 2).contiguous()
75
+ q, ctx, self.gates = torch.split(x, (C, C, self.focal_level+1), 1)
76
+
77
+ # context aggreation
78
+ ctx_all = 0
79
+ for l in range(self.focal_level):
80
+ ctx = self.focal_layers[l](ctx)
81
+ ctx_all = ctx_all + ctx*self.gates[:, l:l+1]
82
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
83
+ ctx_all = ctx_all + ctx_global*self.gates[:,self.focal_level:]
84
+
85
+ # focal modulation
86
+ self.modulator = self.h(ctx_all)
87
+ x_out = q*self.modulator
88
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
89
+
90
+ # post linear porjection
91
+ x_out = self.proj(x_out)
92
+ x_out = self.proj_drop(x_out)
93
+ return x_out
94
+
95
+ def extra_repr(self) -> str:
96
+ return f'dim={self.dim}'
97
+
98
+ def flops(self, N):
99
+ # calculate flops for 1 window with token length of N
100
+ flops = 0
101
+
102
+ flops += N * self.dim * (self.dim * 2 + (self.focal_level+1))
103
+
104
+ # focal convolution
105
+ for k in range(self.focal_level):
106
+ flops += N * (self.kernel_sizes[k]**2+1) * self.dim
107
+
108
+ # global gating
109
+ flops += N * 1 * self.dim
110
+
111
+ # self.linear
112
+ flops += N * self.dim * (self.dim + 1)
113
+
114
+ # x = self.proj(x)
115
+ flops += N * self.dim * self.dim
116
+ return flops
117
+
118
+ class FocalNetBlock(nn.Module):
119
+ r""" Focal Modulation Network Block.
120
+
121
+ Args:
122
+ dim (int): Number of input channels.
123
+ input_resolution (tuple[int]): Input resulotion.
124
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
125
+ drop (float, optional): Dropout rate. Default: 0.0
126
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
127
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
128
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
129
+ focal_level (int): Number of focal levels.
130
+ focal_window (int): Focal window size at first focal level
131
+ use_layerscale (bool): Whether use layerscale
132
+ layerscale_value (float): Initial layerscale value
133
+ use_postln (bool): Whether use layernorm after modulation
134
+ """
135
+
136
+ def __init__(self, dim, input_resolution, mlp_ratio=4., drop=0., drop_path=0.,
137
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
138
+ focal_level=1, focal_window=3,
139
+ use_layerscale=False, layerscale_value=1e-4,
140
+ use_postln=False):
141
+ super().__init__()
142
+ self.dim = dim
143
+ self.input_resolution = input_resolution
144
+ self.mlp_ratio = mlp_ratio
145
+
146
+ self.focal_window = focal_window
147
+ self.focal_level = focal_level
148
+ self.use_postln = use_postln
149
+
150
+ self.norm1 = norm_layer(dim)
151
+ self.modulation = FocalModulation(dim, proj_drop=drop, focal_window=focal_window, focal_level=self.focal_level)
152
+
153
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
154
+ self.norm2 = norm_layer(dim)
155
+ mlp_hidden_dim = int(dim * mlp_ratio)
156
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
157
+
158
+ self.alpha = 3.0 if self.use_postln else 1.0
159
+
160
+ self.gamma_1 = 1.0
161
+ self.gamma_2 = 1.0
162
+ if use_layerscale:
163
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
164
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
165
+
166
+ self.H = None
167
+ self.W = None
168
+
169
+ def forward(self, x):
170
+ H, W = self.H, self.W
171
+ B, L, C = x.shape
172
+ shortcut = x
173
+
174
+ # Focal Modulation
175
+ if not self.use_postln:
176
+ x = self.norm1(x)
177
+ x = x.view(B, H, W, C)
178
+ x = self.modulation(x).view(B, H * W, C)
179
+
180
+ # FFN
181
+ x = shortcut*self.alpha + self.drop_path(self.gamma_1 * x)
182
+ if self.use_postln:
183
+ x = self.norm1(x)
184
+
185
+ if not self.use_postln:
186
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
187
+ else:
188
+ x = x*self.alpha + self.drop_path(self.gamma_2 * self.mlp(x))
189
+ x = self.norm2(x)
190
+
191
+ return x
192
+
193
+ def extra_repr(self) -> str:
194
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, " \
195
+ f"mlp_ratio={self.mlp_ratio}"
196
+
197
+ def flops(self):
198
+ flops = 0
199
+ H, W = self.input_resolution
200
+ # norm1
201
+ flops += self.dim * H * W
202
+
203
+ # W-MSA/SW-MSA
204
+ flops += self.modulation.flops(H*W)
205
+
206
+ # mlp
207
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
208
+ # norm2
209
+ flops += self.dim * H * W
210
+ return flops
211
+
212
+ class BasicLayer(nn.Module):
213
+ """ A basic Focal Transformer layer for one stage.
214
+
215
+ Args:
216
+ dim (int): Number of input channels.
217
+ input_resolution (tuple[int]): Input resolution.
218
+ depth (int): Number of blocks.
219
+ window_size (int): Local window size.
220
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
221
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
222
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
223
+ drop (float, optional): Dropout rate. Default: 0.0
224
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
225
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
226
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
227
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
228
+ focal_level (int): Number of focal levels
229
+ focal_window (int): Focal window size at first focal level
230
+ use_layerscale (bool): Whether use layerscale
231
+ layerscale_value (float): Initial layerscale value
232
+ use_postln (bool): Whether use layernorm after modulation
233
+ """
234
+
235
+ def __init__(self, dim, out_dim, input_resolution, depth,
236
+ mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm,
237
+ downsample=None, use_checkpoint=False,
238
+ focal_level=1, focal_window=1,
239
+ use_conv_embed=False,
240
+ use_layerscale=False, layerscale_value=1e-4, use_postln=False):
241
+
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.input_resolution = input_resolution
245
+ self.depth = depth
246
+ self.use_checkpoint = use_checkpoint
247
+
248
+ # build blocks
249
+ self.blocks = nn.ModuleList([
250
+ FocalNetBlock(
251
+ dim=dim,
252
+ input_resolution=input_resolution,
253
+ mlp_ratio=mlp_ratio,
254
+ drop=drop,
255
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
256
+ norm_layer=norm_layer,
257
+ focal_level=focal_level,
258
+ focal_window=focal_window,
259
+ use_layerscale=use_layerscale,
260
+ layerscale_value=layerscale_value,
261
+ use_postln=use_postln,
262
+ )
263
+ for i in range(depth)])
264
+
265
+ if downsample is not None:
266
+ self.downsample = downsample(
267
+ img_size=input_resolution,
268
+ patch_size=2,
269
+ in_chans=dim,
270
+ embed_dim=out_dim,
271
+ use_conv_embed=use_conv_embed,
272
+ norm_layer=norm_layer,
273
+ is_stem=False
274
+ )
275
+ else:
276
+ self.downsample = None
277
+
278
+ def forward(self, x, H, W):
279
+ for blk in self.blocks:
280
+ blk.H, blk.W = H, W
281
+ if self.use_checkpoint:
282
+ x = checkpoint.checkpoint(blk, x)
283
+ else:
284
+ x = blk(x)
285
+
286
+ if self.downsample is not None:
287
+ x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
288
+ x, Ho, Wo = self.downsample(x)
289
+ else:
290
+ Ho, Wo = H, W
291
+ return x, Ho, Wo
292
+
293
+ def extra_repr(self) -> str:
294
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
295
+
296
+ def flops(self):
297
+ flops = 0
298
+ for blk in self.blocks:
299
+ flops += blk.flops()
300
+ if self.downsample is not None:
301
+ flops += self.downsample.flops()
302
+ return flops
303
+
304
+ class PatchEmbed(nn.Module):
305
+ r""" Image to Patch Embedding
306
+
307
+ Args:
308
+ img_size (int): Image size. Default: 224.
309
+ patch_size (int): Patch token size. Default: 4.
310
+ in_chans (int): Number of input image channels. Default: 3.
311
+ embed_dim (int): Number of linear projection output channels. Default: 96.
312
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
313
+ """
314
+
315
+ def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, use_conv_embed=False, norm_layer=None, is_stem=False):
316
+ super().__init__()
317
+ patch_size = to_2tuple(patch_size)
318
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
319
+ self.img_size = img_size
320
+ self.patch_size = patch_size
321
+ self.patches_resolution = patches_resolution
322
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
323
+
324
+ self.in_chans = in_chans
325
+ self.embed_dim = embed_dim
326
+
327
+ if use_conv_embed:
328
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
329
+ if is_stem:
330
+ kernel_size = 7; padding = 2; stride = 4
331
+ else:
332
+ kernel_size = 3; padding = 1; stride = 2
333
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
334
+ else:
335
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
336
+
337
+ if norm_layer is not None:
338
+ self.norm = norm_layer(embed_dim)
339
+ else:
340
+ self.norm = None
341
+
342
+ def forward(self, x):
343
+ B, C, H, W = x.shape
344
+
345
+ x = self.proj(x)
346
+ H, W = x.shape[2:]
347
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
348
+ if self.norm is not None:
349
+ x = self.norm(x)
350
+ return x, H, W
351
+
352
+ def flops(self):
353
+ Ho, Wo = self.patches_resolution
354
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
355
+ if self.norm is not None:
356
+ flops += Ho * Wo * self.embed_dim
357
+ return flops
358
+
359
+ class FocalNet(nn.Module):
360
+ r""" Focal Modulation Networks (FocalNets)
361
+
362
+ Args:
363
+ img_size (int | tuple(int)): Input image size. Default 224
364
+ patch_size (int | tuple(int)): Patch size. Default: 4
365
+ in_chans (int): Number of input image channels. Default: 3
366
+ num_classes (int): Number of classes for classification head. Default: 1000
367
+ embed_dim (int): Patch embedding dimension. Default: 96
368
+ depths (tuple(int)): Depth of each Focal Transformer layer.
369
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
370
+ drop_rate (float): Dropout rate. Default: 0
371
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
372
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
373
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
374
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
375
+ focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
376
+ focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
377
+ use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance, but we do not use it by default. Default: False
378
+ use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
379
+ layerscale_value (float): Value for layer scale. Default: 1e-4
380
+ use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models)
381
+ """
382
+ def __init__(self,
383
+ img_size=224,
384
+ patch_size=4,
385
+ in_chans=3,
386
+ num_classes=1000,
387
+ embed_dim=96,
388
+ depths=[2, 2, 6, 2],
389
+ mlp_ratio=4.,
390
+ drop_rate=0.,
391
+ drop_path_rate=0.1,
392
+ norm_layer=nn.LayerNorm,
393
+ patch_norm=True,
394
+ use_checkpoint=False,
395
+ focal_levels=[2, 2, 2, 2],
396
+ focal_windows=[3, 3, 3, 3],
397
+ use_conv_embed=False,
398
+ use_layerscale=False,
399
+ layerscale_value=1e-4,
400
+ use_postln=False,
401
+ **kwargs):
402
+ super().__init__()
403
+
404
+ self.num_layers = len(depths)
405
+ embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
406
+
407
+ self.num_classes = num_classes
408
+ self.embed_dim = embed_dim
409
+ self.patch_norm = patch_norm
410
+ self.num_features = embed_dim[-1]
411
+ self.mlp_ratio = mlp_ratio
412
+
413
+ # split image into patches using either non-overlapped embedding or overlapped embedding
414
+ self.patch_embed = PatchEmbed(
415
+ img_size=to_2tuple(img_size),
416
+ patch_size=patch_size,
417
+ in_chans=in_chans,
418
+ embed_dim=embed_dim[0],
419
+ use_conv_embed=use_conv_embed,
420
+ norm_layer=norm_layer if self.patch_norm else None,
421
+ is_stem=True)
422
+
423
+ num_patches = self.patch_embed.num_patches
424
+ patches_resolution = self.patch_embed.patches_resolution
425
+ self.patches_resolution = patches_resolution
426
+ self.pos_drop = nn.Dropout(p=drop_rate)
427
+
428
+ # stochastic depth
429
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
430
+
431
+ # build layers
432
+ self.layers = nn.ModuleList()
433
+ for i_layer in range(self.num_layers):
434
+ layer = BasicLayer(dim=embed_dim[i_layer],
435
+ out_dim=embed_dim[i_layer+1] if (i_layer < self.num_layers - 1) else None,
436
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
437
+ patches_resolution[1] // (2 ** i_layer)),
438
+ depth=depths[i_layer],
439
+ mlp_ratio=self.mlp_ratio,
440
+ drop=drop_rate,
441
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
442
+ norm_layer=norm_layer,
443
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
444
+ focal_level=focal_levels[i_layer],
445
+ focal_window=focal_windows[i_layer],
446
+ use_conv_embed=use_conv_embed,
447
+ use_checkpoint=use_checkpoint,
448
+ use_layerscale=use_layerscale,
449
+ layerscale_value=layerscale_value,
450
+ use_postln=use_postln,
451
+ )
452
+ self.layers.append(layer)
453
+
454
+ self.norm = norm_layer(self.num_features)
455
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
456
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
457
+ self.dim_out = self.num_features
458
+
459
+ self.apply(self._init_weights)
460
+
461
+ def _init_weights(self, m):
462
+ if isinstance(m, nn.Linear):
463
+ trunc_normal_(m.weight, std=.02)
464
+ if isinstance(m, nn.Linear) and m.bias is not None:
465
+ nn.init.constant_(m.bias, 0)
466
+ elif isinstance(m, nn.LayerNorm):
467
+ nn.init.constant_(m.bias, 0)
468
+ nn.init.constant_(m.weight, 1.0)
469
+
470
+ @torch.jit.ignore
471
+ def no_weight_decay(self):
472
+ return {''}
473
+
474
+ @torch.jit.ignore
475
+ def no_weight_decay_keywords(self):
476
+ return {''}
477
+
478
+ def forward_features(self, x):
479
+ x, H, W = self.patch_embed(x)
480
+ x = self.pos_drop(x)
481
+
482
+ for layer in self.layers:
483
+ x, H, W = layer(x, H, W)
484
+ x = self.norm(x) # B L C
485
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
486
+ x = torch.flatten(x, 1)
487
+ return x
488
+
489
+ def forward(self, x):
490
+ x = self.forward_features(x)
491
+ x = self.head(x)
492
+ return x
493
+
494
+ def flops(self):
495
+ flops = 0
496
+ flops += self.patch_embed.flops()
497
+ for i, layer in enumerate(self.layers):
498
+ flops += layer.flops()
499
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
500
+ flops += self.num_features * self.num_classes
501
+ return flops
502
+
503
+ def build_transforms(img_size, center_crop=False):
504
+ t = []
505
+ if center_crop:
506
+ size = int((256 / 224) * img_size)
507
+ t.append(
508
+ transforms.Resize(size, interpolation=_pil_interp('bicubic'))
509
+ )
510
+ t.append(
511
+ transforms.CenterCrop(img_size)
512
+ )
513
+ else:
514
+ t.append(
515
+ transforms.Resize(img_size, interpolation=_pil_interp('bicubic'))
516
+ )
517
+ t.append(transforms.ToTensor())
518
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
519
+ return transforms.Compose(t)
520
+
521
+ def build_transforms4display(img_size, center_crop=False):
522
+ t = []
523
+ if center_crop:
524
+ size = int((256 / 224) * img_size)
525
+ t.append(
526
+ transforms.Resize(size, interpolation=_pil_interp('bicubic'))
527
+ )
528
+ t.append(
529
+ transforms.CenterCrop(img_size)
530
+ )
531
+ else:
532
+ t.append(
533
+ transforms.Resize(img_size, interpolation=_pil_interp('bicubic'))
534
+ )
535
+ t.append(transforms.ToTensor())
536
+ return transforms.Compose(t)
537
+
538
+ model_urls = {
539
+ "focalnet_tiny_srf": "",
540
+ "focalnet_small_srf": "",
541
+ "focalnet_base_srf": "",
542
+ "focalnet_tiny_lrf": "",
543
+ "focalnet_small_lrf": "",
544
+ "focalnet_base_lrf": "",
545
+ }
546
+
547
+ @register_model
548
+ def focalnet_tiny_srf(pretrained=False, **kwargs):
549
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
550
+ if pretrained:
551
+ url = model_urls['focalnet_tiny_srf']
552
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
553
+ model.load_state_dict(checkpoint["model"])
554
+ return model
555
+
556
+ @register_model
557
+ def focalnet_small_srf(pretrained=False, **kwargs):
558
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
559
+ if pretrained:
560
+ url = model_urls['focalnet_small_srf']
561
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
562
+ model.load_state_dict(checkpoint["model"])
563
+ return model
564
+
565
+ @register_model
566
+ def focalnet_base_srf(pretrained=False, **kwargs):
567
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
568
+ if pretrained:
569
+ url = model_urls['focalnet_base_srf']
570
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
571
+ model.load_state_dict(checkpoint["model"])
572
+ return model
573
+
574
+ @register_model
575
+ def focalnet_tiny_lrf(pretrained=False, **kwargs):
576
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
577
+ if pretrained:
578
+ url = model_urls['focalnet_tiny_lrf']
579
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
580
+ model.load_state_dict(checkpoint["model"])
581
+ return model
582
+
583
+ @register_model
584
+ def focalnet_small_lrf(pretrained=False, **kwargs):
585
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
586
+ if pretrained:
587
+ url = model_urls['focalnet_small_lrf']
588
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
589
+ model.load_state_dict(checkpoint["model"])
590
+ return model
591
+
592
+ @register_model
593
+ def focalnet_base_lrf(pretrained=False, **kwargs):
594
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
595
+ if pretrained:
596
+ url = model_urls['focalnet_base_lrf']
597
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
598
+ model.load_state_dict(checkpoint["model"])
599
+ return model
600
+
601
+ @register_model
602
+ def focalnet_giant_lrf(pretrained=False, **kwargs):
603
+ model = FocalNet(depths=[2, 2, 42, 2], embed_dim=512, focal_levels=[3, 3, 3, 3], **kwargs)
604
+ if pretrained:
605
+ url = model_urls['focalnet_giant_lrf']
606
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
607
+ model.load_state_dict(checkpoint["model"])
608
+ return model
609
+
610
+ @register_model
611
+ def focalnet_tiny_iso_16(pretrained=False, **kwargs):
612
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=192, focal_levels=[3], focal_windows=[3], **kwargs)
613
+ if pretrained:
614
+ url = model_urls['focalnet_tiny_iso_16']
615
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
616
+ model.load_state_dict(checkpoint["model"])
617
+ return model
618
+
619
+ @register_model
620
+ def focalnet_small_iso_16(pretrained=False, **kwargs):
621
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=384, focal_levels=[3], focal_windows=[3], **kwargs)
622
+ if pretrained:
623
+ url = model_urls['focalnet_small_iso_16']
624
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
625
+ model.load_state_dict(checkpoint["model"])
626
+ return model
627
+
628
+ @register_model
629
+ def focalnet_base_iso_16(pretrained=False, **kwargs):
630
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True, **kwargs)
631
+ if pretrained:
632
+ url = model_urls['focalnet_base_iso_16']
633
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
634
+ model.load_state_dict(checkpoint["model"])
635
+ return model
636
+
637
+ if __name__ == '__main__':
638
+ img_size = 224
639
+ x = torch.rand(16, 3, img_size, img_size).cuda()
640
+ # model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96)
641
+ # model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], focal_factors=[2])
642
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3]).cuda()
643
+ print(model); model(x)
644
+
645
+ flops = model.flops()
646
+ print(f"number of GFLOPs: {flops / 1e9}")
647
+
648
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
649
+ print(f"number of params: {n_parameters}")
model/image_encoder/swin_transformer.py ADDED
@@ -0,0 +1,586 @@