noelshin commited on
Commit
35188e4
1 Parent(s): 6c45278

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .idea/.gitignore +8 -0
  3. .idea/deployment.xml +15 -0
  4. .idea/inspectionProfiles/Project_Default.xml +23 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +4 -0
  7. .idea/modules.xml +8 -0
  8. .idea/selfmask_demo.iml +8 -0
  9. .idea/sonarlint/issuestore/index.pb +0 -0
  10. .idea/webServers.xml +14 -0
  11. __pycache__/bilateral_solver.cpython-38.pyc +0 -0
  12. __pycache__/utils.cpython-38.pyc +0 -0
  13. app.py +134 -0
  14. bilateral_solver.py +206 -0
  15. duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml +56 -0
  16. networks/__init__.py +0 -0
  17. networks/__pycache__/__init__.cpython-38.pyc +0 -0
  18. networks/__pycache__/timm_deit.cpython-38.pyc +0 -0
  19. networks/__pycache__/timm_vit.cpython-38.pyc +0 -0
  20. networks/__pycache__/vision_transformer.cpython-38.pyc +0 -0
  21. networks/maskformer/__pycache__/maskformer.cpython-38.pyc +0 -0
  22. networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc +0 -0
  23. networks/maskformer/maskformer.py +267 -0
  24. networks/maskformer/positional_embedding.py +48 -0
  25. networks/maskformer/transformer_decoder.py +376 -0
  26. networks/module_helper.py +176 -0
  27. networks/resnet.py +60 -0
  28. networks/resnet_backbone.py +194 -0
  29. networks/resnet_models.py +273 -0
  30. networks/timm_deit.py +254 -0
  31. networks/timm_vit.py +819 -0
  32. networks/vision_transformer.py +569 -0
  33. resources/.DS_Store +0 -0
  34. resources/0053.jpg +0 -0
  35. resources/0236.jpg +0 -0
  36. resources/0239.jpg +0 -0
  37. resources/0403.jpg +0 -0
  38. resources/0412.jpg +0 -0
  39. resources/ILSVRC2012_test_00005309.jpg +0 -0
  40. resources/ILSVRC2012_test_00012622.jpg +0 -0
  41. resources/ILSVRC2012_test_00022698.jpg +0 -0
  42. resources/ILSVRC2012_test_00040725.jpg +0 -0
  43. resources/ILSVRC2012_test_00075738.jpg +0 -0
  44. resources/ILSVRC2012_test_00080683.jpg +0 -0
  45. resources/ILSVRC2012_test_00085874.jpg +0 -0
  46. resources/im052.jpg +0 -0
  47. resources/sun_ainjbonxmervsvpv.jpg +0 -0
  48. resources/sun_alfntqzssslakmss.jpg +0 -0
  49. resources/sun_amnrcxhisjfrliwa.jpg +0 -0
  50. resources/sun_bvyxpvkouzlfwwod.jpg +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/deployment.xml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PublishConfigData" autoUpload="Always" serverName="mydev" remoteFilesAllowedToDisappearOnAutoupload="false">
4
+ <serverData>
5
+ <paths name="mydev">
6
+ <serverdata>
7
+ <mappings>
8
+ <mapping deploy="/" local="$PROJECT_DIR$" web="/" />
9
+ </mappings>
10
+ </serverdata>
11
+ </paths>
12
+ </serverData>
13
+ <option name="myAutoUpload" value="ALWAYS" />
14
+ </component>
15
+ </project>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="10">
8
+ <item index="0" class="java.lang.String" itemvalue="prettytable" />
9
+ <item index="1" class="java.lang.String" itemvalue="interrogate" />
10
+ <item index="2" class="java.lang.String" itemvalue="pytest" />
11
+ <item index="3" class="java.lang.String" itemvalue="yapf" />
12
+ <item index="4" class="java.lang.String" itemvalue="cityscapesscripts" />
13
+ <item index="5" class="java.lang.String" itemvalue="Wand" />
14
+ <item index="6" class="java.lang.String" itemvalue="isort" />
15
+ <item index="7" class="java.lang.String" itemvalue="xdoctest" />
16
+ <item index="8" class="java.lang.String" itemvalue="codecov" />
17
+ <item index="9" class="java.lang.String" itemvalue="flake8" />
18
+ </list>
19
+ </value>
20
+ </option>
21
+ </inspection_tool>
22
+ </profile>
23
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/selfmask_demo.iml" filepath="$PROJECT_DIR$/.idea/selfmask_demo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/selfmask_demo.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/sonarlint/issuestore/index.pb ADDED
File without changes
.idea/webServers.xml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="WebServers">
4
+ <option name="servers">
5
+ <webServer id="12e2cf4d-3b81-4241-9665-54a333f70567" name="mydev">
6
+ <fileTransfer rootFolder="/users/gyungin/selfmask_demo" accessType="SFTP" host="mydev" port="22" sshConfigId="3e23a652-ab3c-4dc2-a117-84c2bf217891" sshConfig="gyungin@mydev:22 password">
7
+ <advancedOptions>
8
+ <advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
9
+ </advancedOptions>
10
+ </fileTransfer>
11
+ </webServer>
12
+ </option>
13
+ </component>
14
+ </project>
__pycache__/bilateral_solver.cpython-38.pyc ADDED
Binary file (6.76 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.9 kB). View file
 
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+ from typing import Dict, List, Tuple
3
+ import yaml
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision.transforms.functional import to_tensor, normalize, resize
10
+ import gradio as gr
11
+ from utils import get_model
12
+ from bilateral_solver import bilateral_solver_output
13
+ import os
14
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ state_dict: dict = torch.hub.load_state_dict_from_url(
18
+ "https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt",
19
+ map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
20
+ )["model"]
21
+
22
+ parser = ArgumentParser("SelfMask demo")
23
+ parser.add_argument(
24
+ "--config",
25
+ type=str,
26
+ default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
27
+ )
28
+
29
+ # parser.add_argument(
30
+ # "--p_state_dict",
31
+ # type=str,
32
+ # default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt",
33
+ # )
34
+ #
35
+ # parser.add_argument(
36
+ # "--dataset_name", '-dn', type=str, default="duts",
37
+ # choices=["dut_omron", "duts", "ecssd"]
38
+ # )
39
+
40
+ # independent variables
41
+ # parser.add_argument("--use_gpu", type=bool, default=True)
42
+ # parser.add_argument('--seed', default=0, type=int)
43
+ # parser.add_argument("--dir_root", type=str, default="..")
44
+ # parser.add_argument("--gpu_id", type=int, default=2)
45
+ # parser.add_argument("--suffix", type=str, default='')
46
+ args: Namespace = parser.parse_args()
47
+ base_args = yaml.safe_load(open(f"{args.config}", 'r'))
48
+ base_args.pop("dataset_name")
49
+ args: dict = vars(args)
50
+ args.update(base_args)
51
+ args: Namespace = Namespace(**args)
52
+
53
+ model = get_model(arch="maskformer", configs=args).to(device)
54
+ model.load_state_dict(state_dict)
55
+ model.eval()
56
+
57
+
58
+ @torch.no_grad()
59
+ def main(
60
+ image: Image.Image,
61
+ size: int = 384,
62
+ max_size: int = 512,
63
+ mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
64
+ std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
65
+ ):
66
+ pil_image: Image.Image = resize(image, size=size, max_size=max_size)
67
+ image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
68
+ dict_outputs = model(image[None].to(device))
69
+
70
+ batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
71
+ batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
72
+
73
+ if len(batch_pred_masks.shape) == 5:
74
+ # b x n_layers x n_queries x h x w -> b x n_queries x h x w
75
+ batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
76
+
77
+ if batch_objectness is not None:
78
+ # b x n_layers x n_queries x 1 -> b x n_queries x 1
79
+ batch_objectness = batch_objectness[:, -1, ...]
80
+
81
+ # resize prediction to original resolution
82
+ # note: upsampling by 4 and cutting the padded region allows for a better result
83
+ H, W = image.shape[-2:]
84
+ batch_pred_masks = F.interpolate(
85
+ batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
86
+ )[..., :H, :W]
87
+
88
+ # iterate over batch dimension
89
+ for batch_index, pred_masks in enumerate(batch_pred_masks):
90
+ # n_queries x 1 -> n_queries
91
+ objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
92
+ ranks = torch.argsort(objectness, descending=True) # n_queries
93
+ pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
94
+ pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
95
+
96
+ pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
97
+ pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
98
+
99
+ attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
100
+ super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
101
+ return super_imposed_img
102
+ # return pred_mask_bi
103
+
104
+ demo = gr.Interface(
105
+ fn=main,
106
+ inputs=gr.inputs.Image(type="pil"),
107
+ outputs="image",
108
+ examples=[f"resources/{fname}.jpg" for fname in [
109
+ "0053",
110
+ "0236",
111
+ "0239",
112
+ "0403",
113
+ "0412",
114
+ "ILSVRC2012_test_00005309",
115
+ "ILSVRC2012_test_00012622",
116
+ "ILSVRC2012_test_00022698",
117
+ "ILSVRC2012_test_00040725",
118
+ "ILSVRC2012_test_00075738",
119
+ "ILSVRC2012_test_00080683",
120
+ "ILSVRC2012_test_00085874",
121
+ "im052",
122
+ "sun_ainjbonxmervsvpv",
123
+ "sun_alfntqzssslakmss",
124
+ "sun_amnrcxhisjfrliwa",
125
+ "sun_bvyxpvkouzlfwwod"
126
+ ]],
127
+ title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
128
+ allow_flagging="never",
129
+ analytics_enabled=False
130
+ )
131
+
132
+ demo.launch(
133
+ # share=True
134
+ )
bilateral_solver.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.sparse import diags
2
+ from scipy.sparse.linalg import cg
3
+ from scipy.sparse import csr_matrix
4
+ import numpy as np
5
+ from skimage.io import imread
6
+ from scipy import ndimage
7
+ import torch
8
+ import PIL.Image as Image
9
+ import os
10
+ from argparse import ArgumentParser, Namespace
11
+ from typing import Dict, Union
12
+ from collections import defaultdict
13
+ import yaml
14
+ import ujson as json
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from PIL import Image
19
+
20
+
21
+ RGB_TO_YUV = np.array([
22
+ [0.299, 0.587, 0.114],
23
+ [-0.168736, -0.331264, 0.5],
24
+ [0.5, -0.418688, -0.081312]])
25
+ YUV_TO_RGB = np.array([
26
+ [1.0, 0.0, 1.402],
27
+ [1.0, -0.34414, -0.71414],
28
+ [1.0, 1.772, 0.0]])
29
+ YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
30
+ MAX_VAL = 255.0
31
+
32
+
33
+ def rgb2yuv(im):
34
+ return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
35
+
36
+
37
+ def yuv2rgb(im):
38
+ return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
39
+
40
+
41
+ def get_valid_idx(valid, candidates):
42
+ """Find which values are present in a list and where they are located"""
43
+ locs = np.searchsorted(valid, candidates)
44
+ # Handle edge case where the candidate is larger than all valid values
45
+ locs = np.clip(locs, 0, len(valid) - 1)
46
+ # Identify which values are actually present
47
+ valid_idx = np.flatnonzero(valid[locs] == candidates)
48
+ locs = locs[valid_idx]
49
+ return valid_idx, locs
50
+
51
+
52
+ class BilateralGrid(object):
53
+ def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
54
+ im_yuv = rgb2yuv(im)
55
+ # Compute 5-dimensional XYLUV bilateral-space coordinates
56
+ Iy, Ix = np.mgrid[:im.shape[0], :im.shape[1]]
57
+ x_coords = (Ix / sigma_spatial).astype(int)
58
+ y_coords = (Iy / sigma_spatial).astype(int)
59
+ luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
60
+ chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
61
+ coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
62
+ coords_flat = coords.reshape(-1, coords.shape[-1])
63
+ self.npixels, self.dim = coords_flat.shape
64
+ # Hacky "hash vector" for coordinates,
65
+ # Requires all scaled coordinates be < MAX_VAL
66
+ self.hash_vec = (MAX_VAL ** np.arange(self.dim))
67
+ # Construct S and B matrix
68
+ self._compute_factorization(coords_flat)
69
+
70
+ def _compute_factorization(self, coords_flat):
71
+ # Hash each coordinate in grid to a unique value
72
+ hashed_coords = self._hash_coords(coords_flat)
73
+ unique_hashes, unique_idx, idx = \
74
+ np.unique(hashed_coords, return_index=True, return_inverse=True)
75
+ # Identify unique set of vertices
76
+ unique_coords = coords_flat[unique_idx]
77
+ self.nvertices = len(unique_coords)
78
+ # Construct sparse splat matrix that maps from pixels to vertices
79
+ self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
80
+ # Construct sparse blur matrices.
81
+ # Note that these represent [1 0 1] blurs, excluding the central element
82
+ self.blurs = []
83
+ for d in range(self.dim):
84
+ blur = 0.0
85
+ for offset in (-1, 1):
86
+ offset_vec = np.zeros((1, self.dim))
87
+ offset_vec[:, d] = offset
88
+ neighbor_hash = self._hash_coords(unique_coords + offset_vec)
89
+ valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
90
+ blur = blur + csr_matrix((np.ones((len(valid_coord),)),
91
+ (valid_coord, idx)),
92
+ shape=(self.nvertices, self.nvertices))
93
+ self.blurs.append(blur)
94
+
95
+ def _hash_coords(self, coord):
96
+ """Hacky function to turn a coordinate into a unique value"""
97
+ return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
98
+
99
+ def splat(self, x):
100
+ return self.S.dot(x)
101
+
102
+ def slice(self, y):
103
+ return self.S.T.dot(y)
104
+
105
+ def blur(self, x):
106
+ """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
107
+ assert x.shape[0] == self.nvertices
108
+ out = 2 * self.dim * x
109
+ for blur in self.blurs:
110
+ out = out + blur.dot(x)
111
+ return out
112
+
113
+ def filter(self, x):
114
+ """Apply bilateral filter to an input x"""
115
+ return self.slice(self.blur(self.splat(x))) / \
116
+ self.slice(self.blur(self.splat(np.ones_like(x))))
117
+
118
+
119
+ def bistochastize(grid, maxiter=10):
120
+ """Compute diagonal matrices to bistochastize a bilateral grid"""
121
+ m = grid.splat(np.ones(grid.npixels))
122
+ n = np.ones(grid.nvertices)
123
+ for i in range(maxiter):
124
+ n = np.sqrt(n * m / grid.blur(n))
125
+ # Correct m to satisfy the assumption of bistochastization regardless
126
+ # of how many iterations have been run.
127
+ m = n * grid.blur(n)
128
+ Dm = diags(m, 0)
129
+ Dn = diags(n, 0)
130
+ return Dn, Dm
131
+
132
+
133
+ class BilateralSolver(object):
134
+ def __init__(self, grid, params):
135
+ self.grid = grid
136
+ self.params = params
137
+ self.Dn, self.Dm = bistochastize(grid)
138
+
139
+ def solve(self, x, w):
140
+ # Check that w is a vector or a nx1 matrix
141
+ if w.ndim == 2:
142
+ assert (w.shape[1] == 1)
143
+ elif w.dim == 1:
144
+ w = w.reshape(w.shape[0], 1)
145
+ A_smooth = (self.Dm - self.Dn.dot(self.grid.blur(self.Dn)))
146
+ w_splat = self.grid.splat(w)
147
+ A_data = diags(w_splat[:, 0], 0)
148
+ A = self.params["lam"] * A_smooth + A_data
149
+ xw = x * w
150
+ b = self.grid.splat(xw)
151
+ # Use simple Jacobi preconditioner
152
+ A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
153
+ M = diags(1 / A_diag, 0)
154
+ # Flat initialization
155
+ y0 = self.grid.splat(xw) / w_splat
156
+ yhat = np.empty_like(y0)
157
+ for d in range(x.shape[-1]):
158
+ yhat[..., d], info = cg(A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"],
159
+ tol=self.params["cg_tol"])
160
+ xhat = self.grid.slice(yhat)
161
+ return xhat
162
+
163
+
164
+ def bilateral_solver_output(
165
+ img: Image.Image,
166
+ target: np.ndarray,
167
+ sigma_spatial=16,
168
+ sigma_luma=16,
169
+ sigma_chroma=8
170
+ ):
171
+ reference = np.array(img)
172
+ h, w = target.shape
173
+ confidence = np.ones((h, w)) * 0.999
174
+
175
+ grid_params = {
176
+ 'sigma_luma': sigma_luma, # Brightness bandwidth
177
+ 'sigma_chroma': sigma_chroma, # Color bandwidth
178
+ 'sigma_spatial': sigma_spatial # Spatial bandwidth
179
+ }
180
+
181
+ bs_params = {
182
+ 'lam': 256, # The strength of the smoothness parameter
183
+ 'A_diag_min': 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
184
+ 'cg_tol': 1e-5, # The tolerance on the convergence in PCG
185
+ 'cg_maxiter': 25 # The number of PCG iterations
186
+ }
187
+
188
+ grid = BilateralGrid(reference, **grid_params)
189
+
190
+ t = target.reshape(-1, 1).astype(np.double)
191
+ c = confidence.reshape(-1, 1).astype(np.double)
192
+
193
+ ## output solver, which is a soft value
194
+ output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
195
+
196
+ binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
197
+ labeled, nr_objects = ndimage.label(binary_solver)
198
+
199
+ nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
200
+ pixel_order = np.argsort(nb_pixel)
201
+ try:
202
+ binary_solver = labeled == pixel_order[-2]
203
+ except:
204
+ binary_solver = np.ones((h, w), dtype=bool)
205
+
206
+ return output_solver, binary_solver
duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # augmentations
2
+ use_copy_paste: false
3
+ scale_range: [ 0.1, 1.0 ]
4
+ repeat_image: false
5
+
6
+ # base directories
7
+ dir_ckpt: "/users/gyungin/selfmask/ckpt" # "/work/gyungin/selfmask/ckpt"
8
+ dir_dataset: "/scratch/shared/beegfs/gyungin/datasets"
9
+
10
+ # clustering
11
+ k: [2, 3, 4]
12
+ clustering_mode: "spectral"
13
+ use_gpu: true # if you want to use gpu-accelerated code for clustering
14
+ scale_factor: 2 # "how much you want to upsample encoder features before clustering"
15
+
16
+ # dataset
17
+ dataset_name: "duts"
18
+ use_pseudo_masks: true
19
+ train_image_size: 224
20
+ eval_image_size: 224
21
+ n_percent: 100
22
+ n_copy_pastes: null
23
+ pseudo_masks_fp: "/users/gyungin/selfmask/datasets/swav_mocov2_dino_p16_k234.json"
24
+
25
+ # dataloader:
26
+ batch_size: 8
27
+ num_workers: 4
28
+ pin_memory: true
29
+
30
+ # networks
31
+ abs_2d_pe_init: false
32
+ arch: "vit_small"
33
+ lateral_connection: false
34
+ learnable_pixel_decoder: false # if False, use the bilinear interpolation
35
+ use_binary_classifier: true # if True, use a binary classifier to get an objectness for each query from transformer decoder
36
+ n_decoder_layers: 6
37
+ n_queries: 20
38
+ num_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
39
+ patch_size: 8
40
+ training_method: "dino" # "supervised", "deit", "dino", "mocov2", "swav"
41
+
42
+ # objective
43
+ loss_every_decoder_layer: true
44
+ weight_dice_loss: 1.0
45
+ weight_focal_loss: 0.0
46
+
47
+ # optimizer
48
+ lr: 0.000006 # default: 0.00006
49
+ lr_warmup_duration: 0 # 5
50
+ momentum: 0.9
51
+ n_epochs: 12
52
+ weight_decay: 0.01
53
+ optimizer_type: "adamw"
54
+
55
+ # validation
56
+ benchmarks: null
networks/__init__.py ADDED
File without changes
networks/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (146 Bytes). View file
 
networks/__pycache__/timm_deit.cpython-38.pyc ADDED
Binary file (7.08 kB). View file
 
networks/__pycache__/timm_vit.cpython-38.pyc ADDED
Binary file (27.7 kB). View file
 
networks/__pycache__/vision_transformer.cpython-38.pyc ADDED
Binary file (15.8 kB). View file
 
networks/maskformer/__pycache__/maskformer.cpython-38.pyc ADDED
Binary file (8.51 kB). View file
 
networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc ADDED
Binary file (8.83 kB). View file
 
networks/maskformer/maskformer.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from math import sqrt, log
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
8
+ from utils import get_model
9
+
10
+
11
+ class MaskFormer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ n_queries: int = 100,
15
+ arch: str = "vit_small",
16
+ patch_size: int = 8,
17
+ training_method: str = "dino",
18
+ n_decoder_layers: int = 6,
19
+ normalize_before: bool = False,
20
+ return_intermediate: bool = False,
21
+ learnable_pixel_decoder: bool = False,
22
+ lateral_connection: bool = False,
23
+ scale_factor: int = 2,
24
+ abs_2d_pe_init: bool = False,
25
+ use_binary_classifier: bool = False
26
+ ):
27
+ """Define a encoder and decoder along with queries to be learned through the decoder."""
28
+ super(MaskFormer, self).__init__()
29
+
30
+ if arch == "vit_small":
31
+ self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
32
+ n_dims: int = self.encoder.n_embs
33
+ n_heads: int = self.encoder.n_heads
34
+ mlp_ratio: int = self.encoder.mlp_ratio
35
+ else:
36
+ self.encoder = get_model(arch=arch, training_method=training_method)
37
+ n_dims_resnet: int = self.encoder.n_embs
38
+ n_dims: int = 384
39
+ n_heads: int = 6
40
+ mlp_ratio: int = 4
41
+ self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)
42
+
43
+ decoder_layer = TransformerDecoderLayer(
44
+ n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
45
+ )
46
+ self.decoder = TransformerDecoder(
47
+ decoder_layer,
48
+ n_decoder_layers,
49
+ norm=nn.LayerNorm(n_dims),
50
+ return_intermediate=return_intermediate
51
+ )
52
+
53
+ self.query_embed = nn.Embedding(n_queries, n_dims).weight # initialized with gaussian(0, 1)
54
+
55
+ if use_binary_classifier:
56
+ # self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
57
+ # self.linear_classifier = nn.Linear(n_dims, 1)
58
+ self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
59
+ # self.norm = nn.LayerNorm(n_dims)
60
+ else:
61
+ # self.ffn = None
62
+ # self.linear_classifier = None
63
+ # self.norm = None
64
+ self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
65
+ self.linear_classifier = nn.Linear(n_dims, 2)
66
+ self.norm = nn.LayerNorm(n_dims)
67
+
68
+ self.arch = arch
69
+ self.use_binary_classifier = use_binary_classifier
70
+ self.lateral_connection = lateral_connection
71
+ self.learnable_pixel_decoder = learnable_pixel_decoder
72
+ self.scale_factor = scale_factor
73
+
74
+ # copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
75
+ @staticmethod
76
+ def positional_encoding_2d(n_dims: int, height: int, width: int):
77
+ """
78
+ :param n_dims: dimension of the model
79
+ :param height: height of the positions
80
+ :param width: width of the positions
81
+ :return: d_model*height*width position matrix
82
+ """
83
+ if n_dims % 4 != 0:
84
+ raise ValueError("Cannot use sin/cos positional encoding with "
85
+ "odd dimension (got dim={:d})".format(n_dims))
86
+ pe = torch.zeros(n_dims, height, width)
87
+ # Each dimension use half of d_model
88
+ d_model = int(n_dims / 2)
89
+ div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
90
+ pos_w = torch.arange(0., width).unsqueeze(1)
91
+ pos_h = torch.arange(0., height).unsqueeze(1)
92
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
93
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
94
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
95
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
96
+
97
+ return pe
98
+
99
+ def forward_encoder(self, x: torch.Tensor):
100
+ """
101
+ :param x: b x c x h x w
102
+ :return patch_tokens: b x depth x hw x n_dims
103
+ """
104
+ if self.arch == "vit_small":
105
+ encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x) # [:, 1:, :]
106
+ all_patch_tokens: List[torch.Tensor] = list()
107
+ for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
108
+ patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :] # b x hw x n_dims
109
+ all_patch_tokens.append(patch_tokens)
110
+
111
+ all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0) # depth x b x hw x n_dims
112
+ all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2) # b x depth x n_dims x hw
113
+ return all_patch_tokens
114
+ else:
115
+ encoder_outputs = self.linear_layer(self.encoder(x)[-1]) # b x n_dims x h x w
116
+ return encoder_outputs
117
+
118
+ def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
119
+ """Forward transformer decoder given patch tokens from the encoder's last layer.
120
+ :param patch_tokens: b x n_dims x hw -> hw x b x n_dims
121
+ :param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
122
+ between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
123
+ experiment.
124
+ :return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
125
+ """
126
+ b = patch_tokens.shape[0]
127
+ patch_tokens = patch_tokens.permute(2, 0, 1) # b x n_dims x hw -> hw x b x n_dims
128
+
129
+ # n_queries x n_dims -> n_queries x b x n_dims
130
+ queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
131
+ queries: torch.Tensor = self.decoder.forward(
132
+ tgt=torch.zeros_like(queries),
133
+ memory=patch_tokens,
134
+ query_pos=queries
135
+ ).squeeze(dim=0)
136
+
137
+ if len(queries.shape) == 3:
138
+ queries: torch.Tensor = queries.permute(1, 0, 2) # n_queries x b x n_dims -> b x n_queries x n_dims
139
+ elif len(queries.shape) == 4:
140
+ # n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
141
+ queries: torch.Tensor = queries.permute(2, 0, 1, 3)
142
+ return queries
143
+
144
+ def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
145
+ """ Upsample patch tokens by self.scale_factor and produce mask predictions
146
+ :param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
147
+ :param queries: b x n_queries x n_dims
148
+ :return mask_predictions: b x n_queries x h x w
149
+ """
150
+
151
+ if input_size is None:
152
+ # assume square shape features
153
+ hw = patch_tokens.shape[-1]
154
+ h = w = int(sqrt(hw))
155
+ else:
156
+ # arbitrary shape features
157
+ h, w = input_size
158
+ patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)
159
+
160
+ assert len(patch_tokens.shape) == 4
161
+ patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
162
+ return patch_tokens
163
+
164
+ def forward(self, x, encoder_only=False, skip_decoder: bool = False):
165
+ """
166
+ x: b x c x h x w
167
+ patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
168
+ query_emb: n_queries x n_dims -> n_queries x b x n_dims
169
+ """
170
+ dict_outputs: dict = dict()
171
+
172
+ # b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
173
+ features: torch.Tensor = self.forward_encoder(x)
174
+
175
+ if self.arch == "vit_small":
176
+ # extract the last layer for decoder input
177
+ last_layer_features: torch.Tensor = features[:, -1, ...] # b x n_dims x hw
178
+ else:
179
+ # transform the shape of the features to the one compatible with transformer decoder
180
+ b, n_dims, h, w = features.shape
181
+ last_layer_features: torch.Tensor = features.view(b, n_dims, h * w) # b x n_dims x hw
182
+
183
+ if encoder_only:
184
+ _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
185
+ _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
186
+
187
+ b, n_dims, hw = last_layer_features.shape
188
+ dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
189
+ return dict_outputs
190
+
191
+ # transformer decoder forward
192
+ queries: torch.Tensor = self.forward_transformer_decoder(
193
+ last_layer_features,
194
+ skip_decoder=skip_decoder
195
+ ) # b x n_queries x n_dims or b x n_layers x n_queries x n_dims
196
+
197
+ # pixel decoder forward (upsampling the patch tokens by self.scale_factor)
198
+ if self.arch == "vit_small":
199
+ _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
200
+ _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
201
+ else:
202
+ _h, _w = h, w
203
+ features: torch.Tensor = self.forward_pixel_decoder(
204
+ patch_tokens=features if self.lateral_connection else last_layer_features,
205
+ input_size=(_h, _w)
206
+ ) # b x n_dims x h x w
207
+
208
+ # queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
209
+ # features: b x n_dims x h x w
210
+ # mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
211
+ if len(queries.shape) == 3:
212
+ mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
213
+ else:
214
+ if self.use_binary_classifier:
215
+ mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
216
+ else:
217
+ mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))
218
+
219
+ if self.use_binary_classifier:
220
+ # queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
221
+ queries = queries.permute(1, 0, 2, 3)
222
+ objectness: List[torch.Tensor] = list()
223
+ for n_layer, queries_per_layer in enumerate(queries): # queries_per_layer: b x n_queries x n_dims
224
+ # objectness_per_layer = self.linear_classifier(
225
+ # self.ffn(self.norm(queries_per_layer))
226
+ # ) # b x n_queries x 1
227
+ objectness_per_layer = self.ffn(queries_per_layer) # b x n_queries x 1
228
+ objectness.append(objectness_per_layer)
229
+ # n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
230
+ objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
231
+ dict_outputs.update({
232
+ "objectness": torch.sigmoid(objectness),
233
+ "mask_pred": mask_pred
234
+ })
235
+
236
+ return dict_outputs
237
+
238
+
239
+ class MLP(nn.Module):
240
+ """Very simple multi-layer perceptron (also called FFN)"""
241
+
242
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
243
+ super().__init__()
244
+ self.num_layers = num_layers
245
+ h = [hidden_dim] * (num_layers - 1)
246
+ self.layers = nn.ModuleList(
247
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
248
+ )
249
+
250
+ def forward(self, x):
251
+ for i, layer in enumerate(self.layers):
252
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
253
+ return x
254
+
255
+
256
+ class UpsampleBlock(nn.Module):
257
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
258
+ super(UpsampleBlock, self).__init__()
259
+ self.block = nn.Sequential(
260
+ nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
261
+ nn.GroupNorm(n_groups, out_channels),
262
+ nn.ReLU()
263
+ )
264
+ self.scale_factor = scale_factor
265
+
266
+ def forward(self, x):
267
+ return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")
networks/maskformer/positional_embedding.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3
+ """
4
+ Various positional encodings for the transformer.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+
18
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * math.pi
27
+ self.scale = scale
28
+
29
+ def forward(self, x, mask=None):
30
+ if mask is None:
31
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48
+ return pos
networks/maskformer/transformer_decoder.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3
+ """
4
+ Transformer class.
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import Tensor, nn
16
+
17
+
18
+ class Transformer(nn.Module):
19
+ def __init__(
20
+ self,
21
+ d_model=512,
22
+ nhead=8,
23
+ num_encoder_layers=6,
24
+ num_decoder_layers=6,
25
+ dim_feedforward=2048,
26
+ dropout=0.1,
27
+ activation="relu", # noel - dino used GeLU
28
+ normalize_before=False,
29
+ return_intermediate_dec=False,
30
+ ):
31
+ super().__init__()
32
+
33
+ encoder_layer = TransformerEncoderLayer(
34
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
35
+ )
36
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
37
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
38
+
39
+ decoder_layer = TransformerDecoderLayer(
40
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
41
+ )
42
+ decoder_norm = nn.LayerNorm(d_model)
43
+ self.decoder = TransformerDecoder(
44
+ decoder_layer,
45
+ num_decoder_layers,
46
+ decoder_norm,
47
+ return_intermediate=return_intermediate_dec,
48
+ )
49
+
50
+ self._reset_parameters()
51
+
52
+ self.d_model = d_model
53
+ self.nhead = nhead
54
+
55
+ def _reset_parameters(self):
56
+ for p in self.parameters():
57
+ if p.dim() > 1:
58
+ nn.init.xavier_uniform_(p)
59
+
60
+ def forward(self, src, mask, query_embed, pos_embed):
61
+ # flatten NxCxHxW to HWxNxC
62
+ bs, c, h, w = src.shape
63
+ src = src.flatten(2).permute(2, 0, 1)
64
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
65
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
66
+ if mask is not None:
67
+ mask = mask.flatten(1)
68
+
69
+ tgt = torch.zeros_like(query_embed)
70
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
71
+ hs = self.decoder(
72
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
73
+ )
74
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
75
+
76
+
77
+ class TransformerEncoder(nn.Module):
78
+ def __init__(self, encoder_layer, num_layers, norm=None):
79
+ super().__init__()
80
+ self.layers = _get_clones(encoder_layer, num_layers)
81
+ self.num_layers = num_layers
82
+ self.norm = norm
83
+
84
+ def forward(
85
+ self,
86
+ src,
87
+ mask: Optional[Tensor] = None,
88
+ src_key_padding_mask: Optional[Tensor] = None,
89
+ pos: Optional[Tensor] = None,
90
+ ):
91
+ output = src
92
+
93
+ for layer in self.layers:
94
+ output = layer(
95
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
96
+ )
97
+
98
+ if self.norm is not None:
99
+ output = self.norm(output)
100
+
101
+ return output
102
+
103
+
104
+ class TransformerDecoder(nn.Module):
105
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
106
+ super().__init__()
107
+ self.layers: nn.ModuleList = _get_clones(decoder_layer, num_layers)
108
+ self.num_layers: int = num_layers
109
+ self.norm = norm
110
+ self.return_intermediate: bool = return_intermediate
111
+
112
+ def forward(
113
+ self,
114
+ tgt,
115
+ memory,
116
+ tgt_mask: Optional[Tensor] = None,
117
+ memory_mask: Optional[Tensor] = None,
118
+ tgt_key_padding_mask: Optional[Tensor] = None,
119
+ memory_key_padding_mask: Optional[Tensor] = None,
120
+ pos: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None,
122
+ ):
123
+ output = tgt
124
+
125
+ intermediate = []
126
+
127
+ for layer in self.layers:
128
+ output = layer(
129
+ output,
130
+ memory,
131
+ tgt_mask=tgt_mask,
132
+ memory_mask=memory_mask,
133
+ tgt_key_padding_mask=tgt_key_padding_mask,
134
+ memory_key_padding_mask=memory_key_padding_mask,
135
+ pos=pos,
136
+ query_pos=query_pos,
137
+ )
138
+ if self.return_intermediate:
139
+ intermediate.append(self.norm(output))
140
+
141
+ if self.norm is not None:
142
+ output = self.norm(output)
143
+ if self.return_intermediate:
144
+ intermediate.pop()
145
+ intermediate.append(output)
146
+
147
+ if self.return_intermediate:
148
+ return torch.stack(intermediate)
149
+
150
+ return output.unsqueeze(0)
151
+
152
+
153
+ class TransformerEncoderLayer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ d_model,
157
+ nhead,
158
+ dim_feedforward=2048,
159
+ dropout=0.1,
160
+ activation="relu",
161
+ normalize_before=False,
162
+ ):
163
+ super().__init__()
164
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
165
+ # Implementation of Feedforward model
166
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
167
+ self.dropout = nn.Dropout(dropout)
168
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
169
+
170
+ self.norm1 = nn.LayerNorm(d_model)
171
+ self.norm2 = nn.LayerNorm(d_model)
172
+ self.dropout1 = nn.Dropout(dropout)
173
+ self.dropout2 = nn.Dropout(dropout)
174
+
175
+ self.activation = _get_activation_fn(activation)
176
+ self.normalize_before = normalize_before
177
+
178
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
179
+ return tensor if pos is None else tensor + pos
180
+
181
+ def forward_post(
182
+ self,
183
+ src,
184
+ src_mask: Optional[Tensor] = None,
185
+ src_key_padding_mask: Optional[Tensor] = None,
186
+ pos: Optional[Tensor] = None,
187
+ ):
188
+ q = k = self.with_pos_embed(src, pos)
189
+ src2 = self.self_attn(
190
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
191
+ )[0]
192
+ src = src + self.dropout1(src2)
193
+ src = self.norm1(src)
194
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
195
+ src = src + self.dropout2(src2)
196
+ src = self.norm2(src)
197
+ return src
198
+
199
+ def forward_pre(
200
+ self,
201
+ src,
202
+ src_mask: Optional[Tensor] = None,
203
+ src_key_padding_mask: Optional[Tensor] = None,
204
+ pos: Optional[Tensor] = None,
205
+ ):
206
+ src2 = self.norm1(src)
207
+ q = k = self.with_pos_embed(src2, pos)
208
+ src2 = self.self_attn(
209
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
210
+ )[0]
211
+ src = src + self.dropout1(src2)
212
+ src2 = self.norm2(src)
213
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
214
+ src = src + self.dropout2(src2)
215
+ return src
216
+
217
+ def forward(
218
+ self,
219
+ src,
220
+ src_mask: Optional[Tensor] = None,
221
+ src_key_padding_mask: Optional[Tensor] = None,
222
+ pos: Optional[Tensor] = None,
223
+ ):
224
+ if self.normalize_before:
225
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
226
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
227
+
228
+
229
+ class TransformerDecoderLayer(nn.Module):
230
+ def __init__(
231
+ self,
232
+ d_model,
233
+ nhead,
234
+ dim_feedforward=2048,
235
+ dropout=0.1,
236
+ activation="relu",
237
+ normalize_before=False,
238
+ ):
239
+ super().__init__()
240
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
241
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
242
+ # Implementation of Feedforward model
243
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
244
+ self.dropout = nn.Dropout(dropout)
245
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
246
+
247
+ self.norm1 = nn.LayerNorm(d_model)
248
+ self.norm2 = nn.LayerNorm(d_model)
249
+ self.norm3 = nn.LayerNorm(d_model)
250
+ self.dropout1 = nn.Dropout(dropout)
251
+ self.dropout2 = nn.Dropout(dropout)
252
+ self.dropout3 = nn.Dropout(dropout)
253
+
254
+ self.activation = _get_activation_fn(activation)
255
+ self.normalize_before = normalize_before
256
+
257
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
258
+ return tensor if pos is None else tensor + pos
259
+
260
+ def forward_post(
261
+ self,
262
+ tgt,
263
+ memory,
264
+ tgt_mask: Optional[Tensor] = None,
265
+ memory_mask: Optional[Tensor] = None,
266
+ tgt_key_padding_mask: Optional[Tensor] = None,
267
+ memory_key_padding_mask: Optional[Tensor] = None,
268
+ pos: Optional[Tensor] = None,
269
+ query_pos: Optional[Tensor] = None,
270
+ ):
271
+ q = k = self.with_pos_embed(tgt, query_pos)
272
+
273
+ tgt2 = self.self_attn(
274
+ q,
275
+ k,
276
+ value=tgt,
277
+ attn_mask=tgt_mask,
278
+ key_padding_mask=tgt_key_padding_mask
279
+ )[0]
280
+ tgt = tgt + self.dropout1(tgt2)
281
+ tgt = self.norm1(tgt)
282
+
283
+ tgt2 = self.multihead_attn(
284
+ query=self.with_pos_embed(tgt, query_pos),
285
+ key=self.with_pos_embed(memory, pos),
286
+ value=memory,
287
+ attn_mask=memory_mask,
288
+ key_padding_mask=memory_key_padding_mask,
289
+ )[0]
290
+ tgt = tgt + self.dropout2(tgt2)
291
+ tgt = self.norm2(tgt)
292
+
293
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
294
+ tgt = tgt + self.dropout3(tgt2)
295
+ tgt = self.norm3(tgt)
296
+
297
+ return tgt
298
+
299
+ def forward_pre(
300
+ self,
301
+ tgt,
302
+ memory,
303
+ tgt_mask: Optional[Tensor] = None,
304
+ memory_mask: Optional[Tensor] = None,
305
+ tgt_key_padding_mask: Optional[Tensor] = None,
306
+ memory_key_padding_mask: Optional[Tensor] = None,
307
+ pos: Optional[Tensor] = None,
308
+ query_pos: Optional[Tensor] = None,
309
+ ):
310
+ tgt2 = self.norm1(tgt)
311
+ q = k = self.with_pos_embed(tgt2, query_pos)
312
+ tgt2 = self.self_attn(
313
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
314
+ )[0]
315
+ tgt = tgt + self.dropout1(tgt2)
316
+ tgt2 = self.norm2(tgt)
317
+ tgt2 = self.multihead_attn(
318
+ query=self.with_pos_embed(tgt2, query_pos),
319
+ key=self.with_pos_embed(memory, pos),
320
+ value=memory,
321
+ attn_mask=memory_mask,
322
+ key_padding_mask=memory_key_padding_mask,
323
+ )[0]
324
+ tgt = tgt + self.dropout2(tgt2)
325
+ tgt2 = self.norm3(tgt)
326
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
327
+ tgt = tgt + self.dropout3(tgt2)
328
+ return tgt
329
+
330
+ def forward(
331
+ self,
332
+ tgt,
333
+ memory,
334
+ tgt_mask: Optional[Tensor] = None,
335
+ memory_mask: Optional[Tensor] = None,
336
+ tgt_key_padding_mask: Optional[Tensor] = None,
337
+ memory_key_padding_mask: Optional[Tensor] = None,
338
+ pos: Optional[Tensor] = None,
339
+ query_pos: Optional[Tensor] = None,
340
+ ):
341
+ if self.normalize_before:
342
+ return self.forward_pre(
343
+ tgt,
344
+ memory,
345
+ tgt_mask,
346
+ memory_mask,
347
+ tgt_key_padding_mask,
348
+ memory_key_padding_mask,
349
+ pos,
350
+ query_pos,
351
+ )
352
+ return self.forward_post(
353
+ tgt,
354
+ memory,
355
+ tgt_mask,
356
+ memory_mask,
357
+ tgt_key_padding_mask,
358
+ memory_key_padding_mask,
359
+ pos,
360
+ query_pos,
361
+ )
362
+
363
+
364
+ def _get_clones(module, N):
365
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
366
+
367
+
368
+ def _get_activation_fn(activation):
369
+ """Return an activation function given a string"""
370
+ if activation == "relu":
371
+ return F.relu
372
+ if activation == "gelu":
373
+ return F.gelu
374
+ if activation == "glu":
375
+ return F.glu
376
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
networks/module_helper.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Author: Donny You (youansheng@gmail.com)
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ try:
10
+ from urllib import urlretrieve
11
+ except ImportError:
12
+ from urllib.request import urlretrieve
13
+
14
+
15
+ class FixedBatchNorm(nn.BatchNorm2d):
16
+ def forward(self, input):
17
+ return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
18
+
19
+
20
+ class ModuleHelper(object):
21
+ @staticmethod
22
+ def BNReLU(num_features, norm_type=None, **kwargs):
23
+ if norm_type == 'batchnorm':
24
+ return nn.Sequential(
25
+ nn.BatchNorm2d(num_features, **kwargs),
26
+ nn.ReLU()
27
+ )
28
+ elif norm_type == 'encsync_batchnorm':
29
+ from encoding.nn import BatchNorm2d
30
+ return nn.Sequential(
31
+ BatchNorm2d(num_features, **kwargs),
32
+ nn.ReLU()
33
+ )
34
+ elif norm_type == 'instancenorm':
35
+ return nn.Sequential(
36
+ nn.InstanceNorm2d(num_features, **kwargs),
37
+ nn.ReLU()
38
+ )
39
+ elif norm_type == 'fixed_batchnorm':
40
+ return nn.Sequential(
41
+ FixedBatchNorm(num_features, **kwargs),
42
+ nn.ReLU()
43
+ )
44
+ else:
45
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
46
+
47
+ @staticmethod
48
+ def BatchNorm3d(norm_type=None, ret_cls=False):
49
+ if norm_type == 'batchnorm':
50
+ return nn.BatchNorm3d
51
+ elif norm_type == 'encsync_batchnorm':
52
+ from encoding.nn import BatchNorm3d
53
+ return BatchNorm3d
54
+ elif norm_type == 'instancenorm':
55
+ return nn.InstanceNorm3d
56
+ else:
57
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
58
+
59
+ @staticmethod
60
+ def BatchNorm2d(norm_type=None, ret_cls=False):
61
+ if norm_type == 'batchnorm':
62
+ return nn.BatchNorm2d
63
+ elif norm_type == 'encsync_batchnorm':
64
+ from encoding.nn import BatchNorm2d
65
+ return BatchNorm2d
66
+
67
+ elif norm_type == 'instancenorm':
68
+ return nn.InstanceNorm2d
69
+ else:
70
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
71
+
72
+ @staticmethod
73
+ def BatchNorm1d(norm_type=None, ret_cls=False):
74
+ if norm_type == 'batchnorm':
75
+ return nn.BatchNorm1d
76
+ elif norm_type == 'encsync_batchnorm':
77
+ from encoding.nn import BatchNorm1d
78
+ return BatchNorm1d
79
+ elif norm_type == 'instancenorm':
80
+ return nn.InstanceNorm1d
81
+ else:
82
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
83
+
84
+ @staticmethod
85
+ def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
86
+ if pretrained is None:
87
+ return model
88
+
89
+ if not os.path.exists(pretrained):
90
+ pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation")
91
+ if os.path.exists(pretrained):
92
+ pass
93
+ else:
94
+ raise FileNotFoundError('{} not exists.'.format(pretrained))
95
+
96
+ print('Loading pretrained model:{}'.format(pretrained))
97
+ if all_match:
98
+ pretrained_dict = torch.load(pretrained, map_location=map_location)
99
+ model_dict = model.state_dict()
100
+ load_dict = dict()
101
+ for k, v in pretrained_dict.items():
102
+ if 'prefix.{}'.format(k) in model_dict:
103
+ load_dict['prefix.{}'.format(k)] = v
104
+ else:
105
+ load_dict[k] = v
106
+ model.load_state_dict(load_dict)
107
+
108
+ else:
109
+ pretrained_dict = torch.load(pretrained)
110
+ model_dict = model.state_dict()
111
+ load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
112
+ print('Matched Keys: {}'.format(load_dict.keys()))
113
+ model_dict.update(load_dict)
114
+ model.load_state_dict(model_dict)
115
+
116
+ return model
117
+
118
+ @staticmethod
119
+ def load_url(url, map_location=None):
120
+ model_dir = os.path.join('~', '.TorchCV', 'model')
121
+ if not os.path.exists(model_dir):
122
+ os.makedirs(model_dir)
123
+
124
+ filename = url.split('/')[-1]
125
+ cached_file = os.path.join(model_dir, filename)
126
+ if not os.path.exists(cached_file):
127
+ print('Downloading: "{}" to {}\n'.format(url, cached_file))
128
+ urlretrieve(url, cached_file)
129
+
130
+ print('Loading pretrained model:{}'.format(cached_file))
131
+ return torch.load(cached_file, map_location=map_location)
132
+
133
+ @staticmethod
134
+ def constant_init(module, val, bias=0):
135
+ nn.init.constant_(module.weight, val)
136
+ if hasattr(module, 'bias') and module.bias is not None:
137
+ nn.init.constant_(module.bias, bias)
138
+
139
+ @staticmethod
140
+ def xavier_init(module, gain=1, bias=0, distribution='normal'):
141
+ assert distribution in ['uniform', 'normal']
142
+ if distribution == 'uniform':
143
+ nn.init.xavier_uniform_(module.weight, gain=gain)
144
+ else:
145
+ nn.init.xavier_normal_(module.weight, gain=gain)
146
+ if hasattr(module, 'bias') and module.bias is not None:
147
+ nn.init.constant_(module.bias, bias)
148
+
149
+ @staticmethod
150
+ def normal_init(module, mean=0, std=1, bias=0):
151
+ nn.init.normal_(module.weight, mean, std)
152
+ if hasattr(module, 'bias') and module.bias is not None:
153
+ nn.init.constant_(module.bias, bias)
154
+
155
+ @staticmethod
156
+ def uniform_init(module, a=0, b=1, bias=0):
157
+ nn.init.uniform_(module.weight, a, b)
158
+ if hasattr(module, 'bias') and module.bias is not None:
159
+ nn.init.constant_(module.bias, bias)
160
+
161
+ @staticmethod
162
+ def kaiming_init(module,
163
+ mode='fan_in',
164
+ nonlinearity='leaky_relu',
165
+ bias=0,
166
+ distribution='normal'):
167
+ assert distribution in ['uniform', 'normal']
168
+ if distribution == 'uniform':
169
+ nn.init.kaiming_uniform_(
170
+ module.weight, mode=mode, nonlinearity=nonlinearity)
171
+ else:
172
+ nn.init.kaiming_normal_(
173
+ module.weight, mode=mode, nonlinearity=nonlinearity)
174
+ if hasattr(module, 'bias') and module.bias is not None:
175
+ nn.init.constant_(module.bias, bias)
176
+
networks/resnet.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from .resnet_backbone import ResNetBackbone
6
+
7
+
8
+ class ResNet50(nn.Module):
9
+ def __init__(
10
+ self,
11
+ weight_type: str = "supervised",
12
+ use_dilated_resnet: bool = True
13
+ ):
14
+ super(ResNet50, self).__init__()
15
+ self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
16
+ self.n_embs = self.network.num_features
17
+ self.use_dilated_resnet = use_dilated_resnet
18
+ self._load_pretrained(weight_type)
19
+
20
+ def _load_pretrained(self, training_method: str) -> None:
21
+ curr_state_dict = self.network.state_dict()
22
+ if training_method == "mocov2":
23
+ state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
24
+
25
+ for k in list(state_dict.keys()):
26
+ if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
27
+ state_dict.pop(k)
28
+
29
+ elif training_method == "swav":
30
+ state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
31
+ for k in list(state_dict.keys()):
32
+ if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
33
+ state_dict.pop(k)
34
+
35
+ elif training_method == "supervised":
36
+ # Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
37
+ # for k in list(curr_state_dict.keys()):
38
+ # if k.find("num_batches_tracked") != -1:
39
+ # curr_state_dict.pop(k)
40
+ # state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")
41
+
42
+ from torchvision.models.resnet import resnet50
43
+ resnet50_supervised = resnet50(True, True)
44
+ state_dict = resnet50_supervised.state_dict()
45
+ for k in list(state_dict.keys()):
46
+ if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
47
+ state_dict.pop(k)
48
+
49
+ assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
50
+ for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
51
+ curr_state_dict[k_curr].copy_(state_dict[k])
52
+ print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
53
+ return
54
+
55
+ def forward(self, x):
56
+ return self.network(x)
57
+
58
+
59
+ if __name__ == '__main__':
60
+ resnet = ResNet50("mocov2")
networks/resnet_backbone.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Author: Donny You(youansheng@gmail.com)
4
+
5
+
6
+ import torch.nn as nn
7
+ from networks.resnet_models import *
8
+
9
+
10
+ class NormalResnetBackbone(nn.Module):
11
+ def __init__(self, orig_resnet):
12
+ super(NormalResnetBackbone, self).__init__()
13
+
14
+ self.num_features = 2048
15
+ # take pretrained resnet, except AvgPool and FC
16
+ self.prefix = orig_resnet.prefix
17
+ self.maxpool = orig_resnet.maxpool
18
+ self.layer1 = orig_resnet.layer1
19
+ self.layer2 = orig_resnet.layer2
20
+ self.layer3 = orig_resnet.layer3
21
+ self.layer4 = orig_resnet.layer4
22
+
23
+ def get_num_features(self):
24
+ return self.num_features
25
+
26
+ def forward(self, x):
27
+ tuple_features = list()
28
+ x = self.prefix(x)
29
+ x = self.maxpool(x)
30
+ x = self.layer1(x)
31
+ tuple_features.append(x)
32
+ x = self.layer2(x)
33
+ tuple_features.append(x)
34
+ x = self.layer3(x)
35
+ tuple_features.append(x)
36
+ x = self.layer4(x)
37
+ tuple_features.append(x)
38
+
39
+ return tuple_features
40
+
41
+
42
+ class DilatedResnetBackbone(nn.Module):
43
+ def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
44
+ super(DilatedResnetBackbone, self).__init__()
45
+
46
+ self.num_features = 2048
47
+ from functools import partial
48
+
49
+ if dilate_scale == 8:
50
+ orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
51
+ if multi_grid is None:
52
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
53
+ else:
54
+ for i, r in enumerate(multi_grid):
55
+ orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
56
+
57
+ elif dilate_scale == 16:
58
+ if multi_grid is None:
59
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
60
+ else:
61
+ for i, r in enumerate(multi_grid):
62
+ orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
63
+
64
+ # Take pretrained resnet, except AvgPool and FC
65
+ self.prefix = orig_resnet.prefix
66
+ self.maxpool = orig_resnet.maxpool
67
+ self.layer1 = orig_resnet.layer1
68
+ self.layer2 = orig_resnet.layer2
69
+ self.layer3 = orig_resnet.layer3
70
+ self.layer4 = orig_resnet.layer4
71
+
72
+ def _nostride_dilate(self, m, dilate):
73
+ classname = m.__class__.__name__
74
+ if classname.find('Conv') != -1:
75
+ # the convolution with stride
76
+ if m.stride == (2, 2):
77
+ m.stride = (1, 1)
78
+ if m.kernel_size == (3, 3):
79
+ m.dilation = (dilate // 2, dilate // 2)
80
+ m.padding = (dilate // 2, dilate // 2)
81
+ # other convoluions
82
+ else:
83
+ if m.kernel_size == (3, 3):
84
+ m.dilation = (dilate, dilate)
85
+ m.padding = (dilate, dilate)
86
+
87
+ def get_num_features(self):
88
+ return self.num_features
89
+
90
+ def forward(self, x):
91
+ tuple_features = list()
92
+
93
+ x = self.prefix(x)
94
+ x = self.maxpool(x)
95
+
96
+ x = self.layer1(x)
97
+ tuple_features.append(x)
98
+ x = self.layer2(x)
99
+ tuple_features.append(x)
100
+ x = self.layer3(x)
101
+ tuple_features.append(x)
102
+ x = self.layer4(x)
103
+ tuple_features.append(x)
104
+
105
+ return tuple_features
106
+
107
+
108
+ def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'):
109
+ arch = backbone
110
+
111
+ if arch == 'resnet18':
112
+ orig_resnet = resnet18(pretrained=pretrained)
113
+ arch_net = NormalResnetBackbone(orig_resnet)
114
+ arch_net.num_features = 512
115
+
116
+ elif arch == 'resnet18_dilated8':
117
+ orig_resnet = resnet18(pretrained=pretrained)
118
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
119
+ arch_net.num_features = 512
120
+
121
+ elif arch == 'resnet34':
122
+ orig_resnet = resnet34(pretrained=pretrained)
123
+ arch_net = NormalResnetBackbone(orig_resnet)
124
+ arch_net.num_features = 512
125
+
126
+ elif arch == 'resnet34_dilated8':
127
+ orig_resnet = resnet34(pretrained=pretrained)
128
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
129
+ arch_net.num_features = 512
130
+
131
+ elif arch == 'resnet34_dilated16':
132
+ orig_resnet = resnet34(pretrained=pretrained)
133
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
134
+ arch_net.num_features = 512
135
+
136
+ elif arch == 'resnet50':
137
+ orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
138
+ arch_net = NormalResnetBackbone(orig_resnet)
139
+
140
+ elif arch == 'resnet50_dilated8':
141
+ orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
142
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
143
+
144
+ elif arch == 'resnet50_dilated16':
145
+ orig_resnet = resnet50(pretrained=pretrained)
146
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
147
+
148
+ elif arch == 'deepbase_resnet50':
149
+ if pretrained:
150
+ pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
151
+ orig_resnet = deepbase_resnet50(pretrained=pretrained)
152
+ arch_net = NormalResnetBackbone(orig_resnet)
153
+
154
+ elif arch == 'deepbase_resnet50_dilated8':
155
+ if pretrained:
156
+ pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
157
+ # pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth"
158
+ orig_resnet = deepbase_resnet50(pretrained=pretrained)
159
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
160
+
161
+ elif arch == 'deepbase_resnet50_dilated16':
162
+ orig_resnet = deepbase_resnet50(pretrained=pretrained)
163
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
164
+
165
+ elif arch == 'resnet101':
166
+ orig_resnet = resnet101(pretrained=pretrained)
167
+ arch_net = NormalResnetBackbone(orig_resnet)
168
+
169
+ elif arch == 'resnet101_dilated8':
170
+ orig_resnet = resnet101(pretrained=pretrained)
171
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
172
+
173
+ elif arch == 'resnet101_dilated16':
174
+ orig_resnet = resnet101(pretrained=pretrained)
175
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
176
+
177
+ elif arch == 'deepbase_resnet101':
178
+ orig_resnet = deepbase_resnet101(pretrained=pretrained)
179
+ arch_net = NormalResnetBackbone(orig_resnet)
180
+
181
+ elif arch == 'deepbase_resnet101_dilated8':
182
+ if pretrained:
183
+ pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth'
184
+ orig_resnet = deepbase_resnet101(pretrained=pretrained)
185
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
186
+
187
+ elif arch == 'deepbase_resnet101_dilated16':
188
+ orig_resnet = deepbase_resnet101(pretrained=pretrained)
189
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
190
+
191
+ else:
192
+ raise Exception('Architecture undefined!')
193
+
194
+ return arch_net
networks/resnet_models.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Author: Donny You(youansheng@gmail.com)
4
+ import math
5
+ import torch.nn as nn
6
+ from collections import OrderedDict
7
+ from .module_helper import ModuleHelper
8
+
9
+
10
+ model_urls = {
11
+ 'resnet18': 'https://download.pytorch.org/backbones/resnet18-5c106cde.pth',
12
+ 'resnet34': 'https://download.pytorch.org/backbones/resnet34-333f7ec4.pth',
13
+ 'resnet50': 'https://download.pytorch.org/backbones/resnet50-19c8e357.pth',
14
+ 'resnet101': 'https://download.pytorch.org/backbones/resnet101-5d3b4d8f.pth',
15
+ 'resnet152': 'https://download.pytorch.org/backbones/resnet152-b121ed2d.pth'
16
+ }
17
+
18
+
19
+ def conv3x3(in_planes, out_planes, stride=1):
20
+ "3x3 convolution with padding"
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=1, bias=False)
23
+
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+
28
+ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
29
+ super(BasicBlock, self).__init__()
30
+ self.conv1 = conv3x3(inplanes, planes, stride)
31
+ self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
32
+ self.relu = nn.ReLU(inplace=True)
33
+ self.conv2 = conv3x3(planes, planes)
34
+ self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
35
+ self.downsample = downsample
36
+ self.stride = stride
37
+
38
+ def forward(self, x):
39
+ residual = x
40
+
41
+ out = self.conv1(x)
42
+ out = self.bn1(out)
43
+ out = self.relu(out)
44
+
45
+ out = self.conv2(out)
46
+ out = self.bn2(out)
47
+
48
+ if self.downsample is not None:
49
+ residual = self.downsample(x)
50
+
51
+ out += residual
52
+ out = self.relu(out)
53
+
54
+ return out
55
+
56
+
57
+ class Bottleneck(nn.Module):
58
+ expansion = 4
59
+
60
+ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
61
+ super(Bottleneck, self).__init__()
62
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63
+ self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
64
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65
+ padding=1, bias=False)
66
+ self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
67
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68
+ self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4)
69
+ self.relu = nn.ReLU(inplace=True)
70
+ self.downsample = downsample
71
+ self.stride = stride
72
+
73
+ def forward(self, x):
74
+ residual = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if self.downsample is not None:
88
+ residual = self.downsample(x)
89
+
90
+ out += residual
91
+ out = self.relu(out)
92
+
93
+ return out
94
+
95
+
96
+ class ResNet(nn.Module):
97
+ def __init__(self, block, layers, width_multiplier=1.0, num_classes=1000, deep_base=False, norm_type=None):
98
+ super(ResNet, self).__init__()
99
+ self.inplanes = 128 if deep_base else int(64 * width_multiplier)
100
+ self.width_multiplier = width_multiplier
101
+ if deep_base:
102
+ self.prefix = nn.Sequential(OrderedDict([
103
+ ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
104
+ ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
105
+ ('relu1', nn.ReLU(inplace=False)),
106
+ ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
107
+ ('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
108
+ ('relu2', nn.ReLU(inplace=False)),
109
+ ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
110
+ ('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
111
+ ('relu3', nn.ReLU(inplace=False))]
112
+ ))
113
+ else:
114
+ self.prefix = nn.Sequential(OrderedDict([
115
+ ('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
116
+ ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
117
+ ('relu', nn.ReLU(inplace=False))]
118
+ ))
119
+
120
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) # change.
121
+
122
+ self.layer1 = self._make_layer(block, int(64 * width_multiplier), layers[0], norm_type=norm_type)
123
+ self.layer2 = self._make_layer(block, int(128 * width_multiplier), layers[1], stride=2, norm_type=norm_type)
124
+ self.layer3 = self._make_layer(block, int(256 * width_multiplier), layers[2], stride=2, norm_type=norm_type)
125
+ self.layer4 = self._make_layer(block, int(512 * width_multiplier), layers[3], stride=2, norm_type=norm_type)
126
+ self.avgpool = nn.AvgPool2d(7, stride=1)
127
+ self.fc = nn.Linear(int(512 * block.expansion * width_multiplier), num_classes)
128
+
129
+ for m in self.modules():
130
+ if isinstance(m, nn.Conv2d):
131
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
132
+ m.weight.data.normal_(0, math.sqrt(2. / n))
133
+ elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
134
+ m.weight.data.fill_(1)
135
+ m.bias.data.zero_()
136
+
137
+ def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
138
+ downsample = None
139
+ if stride != 1 or self.inplanes != planes * block.expansion:
140
+ downsample = nn.Sequential(
141
+ nn.Conv2d(self.inplanes, planes * block.expansion,
142
+ kernel_size=1, stride=stride, bias=False),
143
+ ModuleHelper.BatchNorm2d(norm_type=norm_type)(int(planes * block.expansion * self.width_multiplier)),
144
+ )
145
+
146
+ layers = []
147
+ layers.append(block(self.inplanes, planes,
148
+ stride, downsample, norm_type=norm_type))
149
+
150
+ self.inplanes = planes * block.expansion
151
+ for i in range(1, blocks):
152
+ layers.append(block(self.inplanes, planes, norm_type=norm_type))
153
+
154
+ return nn.Sequential(*layers)
155
+
156
+ def forward(self, x):
157
+ x = self.prefix(x)
158
+ x = self.maxpool(x)
159
+
160
+ x = self.layer1(x)
161
+ x = self.layer2(x)
162
+ x = self.layer3(x)
163
+ x = self.layer4(x)
164
+
165
+ x = self.avgpool(x)
166
+ x = x.view(x.size(0), -1)
167
+ x = self.fc(x)
168
+
169
+ return x
170
+
171
+
172
+ def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
173
+ """Constructs a ResNet-18 model.
174
+ Args:
175
+ pretrained (bool): If True, returns a model pre-trained on Places
176
+ norm_type (str): choose norm type
177
+ """
178
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type)
179
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
180
+ return model
181
+
182
+
183
+ def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
184
+ """Constructs a ResNet-18 model.
185
+ Args:
186
+ pretrained (bool): If True, returns a model pre-trained on Places
187
+ """
188
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type)
189
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
190
+ return model
191
+
192
+
193
+ def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
194
+ """Constructs a ResNet-34 model.
195
+ Args:
196
+ pretrained (bool): If True, returns a model pre-trained on Places
197
+ """
198
+ model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
199
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
200
+ return model
201
+
202
+
203
+ def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
204
+ """Constructs a ResNet-34 model.
205
+ Args:
206
+ pretrained (bool): If True, returns a model pre-trained on Places
207
+ """
208
+ model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
209
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
210
+ return model
211
+
212
+
213
+ def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
214
+ """Constructs a ResNet-50 model.
215
+ Args:
216
+ pretrained (bool): If True, returns a model pre-trained on Places
217
+ """
218
+ model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type,
219
+ width_multiplier=kwargs["width_multiplier"])
220
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
221
+ return model
222
+
223
+
224
+ def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
225
+ """Constructs a ResNet-50 model.
226
+ Args:
227
+ pretrained (bool): If True, returns a model pre-trained on Places
228
+ """
229
+ model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
230
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
231
+ return model
232
+
233
+
234
+ def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
235
+ """Constructs a ResNet-101 model.
236
+ Args:
237
+ pretrained (bool): If True, returns a model pre-trained on Places
238
+ """
239
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
240
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
241
+ return model
242
+
243
+
244
+ def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
245
+ """Constructs a ResNet-101 model.
246
+ Args:
247
+ pretrained (bool): If True, returns a model pre-trained on Places
248
+ """
249
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
250
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
251
+ return model
252
+
253
+
254
+ def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
255
+ """Constructs a ResNet-152 model.
256
+
257
+ Args:
258
+ pretrained (bool): If True, returns a model pre-trained on Places
259
+ """
260
+ model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
261
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
262
+ return model
263
+
264
+
265
+ def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
266
+ """Constructs a ResNet-152 model.
267
+
268
+ Args:
269
+ pretrained (bool): If True, returns a model pre-trained on Places
270
+ """
271
+ model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
272
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
273
+ return model
networks/timm_deit.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+
8
+ from networks.timm_vit import VisionTransformer, _cfg
9
+ from timm.models.registry import register_model
10
+ from timm.models.layers import trunc_normal_
11
+
12
+
13
+ __all__ = [
14
+ 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
15
+ 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
16
+ 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
17
+ 'deit_base_distilled_patch16_384',
18
+ ]
19
+
20
+
21
+ class DistilledVisionTransformer(VisionTransformer):
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
25
+ num_patches = self.patch_embed.num_patches
26
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
27
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
28
+
29
+ trunc_normal_(self.dist_token, std=.02)
30
+ trunc_normal_(self.pos_embed, std=.02)
31
+ self.head_dist.apply(self._init_weights)
32
+
33
+ def forward_features(self, x):
34
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
35
+ # with slight modifications to add the dist_token
36
+ B = x.shape[0]
37
+ x = self.patch_embed(x)
38
+
39
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
40
+ dist_token = self.dist_token.expand(B, -1, -1)
41
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
42
+
43
+ x = x + self.pos_embed
44
+ x = self.pos_drop(x)
45
+
46
+ for blk in self.blocks:
47
+ x = blk(x)
48
+
49
+ x = self.norm(x)
50
+ return x[:, 0], x[:, 1]
51
+
52
+ def forward(self, x):
53
+ x, x_dist = self.forward_features(x)
54
+ x = self.head(x)
55
+ x_dist = self.head_dist(x_dist)
56
+ if self.training:
57
+ return x, x_dist
58
+ else:
59
+ # during inference, return the average of both classifier predictions
60
+ return (x + x_dist) / 2
61
+
62
+ def interpolate_pos_encoding(self, x, pos_embed):
63
+ """Interpolate the learnable positional encoding to match the number of patches.
64
+
65
+ x: B x (1 + 1 + N patches) x dim_embedding
66
+ pos_embed: B x (1 + 1 + N patches) x dim_embedding
67
+
68
+ return interpolated positional embedding
69
+ """
70
+
71
+ npatch = x.shape[1] - 2 # (H // patch_size * W // patch_size)
72
+ N = pos_embed.shape[1] - 2 # 784 (= 28 x 28)
73
+
74
+ if npatch == N:
75
+ return pos_embed
76
+
77
+ class_emb, distil_token, pos_embed = pos_embed[:, 0], pos_embed[:, 1], pos_embed[:, 2:] # a learnable CLS token, learnable position embeddings
78
+
79
+ dim = x.shape[-1] # dimension of embeddings
80
+ pos_embed = nn.functional.interpolate(
81
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
82
+ scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
83
+ recompute_scale_factor=True,
84
+ mode='bicubic'
85
+ )
86
+ # print("pos_embed", pos_embed.shape, npatch, N, math.sqrt(npatch/N), math.sqrt(npatch/N) * int(math.sqrt(N)))
87
+ # exit(12)
88
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
89
+ pos_embed = torch.cat((class_emb.unsqueeze(0), distil_token.unsqueeze(0), pos_embed), dim=1)
90
+ return pos_embed
91
+
92
+ def get_tokens(
93
+ self,
94
+ x,
95
+ layers: list,
96
+ patch_tokens: bool = False,
97
+ norm: bool = True,
98
+ input_tokens: bool = False,
99
+ post_pe: bool = False
100
+ ):
101
+ """Return intermediate tokens."""
102
+ list_tokens: list = []
103
+
104
+ B = x.shape[0]
105
+ x = self.patch_embed(x)
106
+
107
+ cls_tokens = self.cls_token.expand(B, -1, -1)
108
+ dist_token = self.dist_token.expand(B, -1, -1)
109
+
110
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
111
+
112
+ if input_tokens:
113
+ list_tokens.append(x)
114
+
115
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
116
+ x = x + pos_embed
117
+
118
+ if post_pe:
119
+ list_tokens.append(x)
120
+
121
+ x = self.pos_drop(x)
122
+
123
+ for i, blk in enumerate(self.blocks):
124
+ x = blk(x) # B x # patches x dim
125
+ if layers is None or i in layers:
126
+ list_tokens.append(self.norm(x) if norm else x)
127
+
128
+ tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
129
+
130
+ if not patch_tokens:
131
+ return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
132
+
133
+ else:
134
+ return torch.cat((tokens[:, :, 0, :].unsqueeze(dim=2), tokens[:, :, 2:, :]), dim=2) # exclude distil token.
135
+
136
+
137
+ @register_model
138
+ def deit_tiny_patch16_224(pretrained=False, **kwargs):
139
+ model = VisionTransformer(
140
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
141
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
142
+ model.default_cfg = _cfg()
143
+ if pretrained:
144
+ checkpoint = torch.hub.load_state_dict_from_url(
145
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
146
+ map_location="cpu", check_hash=True
147
+ )
148
+ model.load_state_dict(checkpoint["model"])
149
+ return model
150
+
151
+
152
+ @register_model
153
+ def deit_small_patch16_224(pretrained=False, **kwargs):
154
+ model = VisionTransformer(
155
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
156
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
157
+ model.default_cfg = _cfg()
158
+ if pretrained:
159
+ checkpoint = torch.hub.load_state_dict_from_url(
160
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
161
+ map_location="cpu", check_hash=True
162
+ )
163
+ model.load_state_dict(checkpoint["model"])
164
+ return model
165
+
166
+
167
+ @register_model
168
+ def deit_base_patch16_224(pretrained=False, **kwargs):
169
+ model = VisionTransformer(
170
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
171
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
172
+ model.default_cfg = _cfg()
173
+ if pretrained:
174
+ checkpoint = torch.hub.load_state_dict_from_url(
175
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
176
+ map_location="cpu", check_hash=True
177
+ )
178
+ model.load_state_dict(checkpoint["model"])
179
+ return model
180
+
181
+
182
+ @register_model
183
+ def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
184
+ model = DistilledVisionTransformer(
185
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
186
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
187
+ model.default_cfg = _cfg()
188
+ if pretrained:
189
+ checkpoint = torch.hub.load_state_dict_from_url(
190
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
191
+ map_location="cpu", check_hash=True
192
+ )
193
+ model.load_state_dict(checkpoint["model"])
194
+ return model
195
+
196
+
197
+ @register_model
198
+ def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
199
+ model = DistilledVisionTransformer(
200
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
201
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
202
+ model.default_cfg = _cfg()
203
+ if pretrained:
204
+ checkpoint = torch.hub.load_state_dict_from_url(
205
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
206
+ map_location="cpu", check_hash=True
207
+ )
208
+ model.load_state_dict(checkpoint["model"])
209
+ return model
210
+
211
+
212
+ @register_model
213
+ def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
214
+ model = DistilledVisionTransformer(
215
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
216
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
217
+ model.default_cfg = _cfg()
218
+ if pretrained:
219
+ checkpoint = torch.hub.load_state_dict_from_url(
220
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
221
+ map_location="cpu", check_hash=True
222
+ )
223
+ model.load_state_dict(checkpoint["model"])
224
+ return model
225
+
226
+
227
+ @register_model
228
+ def deit_base_patch16_384(pretrained=False, **kwargs):
229
+ model = VisionTransformer(
230
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
231
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
232
+ model.default_cfg = _cfg()
233
+ if pretrained:
234
+ checkpoint = torch.hub.load_state_dict_from_url(
235
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
236
+ map_location="cpu", check_hash=True
237
+ )
238
+ model.load_state_dict(checkpoint["model"])
239
+ return model
240
+
241
+
242
+ @register_model
243
+ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
244
+ model = DistilledVisionTransformer(
245
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
246
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
247
+ model.default_cfg = _cfg()
248
+ if pretrained:
249
+ checkpoint = torch.hub.load_state_dict_from_url(
250
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
251
+ map_location="cpu", check_hash=True
252
+ )
253
+ model.load_state_dict(checkpoint["model"])
254
+ return model
networks/timm_vit.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ DeiT model defs and weights from https://github.com/facebookresearch/deit,
9
+ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
10
+
11
+ Acknowledgments:
12
+ * The paper authors for releasing code and weights, thanks!
13
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
14
+ for some einops/einsum fun
15
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
16
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+ """
20
+ import math
21
+ import logging
22
+ from functools import partial
23
+ from collections import OrderedDict
24
+ from copy import deepcopy
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
31
+ from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
32
+ from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
33
+ from timm.models.registry import register_model
34
+
35
+ _logger = logging.getLogger(__name__)
36
+
37
+
38
+ def _cfg(url='', **kwargs):
39
+ return {
40
+ 'url': url,
41
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
42
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
43
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
44
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
45
+ **kwargs
46
+ }
47
+
48
+
49
+ default_cfgs = {
50
+ # patch models (my experiments)
51
+ 'vit_small_patch16_224': _cfg(
52
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
53
+ ),
54
+
55
+ # patch models (weights ported from official Google JAX impl)
56
+ 'vit_base_patch16_224': _cfg(
57
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
58
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
59
+ ),
60
+ 'vit_base_patch32_224': _cfg(
61
+ url='', # no official model weights for this combo, only for in21k
62
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
63
+ 'vit_base_patch16_384': _cfg(
64
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
65
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
66
+ 'vit_base_patch32_384': _cfg(
67
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
68
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
69
+ 'vit_large_patch16_224': _cfg(
70
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
71
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
72
+ 'vit_large_patch32_224': _cfg(
73
+ url='', # no official model weights for this combo, only for in21k
74
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
75
+ 'vit_large_patch16_384': _cfg(
76
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
77
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
78
+ 'vit_large_patch32_384': _cfg(
79
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
80
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
81
+
82
+ # patch models, imagenet21k (weights ported from official Google JAX impl)
83
+ 'vit_base_patch16_224_in21k': _cfg(
84
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
85
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
86
+ 'vit_base_patch32_224_in21k': _cfg(
87
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
88
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
89
+ 'vit_large_patch16_224_in21k': _cfg(
90
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
91
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
92
+ 'vit_large_patch32_224_in21k': _cfg(
93
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
94
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
95
+ 'vit_huge_patch14_224_in21k': _cfg(
96
+ hf_hub='timm/vit_huge_patch14_224_in21k',
97
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
98
+
99
+ # deit models (FB weights)
100
+ 'vit_deit_tiny_patch16_224': _cfg(
101
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
102
+ 'vit_deit_small_patch16_224': _cfg(
103
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
104
+ 'vit_deit_base_patch16_224': _cfg(
105
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
106
+ 'vit_deit_base_patch16_384': _cfg(
107
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
108
+ input_size=(3, 384, 384), crop_pct=1.0),
109
+ 'vit_deit_tiny_distilled_patch16_224': _cfg(
110
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
111
+ classifier=('head', 'head_dist')),
112
+ 'vit_deit_small_distilled_patch16_224': _cfg(
113
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
114
+ classifier=('head', 'head_dist')),
115
+ 'vit_deit_base_distilled_patch16_224': _cfg(
116
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
117
+ classifier=('head', 'head_dist')),
118
+ 'vit_deit_base_distilled_patch16_384': _cfg(
119
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
120
+ input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
121
+
122
+ # ViT ImageNet-21K-P pretraining
123
+ 'vit_base_patch16_224_miil_in21k': _cfg(
124
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
125
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
126
+ ),
127
+ 'vit_base_patch16_224_miil': _cfg(
128
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
129
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
130
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
131
+ ),
132
+ }
133
+
134
+
135
+ class Attention(nn.Module):
136
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
137
+ super().__init__()
138
+ self.num_heads = num_heads
139
+ head_dim = dim // num_heads
140
+ self.scale = qk_scale or head_dim ** -0.5
141
+
142
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
143
+ self.attn_drop = nn.Dropout(attn_drop)
144
+ self.proj = nn.Linear(dim, dim)
145
+ self.proj_drop = nn.Dropout(proj_drop)
146
+
147
+ def forward(self, x):
148
+ B, N, C = x.shape
149
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
150
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
151
+
152
+ attn = (q @ k.transpose(-2, -1)) * self.scale
153
+ attn = attn.softmax(dim=-1)
154
+ attn = self.attn_drop(attn)
155
+
156
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
157
+ x = self.proj(x)
158
+ x = self.proj_drop(x)
159
+ return x
160
+
161
+
162
+ class Block(nn.Module):
163
+
164
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
165
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
166
+ super().__init__()
167
+ self.norm1 = norm_layer(dim)
168
+ self.attn = Attention(
169
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
170
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
171
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
172
+ self.norm2 = norm_layer(dim)
173
+ mlp_hidden_dim = int(dim * mlp_ratio)
174
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
175
+
176
+ def forward(self, x):
177
+ x = x + self.drop_path(self.attn(self.norm1(x)))
178
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
179
+ return x
180
+
181
+
182
+ class VisionTransformer(nn.Module):
183
+ """ Vision Transformer
184
+
185
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
186
+ - https://arxiv.org/abs/2010.11929
187
+
188
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
189
+ - https://arxiv.org/abs/2012.12877
190
+ """
191
+
192
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
193
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
194
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
195
+ act_layer=None, weight_init='',
196
+ # noel
197
+ img_size_eval: int = 224):
198
+ """
199
+ Args:
200
+ img_size (int, tuple): input image size
201
+ patch_size (int, tuple): patch size
202
+ in_chans (int): number of input channels
203
+ num_classes (int): number of classes for classification head
204
+ embed_dim (int): embedding dimension
205
+ depth (int): depth of transformer
206
+ num_heads (int): number of attention heads
207
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
208
+ qkv_bias (bool): enable bias for qkv if True
209
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
210
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
211
+ distilled (bool): model includes a distillation token and head as in DeiT models
212
+ drop_rate (float): dropout rate
213
+ attn_drop_rate (float): attention dropout rate
214
+ drop_path_rate (float): stochastic depth rate
215
+ embed_layer (nn.Module): patch embedding layer
216
+ norm_layer: (nn.Module): normalization layer
217
+ weight_init: (str): weight init scheme
218
+ """
219
+ super().__init__()
220
+ self.num_classes = num_classes
221
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
222
+ self.num_tokens = 2 if distilled else 1
223
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
224
+ act_layer = act_layer or nn.GELU
225
+
226
+ self.patch_embed = embed_layer(
227
+ img_size=img_size,
228
+ patch_size=patch_size,
229
+ in_chans=in_chans,
230
+ embed_dim=embed_dim
231
+ )
232
+ num_patches = self.patch_embed.num_patches
233
+
234
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
235
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
236
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
237
+ self.pos_drop = nn.Dropout(p=drop_rate)
238
+
239
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
240
+ self.blocks = nn.Sequential(*[
241
+ Block(
242
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
243
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
244
+ for i in range(depth)])
245
+ self.norm = norm_layer(embed_dim)
246
+
247
+ # Representation layer
248
+ if representation_size and not distilled:
249
+ self.num_features = representation_size
250
+ self.pre_logits = nn.Sequential(OrderedDict([
251
+ ('fc', nn.Linear(embed_dim, representation_size)),
252
+ ('act', nn.Tanh())
253
+ ]))
254
+ else:
255
+ self.pre_logits = nn.Identity()
256
+
257
+ # Classifier head(s)
258
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
259
+ self.head_dist = None
260
+ if distilled:
261
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
262
+
263
+ # Weight init
264
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
265
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
266
+ trunc_normal_(self.pos_embed, std=.02)
267
+ if self.dist_token is not None:
268
+ trunc_normal_(self.dist_token, std=.02)
269
+ if weight_init.startswith('jax'):
270
+ # leave cls token as zeros to match jax impl
271
+ for n, m in self.named_modules():
272
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
273
+ else:
274
+ trunc_normal_(self.cls_token, std=.02)
275
+ self.apply(_init_vit_weights)
276
+
277
+ # noel
278
+ self.depth = depth
279
+ self.distilled = distilled
280
+ self.patch_size = patch_size
281
+ self.patch_embed.img_size = (img_size_eval, img_size_eval)
282
+
283
+ def _init_weights(self, m):
284
+ # this fn left here for compat with downstream users
285
+ _init_vit_weights(m)
286
+
287
+ @torch.jit.ignore
288
+ def no_weight_decay(self):
289
+ return {'pos_embed', 'cls_token', 'dist_token'}
290
+
291
+ def get_classifier(self):
292
+ if self.dist_token is None:
293
+ return self.head
294
+ else:
295
+ return self.head, self.head_dist
296
+
297
+ def reset_classifier(self, num_classes, global_pool=''):
298
+ self.num_classes = num_classes
299
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
300
+ if self.num_tokens == 2:
301
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
302
+
303
+ def forward_features(self, x):
304
+ x = self.patch_embed(x)
305
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
306
+ if self.dist_token is None:
307
+ x = torch.cat((cls_token, x), dim=1)
308
+ else:
309
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
310
+ x = self.pos_drop(x + self.pos_embed)
311
+ x = self.blocks(x)
312
+ x = self.norm(x)
313
+ if self.dist_token is None:
314
+ return self.pre_logits(x[:, 0])
315
+ else:
316
+ return x[:, 0], x[:, 1]
317
+
318
+ # def forward(self, x):
319
+ # x = self.forward_features(x)
320
+ # if self.head_dist is not None:
321
+ # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
322
+ # if self.training and not torch.jit.is_scripting():
323
+ # # during inference, return the average of both classifier predictions
324
+ # return x, x_dist
325
+ # else:
326
+ # return (x + x_dist) / 2
327
+ # else:
328
+ # x = self.head(x)
329
+ # return x
330
+
331
+ # noel - start
332
+ def make_square(self, x: torch.Tensor):
333
+ """Pad some pixels to make the input size divisible by the patch size."""
334
+ B, _, H_0, W_0 = x.shape
335
+ pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
336
+ pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
337
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=x.mean())
338
+
339
+ H_p, W_p = H_0 + pad_h, W_0 + pad_w
340
+ x = nn.functional.pad(x, (0, H_p - W_p, 0, 0) if H_p > W_p else (0, 0, 0, W_p - H_p), value=x.mean())
341
+ return x
342
+
343
+ def interpolate_pos_encoding(self, x, pos_embed, size):
344
+ """Interpolate the learnable positional encoding to match the number of patches.
345
+
346
+ x: B x (1 + N patches) x dim_embedding
347
+ pos_embed: B x (1 + N patches) x dim_embedding
348
+
349
+ return interpolated positional embedding
350
+ """
351
+ npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
352
+ N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
353
+ if npatch == N:
354
+ return pos_embed
355
+ class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
356
+
357
+ dim = x.shape[-1] # dimension of embeddings
358
+ pos_embed = nn.functional.interpolate(
359
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
360
+ size=size,
361
+ mode='bicubic',
362
+ align_corners=False
363
+ )
364
+
365
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
366
+ pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
367
+ return pos_embed
368
+
369
+ # def interpolate_pos_encoding(self, x, pos_embed):
370
+ # """Interpolate the learnable positional encoding to match the number of patches.
371
+ #
372
+ # x: B x (1 + N patches) x dim_embedding
373
+ # pos_embed: B x (1 + N patches) x dim_embedding
374
+ #
375
+ # return interpolated positional embedding
376
+ # """
377
+ # npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
378
+ # N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
379
+ # if npatch == N:
380
+ # return pos_embed
381
+ # class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
382
+ #
383
+ # dim = x.shape[-1] # dimension of embeddings
384
+ # pos_embed = nn.functional.interpolate(
385
+ # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
386
+ # scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
387
+ # recompute_scale_factor=True,
388
+ # mode='bicubic',
389
+ # align_corners=False
390
+ # )
391
+ #
392
+ # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
393
+ # pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
394
+ # return pos_embed
395
+
396
+ def prepare_tokens(self, x):
397
+ B, nc, h, w = x.shape
398
+ patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
399
+ x = self.patch_embed(x) # patch linear embedding
400
+
401
+ # add the [CLS] token to the embed patch tokens
402
+ cls_tokens = self.cls_token.expand(B, -1, -1)
403
+ x = torch.cat((cls_tokens, x), dim=1)
404
+
405
+ # add positional encoding to each token
406
+ x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
407
+ return self.pos_drop(x)
408
+
409
+ def get_tokens(
410
+ self,
411
+ x,
412
+ layers: list,
413
+ patch_tokens: bool = False,
414
+ norm: bool = True,
415
+ input_tokens: bool = False,
416
+ post_pe: bool = False
417
+ ):
418
+ """Return intermediate tokens."""
419
+ list_tokens: list = []
420
+
421
+ B = x.shape[0]
422
+ x = self.patch_embed(x)
423
+
424
+ cls_tokens = self.cls_token.expand(B, -1, -1)
425
+
426
+ x = torch.cat((cls_tokens, x), dim=1)
427
+
428
+ if input_tokens:
429
+ list_tokens.append(x)
430
+
431
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
432
+ x = x + pos_embed
433
+
434
+ if post_pe:
435
+ list_tokens.append(x)
436
+
437
+ x = self.pos_drop(x)
438
+
439
+ for i, blk in enumerate(self.blocks):
440
+ x = blk(x) # B x # patches x dim
441
+ if layers is None or i in layers:
442
+ list_tokens.append(self.norm(x) if norm else x)
443
+
444
+ tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
445
+
446
+ if not patch_tokens:
447
+ return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
448
+
449
+ else:
450
+ return tokens
451
+
452
+ def forward(self, x, layer: str = None):
453
+ x = self.prepare_tokens(x)
454
+
455
+ features: dict = {}
456
+ for i, blk in enumerate(self.blocks):
457
+ x = blk(x)
458
+ features[f"layer{i + 1}"] = self.norm(x)
459
+
460
+ if layer is not None:
461
+ return features[layer]
462
+ else:
463
+ return features["layer12"]
464
+ # noel - end
465
+
466
+
467
+ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
468
+ """ ViT weight initialization
469
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
470
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
471
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
472
+ """
473
+ if isinstance(m, nn.Linear):
474
+ if n.startswith('head'):
475
+ nn.init.zeros_(m.weight)
476
+ nn.init.constant_(m.bias, head_bias)
477
+ elif n.startswith('pre_logits'):
478
+ lecun_normal_(m.weight)
479
+ nn.init.zeros_(m.bias)
480
+ else:
481
+ if jax_impl:
482
+ nn.init.xavier_uniform_(m.weight)
483
+ if m.bias is not None:
484
+ if 'mlp' in n:
485
+ nn.init.normal_(m.bias, std=1e-6)
486
+ else:
487
+ nn.init.zeros_(m.bias)
488
+ else:
489
+ trunc_normal_(m.weight, std=.02)
490
+ if m.bias is not None:
491
+ nn.init.zeros_(m.bias)
492
+ elif jax_impl and isinstance(m, nn.Conv2d):
493
+ # NOTE conv was left to pytorch default in my original init
494
+ lecun_normal_(m.weight)
495
+ if m.bias is not None:
496
+ nn.init.zeros_(m.bias)
497
+ elif isinstance(m, nn.LayerNorm):
498
+ nn.init.zeros_(m.bias)
499
+ nn.init.ones_(m.weight)
500
+
501
+
502
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
503
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
504
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
505
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
506
+ ntok_new = posemb_new.shape[1]
507
+ if num_tokens:
508
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
509
+ ntok_new -= num_tokens
510
+ else:
511
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
512
+ gs_old = int(math.sqrt(len(posemb_grid)))
513
+ if not len(gs_new): # backwards compatibility
514
+ gs_new = [int(math.sqrt(ntok_new))] * 2
515
+ assert len(gs_new) >= 2
516
+ _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
517
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
518
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
519
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
520
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
521
+ return posemb
522
+
523
+
524
+ def checkpoint_filter_fn(state_dict, model):
525
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
526
+ out_dict = {}
527
+ if 'model' in state_dict:
528
+ # For deit models
529
+ state_dict = state_dict['model']
530
+ for k, v in state_dict.items():
531
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
532
+ # For old models that I trained prior to conv based patchification
533
+ O, I, H, W = model.patch_embed.proj.weight.shape
534
+ v = v.reshape(O, -1, H, W)
535
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
536
+ # To resize pos embedding when using model at different size from pretrained weights
537
+ v = resize_pos_embed(
538
+ v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
539
+ out_dict[k] = v
540
+ return out_dict
541
+
542
+
543
+ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
544
+ default_cfg = default_cfg or default_cfgs[variant]
545
+ if kwargs.get('features_only', None):
546
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
547
+
548
+ # NOTE this extra code to support handling of repr size for in21k pretrained models
549
+ default_num_classes = default_cfg['num_classes']
550
+ num_classes = kwargs.get('num_classes', default_num_classes)
551
+ repr_size = kwargs.pop('representation_size', None)
552
+ if repr_size is not None and num_classes != default_num_classes:
553
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
554
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
555
+ _logger.warning("Removing representation layer for fine-tuning.")
556
+ repr_size = None
557
+
558
+ model = build_model_with_cfg(
559
+ VisionTransformer, variant, pretrained,
560
+ default_cfg=default_cfg,
561
+ representation_size=repr_size,
562
+ pretrained_filter_fn=checkpoint_filter_fn,
563
+ **kwargs)
564
+ return model
565
+
566
+
567
+ @register_model
568
+ def vit_small_patch16_224(pretrained=False, **kwargs):
569
+ """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
570
+ NOTE:
571
+ * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
572
+ * this model does not have a bias for QKV (unlike the official ViT and DeiT models)
573
+ """
574
+ model_kwargs = dict(
575
+ patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
576
+ qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
577
+ if pretrained:
578
+ # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
579
+ model_kwargs.setdefault('qk_scale', 768 ** -0.5)
580
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
581
+ return model
582
+
583
+
584
+ @register_model
585
+ def vit_base_patch16_224(pretrained=False, **kwargs):
586
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
587
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
588
+ """
589
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
590
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
591
+ return model
592
+
593
+
594
+ @register_model
595
+ def vit_base_patch32_224(pretrained=False, **kwargs):
596
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
597
+ """
598
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
599
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
600
+ return model
601
+
602
+
603
+ @register_model
604
+ def vit_base_patch16_384(pretrained=False, **kwargs):
605
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
606
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
607
+ """
608
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
609
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
610
+ return model
611
+
612
+
613
+ @register_model
614
+ def vit_base_patch32_384(pretrained=False, **kwargs):
615
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
616
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
617
+ """
618
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
619
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
620
+ return model
621
+
622
+
623
+ @register_model
624
+ def vit_large_patch16_224(pretrained=False, **kwargs):
625
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
626
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
627
+ """
628
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
629
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
630
+ return model
631
+
632
+
633
+ @register_model
634
+ def vit_large_patch32_224(pretrained=False, **kwargs):
635
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
636
+ """
637
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
638
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
639
+ return model
640
+
641
+
642
+ @register_model
643
+ def vit_large_patch16_384(pretrained=False, **kwargs):
644
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
645
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
646
+ """
647
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
648
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
649
+ return model
650
+
651
+
652
+ @register_model
653
+ def vit_large_patch32_384(pretrained=False, **kwargs):
654
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
655
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
656
+ """
657
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
658
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
659
+ return model
660
+
661
+
662
+ @register_model
663
+ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
664
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
665
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
666
+ """
667
+ model_kwargs = dict(
668
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
669
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
670
+ return model
671
+
672
+
673
+ @register_model
674
+ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
675
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
676
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
677
+ """
678
+ model_kwargs = dict(
679
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
680
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
681
+ return model
682
+
683
+
684
+ @register_model
685
+ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
686
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
687
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
688
+ """
689
+ model_kwargs = dict(
690
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
691
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
692
+ return model
693
+
694
+
695
+ @register_model
696
+ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
697
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
698
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
699
+ """
700
+ model_kwargs = dict(
701
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
702
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
703
+ return model
704
+
705
+
706
+ @register_model
707
+ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
708
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
709
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
710
+ NOTE: converted weights not currently available, too large for github release hosting.
711
+ """
712
+ model_kwargs = dict(
713
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
714
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
715
+ return model
716
+
717
+
718
+ @register_model
719
+ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
720
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
721
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
722
+ """
723
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
724
+ model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
725
+ return model
726
+
727
+
728
+ @register_model
729
+ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
730
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
731
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
732
+ """
733
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
734
+ model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
735
+ return model
736
+
737
+
738
+ @register_model
739
+ def vit_deit_base_patch16_224(pretrained=False, **kwargs):
740
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
741
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
742
+ """
743
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
744
+ model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
745
+ return model
746
+
747
+
748
+ @register_model
749
+ def vit_deit_base_patch16_384(pretrained=False, **kwargs):
750
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
751
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
752
+ """
753
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
754
+ model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
755
+ return model
756
+
757
+
758
+ @register_model
759
+ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
760
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
761
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
762
+ """
763
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
764
+ model = _create_vision_transformer(
765
+ 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
766
+ return model
767
+
768
+
769
+ @register_model
770
+ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
771
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
772
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
773
+ """
774
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
775
+ model = _create_vision_transformer(
776
+ 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
777
+ return model
778
+
779
+
780
+ @register_model
781
+ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
782
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
783
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
784
+ """
785
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
786
+ model = _create_vision_transformer(
787
+ 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
788
+ return model
789
+
790
+
791
+ @register_model
792
+ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
793
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
794
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
795
+ """
796
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
797
+ model = _create_vision_transformer(
798
+ 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
799
+ return model
800
+
801
+
802
+ @register_model
803
+ def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
804
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
805
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
806
+ """
807
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
808
+ model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
809
+ return model
810
+
811
+
812
+ @register_model
813
+ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
814
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
815
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
816
+ """
817
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
818
+ model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
819
+ return model
networks/vision_transformer.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Mostly copy-paste from timm library.
4
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
5
+ """
6
+ from typing import Optional
7
+ import math
8
+ from functools import partial
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
15
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
16
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
17
+ def norm_cdf(x):
18
+ # Computes standard normal cumulative distribution function
19
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
20
+
21
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
22
+ warnings.warn(
23
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.",
24
+ stacklevel=2
25
+ )
26
+
27
+ with torch.no_grad():
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a - mean) / std)
32
+ u = norm_cdf((b - mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+
51
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
52
+ # type: (Tensor, float, float, float, float) -> Tensor
53
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
54
+
55
+
56
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
57
+ if drop_prob == 0. or not training:
58
+ return x
59
+ keep_prob = 1 - drop_prob
60
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
61
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
62
+ random_tensor.floor_() # binarize
63
+ output = x.div(keep_prob) * random_tensor
64
+ return output
65
+
66
+
67
+ class DropPath(nn.Module):
68
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
69
+ """
70
+ def __init__(self, drop_prob=None):
71
+ super(DropPath, self).__init__()
72
+ self.drop_prob = drop_prob
73
+
74
+ def forward(self, x):
75
+ return drop_path(x, self.drop_prob, self.training)
76
+
77
+
78
+ class Mlp(nn.Module):
79
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
80
+ super().__init__()
81
+ out_features = out_features or in_features
82
+ hidden_features = hidden_features or in_features
83
+ self.fc1 = nn.Linear(in_features, hidden_features)
84
+ self.act = act_layer()
85
+ self.fc2 = nn.Linear(hidden_features, out_features)
86
+ self.drop = nn.Dropout(drop)
87
+
88
+ def forward(self, x):
89
+ x = self.fc1(x)
90
+ x = self.act(x)
91
+ x = self.drop(x)
92
+ x = self.fc2(x)
93
+ x = self.drop(x)
94
+ return x
95
+
96
+
97
+ class Attention(nn.Module):
98
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
99
+ super().__init__()
100
+ self.num_heads = num_heads
101
+ head_dim = dim // num_heads
102
+ self.scale = qk_scale or head_dim ** -0.5 # square root of dimension for normalisation
103
+
104
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
105
+ self.attn_drop = nn.Dropout(attn_drop)
106
+
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x):
111
+ B, N, C = x.shape # B x (cls token + # patch tokens) x dim
112
+
113
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
114
+ # qkv: 3 x B x Nh x (cls token + # patch tokens) x (dim // Nh)
115
+
116
+ q, k, v = qkv[0], qkv[1], qkv[2]
117
+ # q, k, v: B x Nh x (cls token + # patch tokens) x (dim // Nh)
118
+
119
+ # q: B x Nh x (cls token + # patch tokens) x (dim // Nh)
120
+ # k.transpose(-2, -1) = B x Nh x (dim // Nh) x (cls token + # patch tokens)
121
+ # attn: B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
122
+ attn = (q @ k.transpose(-2, -1)) * self.scale # @ operator is for matrix multiplication
123
+ attn = attn.softmax(dim=-1) # B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
124
+ attn = self.attn_drop(attn)
125
+
126
+ # attn = B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
127
+ # v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
128
+ # attn @ v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
129
+ # (attn @ v).transpose(1, 2) = B x (cls token + # patch tokens) x Nh x (dim // Nh)
130
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C) # B x (cls token + # patch tokens) x dim
131
+ x = self.proj(x) # B x (cls token + # patch tokens) x dim
132
+ x = self.proj_drop(x)
133
+ return x, attn
134
+
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self,
138
+ dim, num_heads,
139
+ mlp_ratio=4.,
140
+ qkv_bias=False,
141
+ qk_scale=None,
142
+ drop=0.,
143
+ attn_drop=0.,
144
+ drop_path=0.,
145
+ act_layer=nn.GELU,
146
+ norm_layer=nn.LayerNorm):
147
+ super().__init__()
148
+ self.norm1 = norm_layer(dim)
149
+ self.attn = Attention(
150
+ dim,
151
+ num_heads=num_heads,
152
+ qkv_bias=qkv_bias,
153
+ qk_scale=qk_scale,
154
+ attn_drop=attn_drop,
155
+ proj_drop=drop
156
+ )
157
+
158
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
159
+
160
+ self.norm2 = norm_layer(dim)
161
+ mlp_hidden_dim = int(dim * mlp_ratio)
162
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
163
+
164
+ def forward(self, x, return_attention=False):
165
+ y, attn = self.attn(self.norm1(x))
166
+ if return_attention:
167
+ return attn
168
+ x = x + self.drop_path(y)
169
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
170
+ return x
171
+
172
+
173
+ class PatchEmbed(nn.Module):
174
+ """ Image to Patch Embedding"""
175
+ def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
176
+ super().__init__()
177
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
178
+ self.img_size = img_size
179
+ self.patch_size = patch_size
180
+ self.num_patches = num_patches
181
+
182
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
183
+
184
+ def forward(self, x):
185
+ B, C, H, W = x.shape
186
+ x = self.proj(x)
187
+ x = x.flatten(2).transpose(1, 2) # B x (P_H * P_W) x C
188
+ return x
189
+
190
+
191
+ class VisionTransformer(nn.Module):
192
+ """ Vision Transformer """
193
+ def __init__(self,
194
+ img_size=(224, 224),
195
+ patch_size=16,
196
+ in_chans=3,
197
+ num_classes=0,
198
+ embed_dim=768,
199
+ depth=12,
200
+ num_heads=12,
201
+ mlp_ratio=4.,
202
+ qkv_bias=False,
203
+ qk_scale=None,
204
+ drop_rate=0.,
205
+ attn_drop_rate=0.,
206
+ drop_path_rate=0.,
207
+ norm_layer=nn.LayerNorm):
208
+ super().__init__()
209
+ self.num_features = self.embed_dim = embed_dim
210
+
211
+ self.patch_embed = PatchEmbed(
212
+ img_size=(224, 224), # noel: this is to load pretrained model.
213
+ patch_size=patch_size,
214
+ in_chans=in_chans,
215
+ embed_dim=embed_dim
216
+ )
217
+ num_patches = self.patch_embed.num_patches
218
+
219
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
220
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
221
+ self.pos_drop = nn.Dropout(p=drop_rate)
222
+
223
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
224
+ self.blocks = nn.ModuleList([
225
+ Block(
226
+ dim=embed_dim,
227
+ num_heads=num_heads,
228
+ mlp_ratio=mlp_ratio,
229
+ qkv_bias=qkv_bias,
230
+ qk_scale=qk_scale,
231
+ drop=drop_rate,
232
+ attn_drop=attn_drop_rate,
233
+ drop_path=dpr[i],
234
+ norm_layer=norm_layer
235
+ ) for i in range(depth)])
236
+ self.norm = norm_layer(embed_dim)
237
+
238
+ # Classifier head
239
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
240
+
241
+ trunc_normal_(self.pos_embed, std=.02)
242
+ trunc_normal_(self.cls_token, std=.02)
243
+ self.apply(self._init_weights)
244
+
245
+ self.depth = depth
246
+ self.embed_dim = self.n_embs = embed_dim
247
+ self.mlp_ratio = mlp_ratio
248
+ self.n_heads = num_heads
249
+ self.patch_size = patch_size
250
+
251
+ def _init_weights(self, m):
252
+ if isinstance(m, nn.Linear):
253
+ trunc_normal_(m.weight, std=.02)
254
+ if isinstance(m, nn.Linear) and m.bias is not None:
255
+ nn.init.constant_(m.bias, 0)
256
+ elif isinstance(m, nn.LayerNorm):
257
+ nn.init.constant_(m.bias, 0)
258
+ nn.init.constant_(m.weight, 1.0)
259
+
260
+ def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
261
+ """Pad some pixels to make the input size divisible by the patch size."""
262
+ B, _, H_0, W_0 = x.shape
263
+ pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
264
+ pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
265
+
266
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
267
+ return x
268
+
269
+ def prepare_tokens(self, x):
270
+ B, nc, h, w = x.shape
271
+ x: torch.Tensor = self.make_input_divisible(x)
272
+ patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
273
+
274
+ x = self.patch_embed(x) # patch linear embedding
275
+
276
+ # add positional encoding to each token
277
+ # add the [CLS] token to the embed patch tokens
278
+ cls_tokens = self.cls_token.expand(B, -1, -1)
279
+ x = torch.cat((cls_tokens, x), dim=1)
280
+ x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
281
+ return self.pos_drop(x)
282
+
283
+ @staticmethod
284
+ def split_token(x, token_type: str):
285
+ if token_type == "cls":
286
+ return x[:, 0, :]
287
+ elif token_type == "patch":
288
+ return x[:, 1:, :]
289
+ else:
290
+ return x
291
+
292
+ # noel
293
+ def forward(self, x, layer: Optional[str] = None):
294
+ x: torch.Tensor = self.prepare_tokens(x)
295
+
296
+ features: dict = {}
297
+ for i, blk in enumerate(self.blocks):
298
+ x = blk(x)
299
+ features[f"layer{i + 1}"] = self.norm(x)
300
+
301
+ if layer is not None:
302
+ return features[layer]
303
+ else:
304
+ return features
305
+
306
+ # noel - for DINO's visual
307
+ def get_last_selfattention(self, x):
308
+ x = self.prepare_tokens(x)
309
+ for i, blk in enumerate(self.blocks):
310
+ if i < len(self.blocks) - 1:
311
+ x = blk(x)
312
+ else:
313
+ # return attention of the last block
314
+ return blk(x, return_attention=True)
315
+
316
+ def get_tokens(
317
+ self,
318
+ x,
319
+ layers: list,
320
+ patch_tokens: bool = False,
321
+ norm: bool = True,
322
+ input_tokens: bool = False,
323
+ post_pe: bool = False
324
+ ):
325
+ """Return intermediate tokens."""
326
+ list_tokens: list = []
327
+
328
+ B = x.shape[0]
329
+ x = self.patch_embed(x)
330
+
331
+ cls_tokens = self.cls_token.expand(B, -1, -1)
332
+
333
+ x = torch.cat((cls_tokens, x), dim=1)
334
+
335
+ if input_tokens:
336
+ list_tokens.append(x)
337
+
338
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
339
+ x = x + pos_embed
340
+
341
+ if post_pe:
342
+ list_tokens.append(x)
343
+
344
+ x = self.pos_drop(x)
345
+
346
+ for i, blk in enumerate(self.blocks):
347
+ x = blk(x) # B x # patches x dim
348
+ if layers is None or i in layers:
349
+ list_tokens.append(self.norm(x) if norm else x)
350
+
351
+ tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
352
+
353
+ if not patch_tokens:
354
+ return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
355
+
356
+ else:
357
+ return tokens
358
+
359
+ def forward_features(self, x):
360
+ B = x.shape[0]
361
+ x = self.patch_embed(x)
362
+
363
+ cls_tokens = self.cls_token.expand(B, -1, -1)
364
+ x = torch.cat((cls_tokens, x), dim=1)
365
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
366
+ x = x + pos_embed
367
+ x = self.pos_drop(x)
368
+
369
+ for blk in self.blocks:
370
+ x = blk(x)
371
+
372
+ if self.norm is not None:
373
+ x = self.norm(x)
374
+
375
+ return x[:, 0]
376
+
377
+ def interpolate_pos_encoding(self, x, pos_embed, size):
378
+ """Interpolate the learnable positional encoding to match the number of patches.
379
+
380
+ x: B x (1 + N patches) x dim_embedding
381
+ pos_embed: B x (1 + N patches) x dim_embedding
382
+
383
+ return interpolated positional embedding
384
+ """
385
+ npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
386
+ N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
387
+ if npatch == N:
388
+ return pos_embed
389
+ class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
390
+
391
+ dim = x.shape[-1] # dimension of embeddings
392
+ pos_embed = nn.functional.interpolate(
393
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
394
+ size=size,
395
+ mode='bicubic',
396
+ align_corners=False
397
+ )
398
+
399
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
400
+ pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
401
+ return pos_embed
402
+
403
+ def forward_selfattention(self, x, return_interm_attn=False):
404
+ B, nc, w, h = x.shape
405
+ N = self.pos_embed.shape[1] - 1
406
+ x = self.patch_embed(x)
407
+
408
+ # interpolate patch embeddings
409
+ dim = x.shape[-1]
410
+ w0 = w // self.patch_embed.patch_size
411
+ h0 = h // self.patch_embed.patch_size
412
+ class_pos_embed = self.pos_embed[:, 0]
413
+ patch_pos_embed = self.pos_embed[:, 1:]
414
+ patch_pos_embed = nn.functional.interpolate(
415
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
416
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
417
+ mode='bicubic'
418
+ )
419
+ if w0 != patch_pos_embed.shape[-2]:
420
+ helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
421
+ patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
422
+ if h0 != patch_pos_embed.shape[-1]:
423
+ helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
424
+ pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
425
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
426
+ pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
427
+
428
+ cls_tokens = self.cls_token.expand(B, -1, -1) # self.cls_token: 1 x 1 x emb_dim -> ?
429
+ x = torch.cat((cls_tokens, x), dim=1)
430
+ x = x + pos_embed
431
+ x = self.pos_drop(x)
432
+
433
+ if return_interm_attn:
434
+ list_attn = []
435
+ for i, blk in enumerate(self.blocks):
436
+ attn = blk(x, return_attention=True)
437
+ x = blk(x)
438
+ list_attn.append(attn)
439
+ return torch.cat(list_attn, dim=0)
440
+
441
+ else:
442
+ for i, blk in enumerate(self.blocks):
443
+ if i < len(self.blocks) - 1:
444
+ x = blk(x)
445
+ else:
446
+ return blk(x, return_attention=True)
447
+
448
+ def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
449
+ B = x.shape[0]
450
+ x = self.patch_embed(x)
451
+
452
+ cls_tokens = self.cls_token.expand(B, -1, -1)
453
+
454
+ x = torch.cat((cls_tokens, x), dim=1)
455
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
456
+ x = x + pos_embed
457
+ x = self.pos_drop(x)
458
+
459
+ # we will return the [CLS] tokens from the `n` last blocks
460
+ output = []
461
+ for i, blk in enumerate(self.blocks):
462
+ x = blk(x)
463
+ if len(self.blocks) - i <= n:
464
+ # get only CLS token (B x dim)
465
+ output.append(self.norm(x)[:, 0])
466
+ if return_patch_avgpool:
467
+ x = self.norm(x)
468
+ # In addition to the [CLS] tokens from the `n` last blocks, we also return
469
+ # the patch tokens from the last block. This is useful for linear eval.
470
+ output.append(torch.mean(x[:, 1:], dim=1))
471
+ return torch.cat(output, dim=-1)
472
+
473
+ def return_patch_emb_from_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
474
+ """Return intermediate patch embeddings, rather than CLS token, from the last n blocks."""
475
+ B = x.shape[0]
476
+ x = self.patch_embed(x)
477
+
478
+ cls_tokens = self.cls_token.expand(B, -1, -1)
479
+
480
+ x = torch.cat((cls_tokens, x), dim=1)
481
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
482
+ x = x + pos_embed
483
+ x = self.pos_drop(x)
484
+
485
+ # we will return the [CLS] tokens from the `n` last blocks
486
+ output = []
487
+ for i, blk in enumerate(self.blocks):
488
+ x = blk(x)
489
+ if len(self.blocks) - i <= n:
490
+ output.append(self.norm(x)[:, 1:]) # get only CLS token (B x dim)
491
+
492
+ if return_patch_avgpool:
493
+ x = self.norm(x)
494
+ # In addition to the [CLS] tokens from the `n` last blocks, we also return
495
+ # the patch tokens from the last block. This is useful for linear eval.
496
+ output.append(torch.mean(x[:, 1:], dim=1))
497
+ return torch.stack(output, dim=-1) # B x n_patches x dim x n
498
+
499
+
500
+ def deit_tiny(patch_size=16, **kwargs):
501
+ model = VisionTransformer(
502
+ patch_size=patch_size,
503
+ embed_dim=192,
504
+ depth=12,
505
+ num_heads=3,
506
+ mlp_ratio=4,
507
+ qkv_bias=True,
508
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
509
+ **kwargs)
510
+ return model
511
+
512
+
513
+ def deit_small(patch_size=16, **kwargs):
514
+ depth = kwargs.pop("depth") if "depth" in kwargs else 12
515
+ model = VisionTransformer(
516
+ patch_size=patch_size,
517
+ embed_dim=384,
518
+ depth=depth,
519
+ num_heads=6,
520
+ mlp_ratio=4,
521
+ qkv_bias=True,
522
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
523
+ **kwargs
524
+ )
525
+ return model
526
+
527
+
528
+ def vit_base(patch_size=16, **kwargs):
529
+ model = VisionTransformer(
530
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
531
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
532
+ return model
533
+
534
+
535
+ class DINOHead(nn.Module):
536
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
537
+ super().__init__()
538
+ nlayers = max(nlayers, 1)
539
+ if nlayers == 1:
540
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
541
+ else:
542
+ layers = [nn.Linear(in_dim, hidden_dim)]
543
+ if use_bn:
544
+ layers.append(nn.BatchNorm1d(hidden_dim))
545
+ layers.append(nn.GELU())
546
+ for _ in range(nlayers - 2):
547
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
548
+ if use_bn:
549
+ layers.append(nn.BatchNorm1d(hidden_dim))
550
+ layers.append(nn.GELU())
551
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
552
+ self.mlp = nn.Sequential(*layers)
553
+ self.apply(self._init_weights)
554
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
555
+ self.last_layer.weight_g.data.fill_(1)
556
+ if norm_last_layer:
557
+ self.last_layer.weight_g.requires_grad = False
558
+
559
+ def _init_weights(self, m):
560
+ if isinstance(m, nn.Linear):
561
+ trunc_normal_(m.weight, std=.02)
562
+ if isinstance(m, nn.Linear) and m.bias is not None:
563
+ nn.init.constant_(m.bias, 0)
564
+
565
+ def forward(self, x):
566
+ x = self.mlp(x)
567
+ x = nn.functional.normalize(x, dim=-1, p=2)
568
+ x = self.last_layer(x)
569
+ return x
resources/.DS_Store ADDED
Binary file (8.2 kB). View file
 
resources/0053.jpg ADDED
resources/0236.jpg ADDED
resources/0239.jpg ADDED
resources/0403.jpg ADDED
resources/0412.jpg ADDED
resources/ILSVRC2012_test_00005309.jpg ADDED
resources/ILSVRC2012_test_00012622.jpg ADDED
resources/ILSVRC2012_test_00022698.jpg ADDED
resources/ILSVRC2012_test_00040725.jpg ADDED
resources/ILSVRC2012_test_00075738.jpg ADDED
resources/ILSVRC2012_test_00080683.jpg ADDED
resources/ILSVRC2012_test_00085874.jpg ADDED
resources/im052.jpg ADDED
resources/sun_ainjbonxmervsvpv.jpg ADDED
resources/sun_alfntqzssslakmss.jpg ADDED
resources/sun_amnrcxhisjfrliwa.jpg ADDED
resources/sun_bvyxpvkouzlfwwod.jpg ADDED