diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..8935f33351def1143b1450f471f567883a0c0ae9
Binary files /dev/null and b/.DS_Store differ
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..13566b81b018ad684f3a35fee301741b2734c8f4
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e482f5b5a58d40a37161eff92355e13ca618f4cb
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d1a23f47254192ad5c9f5cb8a9c4310c35ae07a3
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a2f7bd01a70dfab81da858710e1095b5f75ff03a
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c05e55baa78ca1129ab64def195fc2944b23f28c
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/selfmask_demo.iml b/.idea/selfmask_demo.iml
new file mode 100644
index 0000000000000000000000000000000000000000..d0876a78d06ac03b5d78c8dcdb95570281c6f1d6
--- /dev/null
+++ b/.idea/selfmask_demo.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/sonarlint/issuestore/index.pb b/.idea/sonarlint/issuestore/index.pb
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.idea/webServers.xml b/.idea/webServers.xml
new file mode 100644
index 0000000000000000000000000000000000000000..faeee9bbfd00bc2fe119a6052475fc0d3e718fcb
--- /dev/null
+++ b/.idea/webServers.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/__pycache__/bilateral_solver.cpython-38.pyc b/__pycache__/bilateral_solver.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d421a9f6a31a3b4665b76dc92bd05c5139232f4f
Binary files /dev/null and b/__pycache__/bilateral_solver.cpython-38.pyc differ
diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c33b934e0f314cbd4faffbbdb6de7631c554b4d
Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ba09295e29046ede81f5ea268f3f251ed750e3
--- /dev/null
+++ b/app.py
@@ -0,0 +1,134 @@
+from argparse import ArgumentParser, Namespace
+from typing import Dict, List, Tuple
+import yaml
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, normalize, resize
+import gradio as gr
+from utils import get_model
+from bilateral_solver import bilateral_solver_output
+import os
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+state_dict: dict = torch.hub.load_state_dict_from_url(
+ "https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt",
+ map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
+)["model"]
+
+parser = ArgumentParser("SelfMask demo")
+parser.add_argument(
+ "--config",
+ type=str,
+ default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
+)
+
+# parser.add_argument(
+# "--p_state_dict",
+# type=str,
+# default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt",
+# )
+#
+# parser.add_argument(
+# "--dataset_name", '-dn', type=str, default="duts",
+# choices=["dut_omron", "duts", "ecssd"]
+# )
+
+# independent variables
+# parser.add_argument("--use_gpu", type=bool, default=True)
+# parser.add_argument('--seed', default=0, type=int)
+# parser.add_argument("--dir_root", type=str, default="..")
+# parser.add_argument("--gpu_id", type=int, default=2)
+# parser.add_argument("--suffix", type=str, default='')
+args: Namespace = parser.parse_args()
+base_args = yaml.safe_load(open(f"{args.config}", 'r'))
+base_args.pop("dataset_name")
+args: dict = vars(args)
+args.update(base_args)
+args: Namespace = Namespace(**args)
+
+model = get_model(arch="maskformer", configs=args).to(device)
+model.load_state_dict(state_dict)
+model.eval()
+
+
+@torch.no_grad()
+def main(
+ image: Image.Image,
+ size: int = 384,
+ max_size: int = 512,
+ mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
+ std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
+):
+ pil_image: Image.Image = resize(image, size=size, max_size=max_size)
+ image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
+ dict_outputs = model(image[None].to(device))
+
+ batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
+ batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
+
+ if len(batch_pred_masks.shape) == 5:
+ # b x n_layers x n_queries x h x w -> b x n_queries x h x w
+ batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
+
+ if batch_objectness is not None:
+ # b x n_layers x n_queries x 1 -> b x n_queries x 1
+ batch_objectness = batch_objectness[:, -1, ...]
+
+ # resize prediction to original resolution
+ # note: upsampling by 4 and cutting the padded region allows for a better result
+ H, W = image.shape[-2:]
+ batch_pred_masks = F.interpolate(
+ batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
+ )[..., :H, :W]
+
+ # iterate over batch dimension
+ for batch_index, pred_masks in enumerate(batch_pred_masks):
+ # n_queries x 1 -> n_queries
+ objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
+ ranks = torch.argsort(objectness, descending=True) # n_queries
+ pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
+ pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
+
+ pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
+ pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
+
+ attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
+ super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
+ return super_imposed_img
+ # return pred_mask_bi
+
+demo = gr.Interface(
+ fn=main,
+ inputs=gr.inputs.Image(type="pil"),
+ outputs="image",
+ examples=[f"resources/{fname}.jpg" for fname in [
+ "0053",
+ "0236",
+ "0239",
+ "0403",
+ "0412",
+ "ILSVRC2012_test_00005309",
+ "ILSVRC2012_test_00012622",
+ "ILSVRC2012_test_00022698",
+ "ILSVRC2012_test_00040725",
+ "ILSVRC2012_test_00075738",
+ "ILSVRC2012_test_00080683",
+ "ILSVRC2012_test_00085874",
+ "im052",
+ "sun_ainjbonxmervsvpv",
+ "sun_alfntqzssslakmss",
+ "sun_amnrcxhisjfrliwa",
+ "sun_bvyxpvkouzlfwwod"
+ ]],
+ title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
+ allow_flagging="never",
+ analytics_enabled=False
+)
+
+demo.launch(
+ # share=True
+)
\ No newline at end of file
diff --git a/bilateral_solver.py b/bilateral_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e75876397756d7ae364b746a15191423c402dd4
--- /dev/null
+++ b/bilateral_solver.py
@@ -0,0 +1,206 @@
+from scipy.sparse import diags
+from scipy.sparse.linalg import cg
+from scipy.sparse import csr_matrix
+import numpy as np
+from skimage.io import imread
+from scipy import ndimage
+import torch
+import PIL.Image as Image
+import os
+from argparse import ArgumentParser, Namespace
+from typing import Dict, Union
+from collections import defaultdict
+import yaml
+import ujson as json
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+
+RGB_TO_YUV = np.array([
+ [0.299, 0.587, 0.114],
+ [-0.168736, -0.331264, 0.5],
+ [0.5, -0.418688, -0.081312]])
+YUV_TO_RGB = np.array([
+ [1.0, 0.0, 1.402],
+ [1.0, -0.34414, -0.71414],
+ [1.0, 1.772, 0.0]])
+YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
+MAX_VAL = 255.0
+
+
+def rgb2yuv(im):
+ return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
+
+
+def yuv2rgb(im):
+ return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
+
+
+def get_valid_idx(valid, candidates):
+ """Find which values are present in a list and where they are located"""
+ locs = np.searchsorted(valid, candidates)
+ # Handle edge case where the candidate is larger than all valid values
+ locs = np.clip(locs, 0, len(valid) - 1)
+ # Identify which values are actually present
+ valid_idx = np.flatnonzero(valid[locs] == candidates)
+ locs = locs[valid_idx]
+ return valid_idx, locs
+
+
+class BilateralGrid(object):
+ def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
+ im_yuv = rgb2yuv(im)
+ # Compute 5-dimensional XYLUV bilateral-space coordinates
+ Iy, Ix = np.mgrid[:im.shape[0], :im.shape[1]]
+ x_coords = (Ix / sigma_spatial).astype(int)
+ y_coords = (Iy / sigma_spatial).astype(int)
+ luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
+ chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
+ coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
+ coords_flat = coords.reshape(-1, coords.shape[-1])
+ self.npixels, self.dim = coords_flat.shape
+ # Hacky "hash vector" for coordinates,
+ # Requires all scaled coordinates be < MAX_VAL
+ self.hash_vec = (MAX_VAL ** np.arange(self.dim))
+ # Construct S and B matrix
+ self._compute_factorization(coords_flat)
+
+ def _compute_factorization(self, coords_flat):
+ # Hash each coordinate in grid to a unique value
+ hashed_coords = self._hash_coords(coords_flat)
+ unique_hashes, unique_idx, idx = \
+ np.unique(hashed_coords, return_index=True, return_inverse=True)
+ # Identify unique set of vertices
+ unique_coords = coords_flat[unique_idx]
+ self.nvertices = len(unique_coords)
+ # Construct sparse splat matrix that maps from pixels to vertices
+ self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
+ # Construct sparse blur matrices.
+ # Note that these represent [1 0 1] blurs, excluding the central element
+ self.blurs = []
+ for d in range(self.dim):
+ blur = 0.0
+ for offset in (-1, 1):
+ offset_vec = np.zeros((1, self.dim))
+ offset_vec[:, d] = offset
+ neighbor_hash = self._hash_coords(unique_coords + offset_vec)
+ valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
+ blur = blur + csr_matrix((np.ones((len(valid_coord),)),
+ (valid_coord, idx)),
+ shape=(self.nvertices, self.nvertices))
+ self.blurs.append(blur)
+
+ def _hash_coords(self, coord):
+ """Hacky function to turn a coordinate into a unique value"""
+ return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
+
+ def splat(self, x):
+ return self.S.dot(x)
+
+ def slice(self, y):
+ return self.S.T.dot(y)
+
+ def blur(self, x):
+ """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
+ assert x.shape[0] == self.nvertices
+ out = 2 * self.dim * x
+ for blur in self.blurs:
+ out = out + blur.dot(x)
+ return out
+
+ def filter(self, x):
+ """Apply bilateral filter to an input x"""
+ return self.slice(self.blur(self.splat(x))) / \
+ self.slice(self.blur(self.splat(np.ones_like(x))))
+
+
+def bistochastize(grid, maxiter=10):
+ """Compute diagonal matrices to bistochastize a bilateral grid"""
+ m = grid.splat(np.ones(grid.npixels))
+ n = np.ones(grid.nvertices)
+ for i in range(maxiter):
+ n = np.sqrt(n * m / grid.blur(n))
+ # Correct m to satisfy the assumption of bistochastization regardless
+ # of how many iterations have been run.
+ m = n * grid.blur(n)
+ Dm = diags(m, 0)
+ Dn = diags(n, 0)
+ return Dn, Dm
+
+
+class BilateralSolver(object):
+ def __init__(self, grid, params):
+ self.grid = grid
+ self.params = params
+ self.Dn, self.Dm = bistochastize(grid)
+
+ def solve(self, x, w):
+ # Check that w is a vector or a nx1 matrix
+ if w.ndim == 2:
+ assert (w.shape[1] == 1)
+ elif w.dim == 1:
+ w = w.reshape(w.shape[0], 1)
+ A_smooth = (self.Dm - self.Dn.dot(self.grid.blur(self.Dn)))
+ w_splat = self.grid.splat(w)
+ A_data = diags(w_splat[:, 0], 0)
+ A = self.params["lam"] * A_smooth + A_data
+ xw = x * w
+ b = self.grid.splat(xw)
+ # Use simple Jacobi preconditioner
+ A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
+ M = diags(1 / A_diag, 0)
+ # Flat initialization
+ y0 = self.grid.splat(xw) / w_splat
+ yhat = np.empty_like(y0)
+ for d in range(x.shape[-1]):
+ yhat[..., d], info = cg(A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"],
+ tol=self.params["cg_tol"])
+ xhat = self.grid.slice(yhat)
+ return xhat
+
+
+def bilateral_solver_output(
+ img: Image.Image,
+ target: np.ndarray,
+ sigma_spatial=16,
+ sigma_luma=16,
+ sigma_chroma=8
+):
+ reference = np.array(img)
+ h, w = target.shape
+ confidence = np.ones((h, w)) * 0.999
+
+ grid_params = {
+ 'sigma_luma': sigma_luma, # Brightness bandwidth
+ 'sigma_chroma': sigma_chroma, # Color bandwidth
+ 'sigma_spatial': sigma_spatial # Spatial bandwidth
+ }
+
+ bs_params = {
+ 'lam': 256, # The strength of the smoothness parameter
+ 'A_diag_min': 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
+ 'cg_tol': 1e-5, # The tolerance on the convergence in PCG
+ 'cg_maxiter': 25 # The number of PCG iterations
+ }
+
+ grid = BilateralGrid(reference, **grid_params)
+
+ t = target.reshape(-1, 1).astype(np.double)
+ c = confidence.reshape(-1, 1).astype(np.double)
+
+ ## output solver, which is a soft value
+ output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
+
+ binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
+ labeled, nr_objects = ndimage.label(binary_solver)
+
+ nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
+ pixel_order = np.argsort(nb_pixel)
+ try:
+ binary_solver = labeled == pixel_order[-2]
+ except:
+ binary_solver = np.ones((h, w), dtype=bool)
+
+ return output_solver, binary_solver
\ No newline at end of file
diff --git a/duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml b/duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0cebb9c24cdb795291eb4df27ef45e1071d5aa9b
--- /dev/null
+++ b/duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml
@@ -0,0 +1,56 @@
+# augmentations
+use_copy_paste: false
+scale_range: [ 0.1, 1.0 ]
+repeat_image: false
+
+# base directories
+dir_ckpt: "/users/gyungin/selfmask/ckpt" # "/work/gyungin/selfmask/ckpt"
+dir_dataset: "/scratch/shared/beegfs/gyungin/datasets"
+
+# clustering
+k: [2, 3, 4]
+clustering_mode: "spectral"
+use_gpu: true # if you want to use gpu-accelerated code for clustering
+scale_factor: 2 # "how much you want to upsample encoder features before clustering"
+
+# dataset
+dataset_name: "duts"
+use_pseudo_masks: true
+train_image_size: 224
+eval_image_size: 224
+n_percent: 100
+n_copy_pastes: null
+pseudo_masks_fp: "/users/gyungin/selfmask/datasets/swav_mocov2_dino_p16_k234.json"
+
+# dataloader:
+batch_size: 8
+num_workers: 4
+pin_memory: true
+
+# networks
+abs_2d_pe_init: false
+arch: "vit_small"
+lateral_connection: false
+learnable_pixel_decoder: false # if False, use the bilinear interpolation
+use_binary_classifier: true # if True, use a binary classifier to get an objectness for each query from transformer decoder
+n_decoder_layers: 6
+n_queries: 20
+num_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+patch_size: 8
+training_method: "dino" # "supervised", "deit", "dino", "mocov2", "swav"
+
+# objective
+loss_every_decoder_layer: true
+weight_dice_loss: 1.0
+weight_focal_loss: 0.0
+
+# optimizer
+lr: 0.000006 # default: 0.00006
+lr_warmup_duration: 0 # 5
+momentum: 0.9
+n_epochs: 12
+weight_decay: 0.01
+optimizer_type: "adamw"
+
+# validation
+benchmarks: null
\ No newline at end of file
diff --git a/networks/__init__.py b/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/networks/__pycache__/__init__.cpython-38.pyc b/networks/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb0f4d3d9de3c2020e8d0d03610ce22dc18186aa
Binary files /dev/null and b/networks/__pycache__/__init__.cpython-38.pyc differ
diff --git a/networks/__pycache__/timm_deit.cpython-38.pyc b/networks/__pycache__/timm_deit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28b6bd1424fc3768496f2bd9185d57dce2116eb3
Binary files /dev/null and b/networks/__pycache__/timm_deit.cpython-38.pyc differ
diff --git a/networks/__pycache__/timm_vit.cpython-38.pyc b/networks/__pycache__/timm_vit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b494e2e88b94c3fd4b43fe416047bf1420239a44
Binary files /dev/null and b/networks/__pycache__/timm_vit.cpython-38.pyc differ
diff --git a/networks/__pycache__/vision_transformer.cpython-38.pyc b/networks/__pycache__/vision_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f28b4f5bc8b1c2f5b50989873ddc04317635dd9e
Binary files /dev/null and b/networks/__pycache__/vision_transformer.cpython-38.pyc differ
diff --git a/networks/maskformer/__pycache__/maskformer.cpython-38.pyc b/networks/maskformer/__pycache__/maskformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d86e6997bd4f473f54befe77d4394cba393612bf
Binary files /dev/null and b/networks/maskformer/__pycache__/maskformer.cpython-38.pyc differ
diff --git a/networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc b/networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f94f2f482506505c296f3fa23af62aadb4c1f93
Binary files /dev/null and b/networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc differ
diff --git a/networks/maskformer/maskformer.py b/networks/maskformer/maskformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..003896f1f4c8d58c361dea9699625d1a29fd4782
--- /dev/null
+++ b/networks/maskformer/maskformer.py
@@ -0,0 +1,267 @@
+from typing import Dict, List
+from math import sqrt, log
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
+from utils import get_model
+
+
+class MaskFormer(nn.Module):
+ def __init__(
+ self,
+ n_queries: int = 100,
+ arch: str = "vit_small",
+ patch_size: int = 8,
+ training_method: str = "dino",
+ n_decoder_layers: int = 6,
+ normalize_before: bool = False,
+ return_intermediate: bool = False,
+ learnable_pixel_decoder: bool = False,
+ lateral_connection: bool = False,
+ scale_factor: int = 2,
+ abs_2d_pe_init: bool = False,
+ use_binary_classifier: bool = False
+ ):
+ """Define a encoder and decoder along with queries to be learned through the decoder."""
+ super(MaskFormer, self).__init__()
+
+ if arch == "vit_small":
+ self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
+ n_dims: int = self.encoder.n_embs
+ n_heads: int = self.encoder.n_heads
+ mlp_ratio: int = self.encoder.mlp_ratio
+ else:
+ self.encoder = get_model(arch=arch, training_method=training_method)
+ n_dims_resnet: int = self.encoder.n_embs
+ n_dims: int = 384
+ n_heads: int = 6
+ mlp_ratio: int = 4
+ self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)
+
+ decoder_layer = TransformerDecoderLayer(
+ n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
+ )
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ n_decoder_layers,
+ norm=nn.LayerNorm(n_dims),
+ return_intermediate=return_intermediate
+ )
+
+ self.query_embed = nn.Embedding(n_queries, n_dims).weight # initialized with gaussian(0, 1)
+
+ if use_binary_classifier:
+ # self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
+ # self.linear_classifier = nn.Linear(n_dims, 1)
+ self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
+ # self.norm = nn.LayerNorm(n_dims)
+ else:
+ # self.ffn = None
+ # self.linear_classifier = None
+ # self.norm = None
+ self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
+ self.linear_classifier = nn.Linear(n_dims, 2)
+ self.norm = nn.LayerNorm(n_dims)
+
+ self.arch = arch
+ self.use_binary_classifier = use_binary_classifier
+ self.lateral_connection = lateral_connection
+ self.learnable_pixel_decoder = learnable_pixel_decoder
+ self.scale_factor = scale_factor
+
+ # copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
+ @staticmethod
+ def positional_encoding_2d(n_dims: int, height: int, width: int):
+ """
+ :param n_dims: dimension of the model
+ :param height: height of the positions
+ :param width: width of the positions
+ :return: d_model*height*width position matrix
+ """
+ if n_dims % 4 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dimension (got dim={:d})".format(n_dims))
+ pe = torch.zeros(n_dims, height, width)
+ # Each dimension use half of d_model
+ d_model = int(n_dims / 2)
+ div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
+ pos_w = torch.arange(0., width).unsqueeze(1)
+ pos_h = torch.arange(0., height).unsqueeze(1)
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+
+ return pe
+
+ def forward_encoder(self, x: torch.Tensor):
+ """
+ :param x: b x c x h x w
+ :return patch_tokens: b x depth x hw x n_dims
+ """
+ if self.arch == "vit_small":
+ encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x) # [:, 1:, :]
+ all_patch_tokens: List[torch.Tensor] = list()
+ for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
+ patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :] # b x hw x n_dims
+ all_patch_tokens.append(patch_tokens)
+
+ all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0) # depth x b x hw x n_dims
+ all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2) # b x depth x n_dims x hw
+ return all_patch_tokens
+ else:
+ encoder_outputs = self.linear_layer(self.encoder(x)[-1]) # b x n_dims x h x w
+ return encoder_outputs
+
+ def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
+ """Forward transformer decoder given patch tokens from the encoder's last layer.
+ :param patch_tokens: b x n_dims x hw -> hw x b x n_dims
+ :param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
+ between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
+ experiment.
+ :return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
+ """
+ b = patch_tokens.shape[0]
+ patch_tokens = patch_tokens.permute(2, 0, 1) # b x n_dims x hw -> hw x b x n_dims
+
+ # n_queries x n_dims -> n_queries x b x n_dims
+ queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
+ queries: torch.Tensor = self.decoder.forward(
+ tgt=torch.zeros_like(queries),
+ memory=patch_tokens,
+ query_pos=queries
+ ).squeeze(dim=0)
+
+ if len(queries.shape) == 3:
+ queries: torch.Tensor = queries.permute(1, 0, 2) # n_queries x b x n_dims -> b x n_queries x n_dims
+ elif len(queries.shape) == 4:
+ # n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
+ queries: torch.Tensor = queries.permute(2, 0, 1, 3)
+ return queries
+
+ def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
+ """ Upsample patch tokens by self.scale_factor and produce mask predictions
+ :param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
+ :param queries: b x n_queries x n_dims
+ :return mask_predictions: b x n_queries x h x w
+ """
+
+ if input_size is None:
+ # assume square shape features
+ hw = patch_tokens.shape[-1]
+ h = w = int(sqrt(hw))
+ else:
+ # arbitrary shape features
+ h, w = input_size
+ patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)
+
+ assert len(patch_tokens.shape) == 4
+ patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
+ return patch_tokens
+
+ def forward(self, x, encoder_only=False, skip_decoder: bool = False):
+ """
+ x: b x c x h x w
+ patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
+ query_emb: n_queries x n_dims -> n_queries x b x n_dims
+ """
+ dict_outputs: dict = dict()
+
+ # b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
+ features: torch.Tensor = self.forward_encoder(x)
+
+ if self.arch == "vit_small":
+ # extract the last layer for decoder input
+ last_layer_features: torch.Tensor = features[:, -1, ...] # b x n_dims x hw
+ else:
+ # transform the shape of the features to the one compatible with transformer decoder
+ b, n_dims, h, w = features.shape
+ last_layer_features: torch.Tensor = features.view(b, n_dims, h * w) # b x n_dims x hw
+
+ if encoder_only:
+ _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
+ _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
+
+ b, n_dims, hw = last_layer_features.shape
+ dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
+ return dict_outputs
+
+ # transformer decoder forward
+ queries: torch.Tensor = self.forward_transformer_decoder(
+ last_layer_features,
+ skip_decoder=skip_decoder
+ ) # b x n_queries x n_dims or b x n_layers x n_queries x n_dims
+
+ # pixel decoder forward (upsampling the patch tokens by self.scale_factor)
+ if self.arch == "vit_small":
+ _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
+ _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
+ else:
+ _h, _w = h, w
+ features: torch.Tensor = self.forward_pixel_decoder(
+ patch_tokens=features if self.lateral_connection else last_layer_features,
+ input_size=(_h, _w)
+ ) # b x n_dims x h x w
+
+ # queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
+ # features: b x n_dims x h x w
+ # mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
+ if len(queries.shape) == 3:
+ mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
+ else:
+ if self.use_binary_classifier:
+ mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
+ else:
+ mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))
+
+ if self.use_binary_classifier:
+ # queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
+ queries = queries.permute(1, 0, 2, 3)
+ objectness: List[torch.Tensor] = list()
+ for n_layer, queries_per_layer in enumerate(queries): # queries_per_layer: b x n_queries x n_dims
+ # objectness_per_layer = self.linear_classifier(
+ # self.ffn(self.norm(queries_per_layer))
+ # ) # b x n_queries x 1
+ objectness_per_layer = self.ffn(queries_per_layer) # b x n_queries x 1
+ objectness.append(objectness_per_layer)
+ # n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
+ objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
+ dict_outputs.update({
+ "objectness": torch.sigmoid(objectness),
+ "mask_pred": mask_pred
+ })
+
+ return dict_outputs
+
+
+class MLP(nn.Module):
+ """Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class UpsampleBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
+ super(UpsampleBlock, self).__init__()
+ self.block = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
+ nn.GroupNorm(n_groups, out_channels),
+ nn.ReLU()
+ )
+ self.scale_factor = scale_factor
+
+ def forward(self, x):
+ return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")
\ No newline at end of file
diff --git a/networks/maskformer/positional_embedding.py b/networks/maskformer/positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d117e05e5a133424c9f9f7b883da129b9c4d45
--- /dev/null
+++ b/networks/maskformer/positional_embedding.py
@@ -0,0 +1,48 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
\ No newline at end of file
diff --git a/networks/maskformer/transformer_decoder.py b/networks/maskformer/transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6111be730f7f063281fab0199f6dd413ba50e9ba
--- /dev/null
+++ b/networks/maskformer/transformer_decoder.py
@@ -0,0 +1,376 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
+"""
+Transformer class.
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu", # noel - dino used GeLU
+ normalize_before=False,
+ return_intermediate_dec=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, query_embed, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ if mask is not None:
+ mask = mask.flatten(1)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ hs = self.decoder(
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
+ )
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ output = layer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers: nn.ModuleList = _get_clones(decoder_layer, num_layers)
+ self.num_layers: int = num_layers
+ self.norm = norm
+ self.return_intermediate: bool = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(tgt, query_pos)
+
+ tgt2 = self.self_attn(
+ q,
+ k,
+ value=tgt,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+ return self.forward_post(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
\ No newline at end of file
diff --git a/networks/module_helper.py b/networks/module_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9a55e55b71f1d4f6951c625d90367a77d8f6d9
--- /dev/null
+++ b/networks/module_helper.py
@@ -0,0 +1,176 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# Author: Donny You (youansheng@gmail.com)
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ from urllib import urlretrieve
+except ImportError:
+ from urllib.request import urlretrieve
+
+
+class FixedBatchNorm(nn.BatchNorm2d):
+ def forward(self, input):
+ return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
+
+
+class ModuleHelper(object):
+ @staticmethod
+ def BNReLU(num_features, norm_type=None, **kwargs):
+ if norm_type == 'batchnorm':
+ return nn.Sequential(
+ nn.BatchNorm2d(num_features, **kwargs),
+ nn.ReLU()
+ )
+ elif norm_type == 'encsync_batchnorm':
+ from encoding.nn import BatchNorm2d
+ return nn.Sequential(
+ BatchNorm2d(num_features, **kwargs),
+ nn.ReLU()
+ )
+ elif norm_type == 'instancenorm':
+ return nn.Sequential(
+ nn.InstanceNorm2d(num_features, **kwargs),
+ nn.ReLU()
+ )
+ elif norm_type == 'fixed_batchnorm':
+ return nn.Sequential(
+ FixedBatchNorm(num_features, **kwargs),
+ nn.ReLU()
+ )
+ else:
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
+
+ @staticmethod
+ def BatchNorm3d(norm_type=None, ret_cls=False):
+ if norm_type == 'batchnorm':
+ return nn.BatchNorm3d
+ elif norm_type == 'encsync_batchnorm':
+ from encoding.nn import BatchNorm3d
+ return BatchNorm3d
+ elif norm_type == 'instancenorm':
+ return nn.InstanceNorm3d
+ else:
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
+
+ @staticmethod
+ def BatchNorm2d(norm_type=None, ret_cls=False):
+ if norm_type == 'batchnorm':
+ return nn.BatchNorm2d
+ elif norm_type == 'encsync_batchnorm':
+ from encoding.nn import BatchNorm2d
+ return BatchNorm2d
+
+ elif norm_type == 'instancenorm':
+ return nn.InstanceNorm2d
+ else:
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
+
+ @staticmethod
+ def BatchNorm1d(norm_type=None, ret_cls=False):
+ if norm_type == 'batchnorm':
+ return nn.BatchNorm1d
+ elif norm_type == 'encsync_batchnorm':
+ from encoding.nn import BatchNorm1d
+ return BatchNorm1d
+ elif norm_type == 'instancenorm':
+ return nn.InstanceNorm1d
+ else:
+ raise ValueError('Not support BN type: {}.'.format(norm_type))
+
+ @staticmethod
+ def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
+ if pretrained is None:
+ return model
+
+ if not os.path.exists(pretrained):
+ pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation")
+ if os.path.exists(pretrained):
+ pass
+ else:
+ raise FileNotFoundError('{} not exists.'.format(pretrained))
+
+ print('Loading pretrained model:{}'.format(pretrained))
+ if all_match:
+ pretrained_dict = torch.load(pretrained, map_location=map_location)
+ model_dict = model.state_dict()
+ load_dict = dict()
+ for k, v in pretrained_dict.items():
+ if 'prefix.{}'.format(k) in model_dict:
+ load_dict['prefix.{}'.format(k)] = v
+ else:
+ load_dict[k] = v
+ model.load_state_dict(load_dict)
+
+ else:
+ pretrained_dict = torch.load(pretrained)
+ model_dict = model.state_dict()
+ load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ print('Matched Keys: {}'.format(load_dict.keys()))
+ model_dict.update(load_dict)
+ model.load_state_dict(model_dict)
+
+ return model
+
+ @staticmethod
+ def load_url(url, map_location=None):
+ model_dir = os.path.join('~', '.TorchCV', 'model')
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+
+ filename = url.split('/')[-1]
+ cached_file = os.path.join(model_dir, filename)
+ if not os.path.exists(cached_file):
+ print('Downloading: "{}" to {}\n'.format(url, cached_file))
+ urlretrieve(url, cached_file)
+
+ print('Loading pretrained model:{}'.format(cached_file))
+ return torch.load(cached_file, map_location=map_location)
+
+ @staticmethod
+ def constant_init(module, val, bias=0):
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+ @staticmethod
+ def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+ @staticmethod
+ def normal_init(module, mean=0, std=1, bias=0):
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+ @staticmethod
+ def uniform_init(module, a=0, b=1, bias=0):
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+ @staticmethod
+ def kaiming_init(module,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
diff --git a/networks/resnet.py b/networks/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b8e791d2248c0affdacc8f64af03094a8c9137
--- /dev/null
+++ b/networks/resnet.py
@@ -0,0 +1,60 @@
+import os
+
+import torch
+import torch.nn as nn
+from .resnet_backbone import ResNetBackbone
+
+
+class ResNet50(nn.Module):
+ def __init__(
+ self,
+ weight_type: str = "supervised",
+ use_dilated_resnet: bool = True
+ ):
+ super(ResNet50, self).__init__()
+ self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
+ self.n_embs = self.network.num_features
+ self.use_dilated_resnet = use_dilated_resnet
+ self._load_pretrained(weight_type)
+
+ def _load_pretrained(self, training_method: str) -> None:
+ curr_state_dict = self.network.state_dict()
+ if training_method == "mocov2":
+ state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
+
+ for k in list(state_dict.keys()):
+ if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
+ state_dict.pop(k)
+
+ elif training_method == "swav":
+ state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
+ for k in list(state_dict.keys()):
+ if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
+ state_dict.pop(k)
+
+ elif training_method == "supervised":
+ # Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
+ # for k in list(curr_state_dict.keys()):
+ # if k.find("num_batches_tracked") != -1:
+ # curr_state_dict.pop(k)
+ # state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")
+
+ from torchvision.models.resnet import resnet50
+ resnet50_supervised = resnet50(True, True)
+ state_dict = resnet50_supervised.state_dict()
+ for k in list(state_dict.keys()):
+ if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
+ state_dict.pop(k)
+
+ assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
+ for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
+ curr_state_dict[k_curr].copy_(state_dict[k])
+ print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
+ return
+
+ def forward(self, x):
+ return self.network(x)
+
+
+if __name__ == '__main__':
+ resnet = ResNet50("mocov2")
diff --git a/networks/resnet_backbone.py b/networks/resnet_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..498ae808f02f60f109eb2c2302e2783360eb6db4
--- /dev/null
+++ b/networks/resnet_backbone.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# Author: Donny You(youansheng@gmail.com)
+
+
+import torch.nn as nn
+from networks.resnet_models import *
+
+
+class NormalResnetBackbone(nn.Module):
+ def __init__(self, orig_resnet):
+ super(NormalResnetBackbone, self).__init__()
+
+ self.num_features = 2048
+ # take pretrained resnet, except AvgPool and FC
+ self.prefix = orig_resnet.prefix
+ self.maxpool = orig_resnet.maxpool
+ self.layer1 = orig_resnet.layer1
+ self.layer2 = orig_resnet.layer2
+ self.layer3 = orig_resnet.layer3
+ self.layer4 = orig_resnet.layer4
+
+ def get_num_features(self):
+ return self.num_features
+
+ def forward(self, x):
+ tuple_features = list()
+ x = self.prefix(x)
+ x = self.maxpool(x)
+ x = self.layer1(x)
+ tuple_features.append(x)
+ x = self.layer2(x)
+ tuple_features.append(x)
+ x = self.layer3(x)
+ tuple_features.append(x)
+ x = self.layer4(x)
+ tuple_features.append(x)
+
+ return tuple_features
+
+
+class DilatedResnetBackbone(nn.Module):
+ def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
+ super(DilatedResnetBackbone, self).__init__()
+
+ self.num_features = 2048
+ from functools import partial
+
+ if dilate_scale == 8:
+ orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
+ if multi_grid is None:
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
+ else:
+ for i, r in enumerate(multi_grid):
+ orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
+
+ elif dilate_scale == 16:
+ if multi_grid is None:
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
+ else:
+ for i, r in enumerate(multi_grid):
+ orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
+
+ # Take pretrained resnet, except AvgPool and FC
+ self.prefix = orig_resnet.prefix
+ self.maxpool = orig_resnet.maxpool
+ self.layer1 = orig_resnet.layer1
+ self.layer2 = orig_resnet.layer2
+ self.layer3 = orig_resnet.layer3
+ self.layer4 = orig_resnet.layer4
+
+ def _nostride_dilate(self, m, dilate):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ # the convolution with stride
+ if m.stride == (2, 2):
+ m.stride = (1, 1)
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate // 2, dilate // 2)
+ m.padding = (dilate // 2, dilate // 2)
+ # other convoluions
+ else:
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate, dilate)
+ m.padding = (dilate, dilate)
+
+ def get_num_features(self):
+ return self.num_features
+
+ def forward(self, x):
+ tuple_features = list()
+
+ x = self.prefix(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ tuple_features.append(x)
+ x = self.layer2(x)
+ tuple_features.append(x)
+ x = self.layer3(x)
+ tuple_features.append(x)
+ x = self.layer4(x)
+ tuple_features.append(x)
+
+ return tuple_features
+
+
+def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'):
+ arch = backbone
+
+ if arch == 'resnet18':
+ orig_resnet = resnet18(pretrained=pretrained)
+ arch_net = NormalResnetBackbone(orig_resnet)
+ arch_net.num_features = 512
+
+ elif arch == 'resnet18_dilated8':
+ orig_resnet = resnet18(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
+ arch_net.num_features = 512
+
+ elif arch == 'resnet34':
+ orig_resnet = resnet34(pretrained=pretrained)
+ arch_net = NormalResnetBackbone(orig_resnet)
+ arch_net.num_features = 512
+
+ elif arch == 'resnet34_dilated8':
+ orig_resnet = resnet34(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
+ arch_net.num_features = 512
+
+ elif arch == 'resnet34_dilated16':
+ orig_resnet = resnet34(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
+ arch_net.num_features = 512
+
+ elif arch == 'resnet50':
+ orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
+ arch_net = NormalResnetBackbone(orig_resnet)
+
+ elif arch == 'resnet50_dilated8':
+ orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
+
+ elif arch == 'resnet50_dilated16':
+ orig_resnet = resnet50(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
+
+ elif arch == 'deepbase_resnet50':
+ if pretrained:
+ pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
+ orig_resnet = deepbase_resnet50(pretrained=pretrained)
+ arch_net = NormalResnetBackbone(orig_resnet)
+
+ elif arch == 'deepbase_resnet50_dilated8':
+ if pretrained:
+ pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
+ # pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth"
+ orig_resnet = deepbase_resnet50(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
+
+ elif arch == 'deepbase_resnet50_dilated16':
+ orig_resnet = deepbase_resnet50(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
+
+ elif arch == 'resnet101':
+ orig_resnet = resnet101(pretrained=pretrained)
+ arch_net = NormalResnetBackbone(orig_resnet)
+
+ elif arch == 'resnet101_dilated8':
+ orig_resnet = resnet101(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
+
+ elif arch == 'resnet101_dilated16':
+ orig_resnet = resnet101(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
+
+ elif arch == 'deepbase_resnet101':
+ orig_resnet = deepbase_resnet101(pretrained=pretrained)
+ arch_net = NormalResnetBackbone(orig_resnet)
+
+ elif arch == 'deepbase_resnet101_dilated8':
+ if pretrained:
+ pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth'
+ orig_resnet = deepbase_resnet101(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
+
+ elif arch == 'deepbase_resnet101_dilated16':
+ orig_resnet = deepbase_resnet101(pretrained=pretrained)
+ arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
+
+ else:
+ raise Exception('Architecture undefined!')
+
+ return arch_net
diff --git a/networks/resnet_models.py b/networks/resnet_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5fe1285a0eb657cdc6f865edcf95023db946dcb
--- /dev/null
+++ b/networks/resnet_models.py
@@ -0,0 +1,273 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# Author: Donny You(youansheng@gmail.com)
+import math
+import torch.nn as nn
+from collections import OrderedDict
+from .module_helper import ModuleHelper
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/backbones/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/backbones/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/backbones/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/backbones/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/backbones/resnet152-b121ed2d.pth'
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers, width_multiplier=1.0, num_classes=1000, deep_base=False, norm_type=None):
+ super(ResNet, self).__init__()
+ self.inplanes = 128 if deep_base else int(64 * width_multiplier)
+ self.width_multiplier = width_multiplier
+ if deep_base:
+ self.prefix = nn.Sequential(OrderedDict([
+ ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
+ ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
+ ('relu1', nn.ReLU(inplace=False)),
+ ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
+ ('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
+ ('relu2', nn.ReLU(inplace=False)),
+ ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
+ ('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
+ ('relu3', nn.ReLU(inplace=False))]
+ ))
+ else:
+ self.prefix = nn.Sequential(OrderedDict([
+ ('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
+ ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
+ ('relu', nn.ReLU(inplace=False))]
+ ))
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) # change.
+
+ self.layer1 = self._make_layer(block, int(64 * width_multiplier), layers[0], norm_type=norm_type)
+ self.layer2 = self._make_layer(block, int(128 * width_multiplier), layers[1], stride=2, norm_type=norm_type)
+ self.layer3 = self._make_layer(block, int(256 * width_multiplier), layers[2], stride=2, norm_type=norm_type)
+ self.layer4 = self._make_layer(block, int(512 * width_multiplier), layers[3], stride=2, norm_type=norm_type)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(int(512 * block.expansion * width_multiplier), num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ ModuleHelper.BatchNorm2d(norm_type=norm_type)(int(planes * block.expansion * self.width_multiplier)),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes,
+ stride, downsample, norm_type=norm_type))
+
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, norm_type=norm_type))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.prefix(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ norm_type (str): choose norm type
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type,
+ width_multiplier=kwargs["width_multiplier"])
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-152 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
+
+
+def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
+ """Constructs a ResNet-152 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
+ model = ModuleHelper.load_model(model, pretrained=pretrained)
+ return model
diff --git a/networks/timm_deit.py b/networks/timm_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..95754e99af6f3c18ae649e807f9d0ef800d466f5
--- /dev/null
+++ b/networks/timm_deit.py
@@ -0,0 +1,254 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+import math
+import torch
+import torch.nn as nn
+from functools import partial
+
+from networks.timm_vit import VisionTransformer, _cfg
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_
+
+
+__all__ = [
+ 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
+ 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
+ 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
+ 'deit_base_distilled_patch16_384',
+]
+
+
+class DistilledVisionTransformer(VisionTransformer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ num_patches = self.patch_embed.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.dist_token, std=.02)
+ trunc_normal_(self.pos_embed, std=.02)
+ self.head_dist.apply(self._init_weights)
+
+ def forward_features(self, x):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to add the dist_token
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+ return x[:, 0], x[:, 1]
+
+ def forward(self, x):
+ x, x_dist = self.forward_features(x)
+ x = self.head(x)
+ x_dist = self.head_dist(x_dist)
+ if self.training:
+ return x, x_dist
+ else:
+ # during inference, return the average of both classifier predictions
+ return (x + x_dist) / 2
+
+ def interpolate_pos_encoding(self, x, pos_embed):
+ """Interpolate the learnable positional encoding to match the number of patches.
+
+ x: B x (1 + 1 + N patches) x dim_embedding
+ pos_embed: B x (1 + 1 + N patches) x dim_embedding
+
+ return interpolated positional embedding
+ """
+
+ npatch = x.shape[1] - 2 # (H // patch_size * W // patch_size)
+ N = pos_embed.shape[1] - 2 # 784 (= 28 x 28)
+
+ if npatch == N:
+ return pos_embed
+
+ class_emb, distil_token, pos_embed = pos_embed[:, 0], pos_embed[:, 1], pos_embed[:, 2:] # a learnable CLS token, learnable position embeddings
+
+ dim = x.shape[-1] # dimension of embeddings
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
+ scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
+ recompute_scale_factor=True,
+ mode='bicubic'
+ )
+ # print("pos_embed", pos_embed.shape, npatch, N, math.sqrt(npatch/N), math.sqrt(npatch/N) * int(math.sqrt(N)))
+ # exit(12)
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ pos_embed = torch.cat((class_emb.unsqueeze(0), distil_token.unsqueeze(0), pos_embed), dim=1)
+ return pos_embed
+
+ def get_tokens(
+ self,
+ x,
+ layers: list,
+ patch_tokens: bool = False,
+ norm: bool = True,
+ input_tokens: bool = False,
+ post_pe: bool = False
+ ):
+ """Return intermediate tokens."""
+ list_tokens: list = []
+
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ dist_token = self.dist_token.expand(B, -1, -1)
+
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+
+ if input_tokens:
+ list_tokens.append(x)
+
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+ x = x + pos_embed
+
+ if post_pe:
+ list_tokens.append(x)
+
+ x = self.pos_drop(x)
+
+ for i, blk in enumerate(self.blocks):
+ x = blk(x) # B x # patches x dim
+ if layers is None or i in layers:
+ list_tokens.append(self.norm(x) if norm else x)
+
+ tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
+
+ if not patch_tokens:
+ return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
+
+ else:
+ return torch.cat((tokens[:, :, 0, :].unsqueeze(dim=2), tokens[:, :, 2:, :]), dim=2) # exclude distil token.
+
+
+@register_model
+def deit_tiny_patch16_224(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_small_patch16_224(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_patch16_224(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
+ model = DistilledVisionTransformer(
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
+ model = DistilledVisionTransformer(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
+ model = DistilledVisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_patch16_384(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
+ model = DistilledVisionTransformer(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
diff --git a/networks/timm_vit.py b/networks/timm_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8aea2eb1fc17066c0d9b1a6541461dff12a46b
--- /dev/null
+++ b/networks/timm_vit.py
@@ -0,0 +1,819 @@
+""" Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of Vision Transformers as described in
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+DeiT model defs and weights from https://github.com/facebookresearch/deit,
+paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import logging
+from functools import partial
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
+from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
+from timm.models.registry import register_model
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models (my experiments)
+ 'vit_small_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
+ ),
+
+ # patch models (weights ported from official Google JAX impl)
+ 'vit_base_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+ 'vit_base_patch32_224': _cfg(
+ url='', # no official model weights for this combo, only for in21k
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_base_patch16_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+ 'vit_base_patch32_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+ 'vit_large_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch32_224': _cfg(
+ url='', # no official model weights for this combo, only for in21k
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch16_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+ 'vit_large_patch32_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+
+ # patch models, imagenet21k (weights ported from official Google JAX impl)
+ 'vit_base_patch16_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_base_patch32_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch16_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch32_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_huge_patch14_224_in21k': _cfg(
+ hf_hub='timm/vit_huge_patch14_224_in21k',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+
+ # deit models (FB weights)
+ 'vit_deit_tiny_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
+ 'vit_deit_small_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
+ 'vit_deit_base_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
+ 'vit_deit_base_patch16_384': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_deit_tiny_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
+ classifier=('head', 'head_dist')),
+ 'vit_deit_small_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
+ classifier=('head', 'head_dist')),
+ 'vit_deit_base_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
+ classifier=('head', 'head_dist')),
+ 'vit_deit_base_distilled_patch16_384': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
+
+ # ViT ImageNet-21K-P pretraining
+ 'vit_base_patch16_224_miil_in21k': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
+ ),
+ 'vit_base_patch16_224_miil': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
+ ),
+}
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
+ - https://arxiv.org/abs/2012.12877
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
+ act_layer=None, weight_init='',
+ # noel
+ img_size_eval: int = 224):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ distilled (bool): model includes a distillation token and head as in DeiT models
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ weight_init: (str): weight init scheme
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 2 if distilled else 1
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.Sequential(*[
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Representation layer
+ if representation_size and not distilled:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ # Classifier head(s)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ # Weight init
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
+ trunc_normal_(self.pos_embed, std=.02)
+ if self.dist_token is not None:
+ trunc_normal_(self.dist_token, std=.02)
+ if weight_init.startswith('jax'):
+ # leave cls token as zeros to match jax impl
+ for n, m in self.named_modules():
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
+ else:
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(_init_vit_weights)
+
+ # noel
+ self.depth = depth
+ self.distilled = distilled
+ self.patch_size = patch_size
+ self.patch_embed.img_size = (img_size_eval, img_size_eval)
+
+ def _init_weights(self, m):
+ # this fn left here for compat with downstream users
+ _init_vit_weights(m)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token', 'dist_token'}
+
+ def get_classifier(self):
+ if self.dist_token is None:
+ return self.head
+ else:
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ if self.num_tokens == 2:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ if self.dist_token is None:
+ x = torch.cat((cls_token, x), dim=1)
+ else:
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+ x = self.blocks(x)
+ x = self.norm(x)
+ if self.dist_token is None:
+ return self.pre_logits(x[:, 0])
+ else:
+ return x[:, 0], x[:, 1]
+
+ # def forward(self, x):
+ # x = self.forward_features(x)
+ # if self.head_dist is not None:
+ # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
+ # if self.training and not torch.jit.is_scripting():
+ # # during inference, return the average of both classifier predictions
+ # return x, x_dist
+ # else:
+ # return (x + x_dist) / 2
+ # else:
+ # x = self.head(x)
+ # return x
+
+ # noel - start
+ def make_square(self, x: torch.Tensor):
+ """Pad some pixels to make the input size divisible by the patch size."""
+ B, _, H_0, W_0 = x.shape
+ pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
+ pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=x.mean())
+
+ H_p, W_p = H_0 + pad_h, W_0 + pad_w
+ x = nn.functional.pad(x, (0, H_p - W_p, 0, 0) if H_p > W_p else (0, 0, 0, W_p - H_p), value=x.mean())
+ return x
+
+ def interpolate_pos_encoding(self, x, pos_embed, size):
+ """Interpolate the learnable positional encoding to match the number of patches.
+
+ x: B x (1 + N patches) x dim_embedding
+ pos_embed: B x (1 + N patches) x dim_embedding
+
+ return interpolated positional embedding
+ """
+ npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
+ N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
+ if npatch == N:
+ return pos_embed
+ class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
+
+ dim = x.shape[-1] # dimension of embeddings
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
+ size=size,
+ mode='bicubic',
+ align_corners=False
+ )
+
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
+ return pos_embed
+
+ # def interpolate_pos_encoding(self, x, pos_embed):
+ # """Interpolate the learnable positional encoding to match the number of patches.
+ #
+ # x: B x (1 + N patches) x dim_embedding
+ # pos_embed: B x (1 + N patches) x dim_embedding
+ #
+ # return interpolated positional embedding
+ # """
+ # npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
+ # N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
+ # if npatch == N:
+ # return pos_embed
+ # class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
+ #
+ # dim = x.shape[-1] # dimension of embeddings
+ # pos_embed = nn.functional.interpolate(
+ # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
+ # scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer.
+ # recompute_scale_factor=True,
+ # mode='bicubic',
+ # align_corners=False
+ # )
+ #
+ # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ # pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
+ # return pos_embed
+
+ def prepare_tokens(self, x):
+ B, nc, h, w = x.shape
+ patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
+ x = self.patch_embed(x) # patch linear embedding
+
+ # add the [CLS] token to the embed patch tokens
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add positional encoding to each token
+ x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
+ return self.pos_drop(x)
+
+ def get_tokens(
+ self,
+ x,
+ layers: list,
+ patch_tokens: bool = False,
+ norm: bool = True,
+ input_tokens: bool = False,
+ post_pe: bool = False
+ ):
+ """Return intermediate tokens."""
+ list_tokens: list = []
+
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if input_tokens:
+ list_tokens.append(x)
+
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+ x = x + pos_embed
+
+ if post_pe:
+ list_tokens.append(x)
+
+ x = self.pos_drop(x)
+
+ for i, blk in enumerate(self.blocks):
+ x = blk(x) # B x # patches x dim
+ if layers is None or i in layers:
+ list_tokens.append(self.norm(x) if norm else x)
+
+ tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
+
+ if not patch_tokens:
+ return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
+
+ else:
+ return tokens
+
+ def forward(self, x, layer: str = None):
+ x = self.prepare_tokens(x)
+
+ features: dict = {}
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ features[f"layer{i + 1}"] = self.norm(x)
+
+ if layer is not None:
+ return features[layer]
+ else:
+ return features["layer12"]
+ # noel - end
+
+
+def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
+ """ ViT weight initialization
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
+ """
+ if isinstance(m, nn.Linear):
+ if n.startswith('head'):
+ nn.init.zeros_(m.weight)
+ nn.init.constant_(m.bias, head_bias)
+ elif n.startswith('pre_logits'):
+ lecun_normal_(m.weight)
+ nn.init.zeros_(m.bias)
+ else:
+ if jax_impl:
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ if 'mlp' in n:
+ nn.init.normal_(m.bias, std=1e-6)
+ else:
+ nn.init.zeros_(m.bias)
+ else:
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif jax_impl and isinstance(m, nn.Conv2d):
+ # NOTE conv was left to pytorch default in my original init
+ lecun_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.zeros_(m.bias)
+ nn.init.ones_(m.weight)
+
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
+ ntok_new = posemb_new.shape[1]
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ if not len(gs_new): # backwards compatibility
+ gs_new = [int(math.sqrt(ntok_new))] * 2
+ assert len(gs_new) >= 2
+ _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ if 'model' in state_dict:
+ # For deit models
+ state_dict = state_dict['model']
+ for k, v in state_dict.items():
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
+ # For old models that I trained prior to conv based patchification
+ O, I, H, W = model.patch_embed.proj.weight.shape
+ v = v.reshape(O, -1, H, W)
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
+ # To resize pos embedding when using model at different size from pretrained weights
+ v = resize_pos_embed(
+ v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
+ default_cfg = default_cfg or default_cfgs[variant]
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ # NOTE this extra code to support handling of repr size for in21k pretrained models
+ default_num_classes = default_cfg['num_classes']
+ num_classes = kwargs.get('num_classes', default_num_classes)
+ repr_size = kwargs.pop('representation_size', None)
+ if repr_size is not None and num_classes != default_num_classes:
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
+ _logger.warning("Removing representation layer for fine-tuning.")
+ repr_size = None
+
+ model = build_model_with_cfg(
+ VisionTransformer, variant, pretrained,
+ default_cfg=default_cfg,
+ representation_size=repr_size,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch16_224(pretrained=False, **kwargs):
+ """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
+ NOTE:
+ * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
+ * this model does not have a bias for QKV (unlike the official ViT and DeiT models)
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
+ qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
+ if pretrained:
+ # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
+ model_kwargs.setdefault('qk_scale', 768 ** -0.5)
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_384(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_384(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_224(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_224(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_384(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_384(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: converted weights not currently available, too large for github release hosting.
+ """
+ model_kwargs = dict(
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_small_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_patch16_224(pretrained=False, **kwargs):
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_patch16_384(pretrained=False, **kwargs):
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_miil(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
+ return model
\ No newline at end of file
diff --git a/networks/vision_transformer.py b/networks/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e1172f4dd482cd6789ec231902e8fc979f0cd4f
--- /dev/null
+++ b/networks/vision_transformer.py
@@ -0,0 +1,569 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Mostly copy-paste from timm library.
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+from typing import Optional
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.",
+ stacklevel=2
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5 # square root of dimension for normalisation
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape # B x (cls token + # patch tokens) x dim
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ # qkv: 3 x B x Nh x (cls token + # patch tokens) x (dim // Nh)
+
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ # q, k, v: B x Nh x (cls token + # patch tokens) x (dim // Nh)
+
+ # q: B x Nh x (cls token + # patch tokens) x (dim // Nh)
+ # k.transpose(-2, -1) = B x Nh x (dim // Nh) x (cls token + # patch tokens)
+ # attn: B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
+ attn = (q @ k.transpose(-2, -1)) * self.scale # @ operator is for matrix multiplication
+ attn = attn.softmax(dim=-1) # B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
+ attn = self.attn_drop(attn)
+
+ # attn = B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
+ # v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
+ # attn @ v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
+ # (attn @ v).transpose(1, 2) = B x (cls token + # patch tokens) x Nh x (dim // Nh)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C) # B x (cls token + # patch tokens) x dim
+ x = self.proj(x) # B x (cls token + # patch tokens) x dim
+ x = self.proj_drop(x)
+ return x, attn
+
+
+class Block(nn.Module):
+ def __init__(self,
+ dim, num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, return_attention=False):
+ y, attn = self.attn(self.norm1(x))
+ if return_attention:
+ return attn
+ x = x + self.drop_path(y)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding"""
+ def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2) # B x (P_H * P_W) x C
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer """
+ def __init__(self,
+ img_size=(224, 224),
+ patch_size=16,
+ in_chans=3,
+ num_classes=0,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim
+
+ self.patch_embed = PatchEmbed(
+ img_size=(224, 224), # noel: this is to load pretrained model.
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer
+ ) for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Classifier head
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ self.depth = depth
+ self.embed_dim = self.n_embs = embed_dim
+ self.mlp_ratio = mlp_ratio
+ self.n_heads = num_heads
+ self.patch_size = patch_size
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
+ """Pad some pixels to make the input size divisible by the patch size."""
+ B, _, H_0, W_0 = x.shape
+ pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
+ pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
+
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
+ return x
+
+ def prepare_tokens(self, x):
+ B, nc, h, w = x.shape
+ x: torch.Tensor = self.make_input_divisible(x)
+ patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
+
+ x = self.patch_embed(x) # patch linear embedding
+
+ # add positional encoding to each token
+ # add the [CLS] token to the embed patch tokens
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
+ return self.pos_drop(x)
+
+ @staticmethod
+ def split_token(x, token_type: str):
+ if token_type == "cls":
+ return x[:, 0, :]
+ elif token_type == "patch":
+ return x[:, 1:, :]
+ else:
+ return x
+
+ # noel
+ def forward(self, x, layer: Optional[str] = None):
+ x: torch.Tensor = self.prepare_tokens(x)
+
+ features: dict = {}
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ features[f"layer{i + 1}"] = self.norm(x)
+
+ if layer is not None:
+ return features[layer]
+ else:
+ return features
+
+ # noel - for DINO's visual
+ def get_last_selfattention(self, x):
+ x = self.prepare_tokens(x)
+ for i, blk in enumerate(self.blocks):
+ if i < len(self.blocks) - 1:
+ x = blk(x)
+ else:
+ # return attention of the last block
+ return blk(x, return_attention=True)
+
+ def get_tokens(
+ self,
+ x,
+ layers: list,
+ patch_tokens: bool = False,
+ norm: bool = True,
+ input_tokens: bool = False,
+ post_pe: bool = False
+ ):
+ """Return intermediate tokens."""
+ list_tokens: list = []
+
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if input_tokens:
+ list_tokens.append(x)
+
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+ x = x + pos_embed
+
+ if post_pe:
+ list_tokens.append(x)
+
+ x = self.pos_drop(x)
+
+ for i, blk in enumerate(self.blocks):
+ x = blk(x) # B x # patches x dim
+ if layers is None or i in layers:
+ list_tokens.append(self.norm(x) if norm else x)
+
+ tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim
+
+ if not patch_tokens:
+ return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim
+
+ else:
+ return tokens
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ return x[:, 0]
+
+ def interpolate_pos_encoding(self, x, pos_embed, size):
+ """Interpolate the learnable positional encoding to match the number of patches.
+
+ x: B x (1 + N patches) x dim_embedding
+ pos_embed: B x (1 + N patches) x dim_embedding
+
+ return interpolated positional embedding
+ """
+ npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size)
+ N = pos_embed.shape[1] - 1 # 784 (= 28 x 28)
+ if npatch == N:
+ return pos_embed
+ class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings
+
+ dim = x.shape[-1] # dimension of embeddings
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28
+ size=size,
+ mode='bicubic',
+ align_corners=False
+ )
+
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
+ return pos_embed
+
+ def forward_selfattention(self, x, return_interm_attn=False):
+ B, nc, w, h = x.shape
+ N = self.pos_embed.shape[1] - 1
+ x = self.patch_embed(x)
+
+ # interpolate patch embeddings
+ dim = x.shape[-1]
+ w0 = w // self.patch_embed.patch_size
+ h0 = h // self.patch_embed.patch_size
+ class_pos_embed = self.pos_embed[:, 0]
+ patch_pos_embed = self.pos_embed[:, 1:]
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode='bicubic'
+ )
+ if w0 != patch_pos_embed.shape[-2]:
+ helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
+ patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
+ if h0 != patch_pos_embed.shape[-1]:
+ helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
+ pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # self.cls_token: 1 x 1 x emb_dim -> ?
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ if return_interm_attn:
+ list_attn = []
+ for i, blk in enumerate(self.blocks):
+ attn = blk(x, return_attention=True)
+ x = blk(x)
+ list_attn.append(attn)
+ return torch.cat(list_attn, dim=0)
+
+ else:
+ for i, blk in enumerate(self.blocks):
+ if i < len(self.blocks) - 1:
+ x = blk(x)
+ else:
+ return blk(x, return_attention=True)
+
+ def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ x = torch.cat((cls_tokens, x), dim=1)
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ # we will return the [CLS] tokens from the `n` last blocks
+ output = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if len(self.blocks) - i <= n:
+ # get only CLS token (B x dim)
+ output.append(self.norm(x)[:, 0])
+ if return_patch_avgpool:
+ x = self.norm(x)
+ # In addition to the [CLS] tokens from the `n` last blocks, we also return
+ # the patch tokens from the last block. This is useful for linear eval.
+ output.append(torch.mean(x[:, 1:], dim=1))
+ return torch.cat(output, dim=-1)
+
+ def return_patch_emb_from_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
+ """Return intermediate patch embeddings, rather than CLS token, from the last n blocks."""
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ x = torch.cat((cls_tokens, x), dim=1)
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ # we will return the [CLS] tokens from the `n` last blocks
+ output = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if len(self.blocks) - i <= n:
+ output.append(self.norm(x)[:, 1:]) # get only CLS token (B x dim)
+
+ if return_patch_avgpool:
+ x = self.norm(x)
+ # In addition to the [CLS] tokens from the `n` last blocks, we also return
+ # the patch tokens from the last block. This is useful for linear eval.
+ output.append(torch.mean(x[:, 1:], dim=1))
+ return torch.stack(output, dim=-1) # B x n_patches x dim x n
+
+
+def deit_tiny(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=192,
+ depth=12,
+ num_heads=3,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs)
+ return model
+
+
+def deit_small(patch_size=16, **kwargs):
+ depth = kwargs.pop("depth") if "depth" in kwargs else 12
+ model = VisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=depth,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+class DINOHead(nn.Module):
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ if nlayers == 1:
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
+ self.mlp = nn.Sequential(*layers)
+ self.apply(self._init_weights)
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+ if norm_last_layer:
+ self.last_layer.weight_g.requires_grad = False
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ x = nn.functional.normalize(x, dim=-1, p=2)
+ x = self.last_layer(x)
+ return x
diff --git a/resources/.DS_Store b/resources/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..039236944271ae9bfcc54629b05f46e51106b004
Binary files /dev/null and b/resources/.DS_Store differ
diff --git a/resources/0053.jpg b/resources/0053.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..82abb9e649639af3906936a10cc894769681d6bc
Binary files /dev/null and b/resources/0053.jpg differ
diff --git a/resources/0236.jpg b/resources/0236.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2b2d5c52031370e67008f36ad86f52a1b0788ef1
Binary files /dev/null and b/resources/0236.jpg differ
diff --git a/resources/0239.jpg b/resources/0239.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e309e0338648952704230f8a5968bc4b58549e10
Binary files /dev/null and b/resources/0239.jpg differ
diff --git a/resources/0403.jpg b/resources/0403.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..693807004eb10b07d01b988815c576bf3ee9fb7b
Binary files /dev/null and b/resources/0403.jpg differ
diff --git a/resources/0412.jpg b/resources/0412.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f38a782ac2d4c93c41346682171e3208f59a8a4
Binary files /dev/null and b/resources/0412.jpg differ
diff --git a/resources/ILSVRC2012_test_00005309.jpg b/resources/ILSVRC2012_test_00005309.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f3c75f4df56a5d19913c78dbb8f139a9cd03f6e1
Binary files /dev/null and b/resources/ILSVRC2012_test_00005309.jpg differ
diff --git a/resources/ILSVRC2012_test_00012622.jpg b/resources/ILSVRC2012_test_00012622.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d3a7481ad690cdc0b77eb92c67c823fb4756452d
Binary files /dev/null and b/resources/ILSVRC2012_test_00012622.jpg differ
diff --git a/resources/ILSVRC2012_test_00022698.jpg b/resources/ILSVRC2012_test_00022698.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7cb1edc16e45fd5d61d0502ff46f448b43c82b28
Binary files /dev/null and b/resources/ILSVRC2012_test_00022698.jpg differ
diff --git a/resources/ILSVRC2012_test_00040725.jpg b/resources/ILSVRC2012_test_00040725.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..67c4035e78f6891ddc555d911e82598652bb22db
Binary files /dev/null and b/resources/ILSVRC2012_test_00040725.jpg differ
diff --git a/resources/ILSVRC2012_test_00075738.jpg b/resources/ILSVRC2012_test_00075738.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7ebdcb5281e963c42c7c23f3013831b765e7005c
Binary files /dev/null and b/resources/ILSVRC2012_test_00075738.jpg differ
diff --git a/resources/ILSVRC2012_test_00080683.jpg b/resources/ILSVRC2012_test_00080683.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7c4182ed81b25a7efc59c9870cb316a8cede8583
Binary files /dev/null and b/resources/ILSVRC2012_test_00080683.jpg differ
diff --git a/resources/ILSVRC2012_test_00085874.jpg b/resources/ILSVRC2012_test_00085874.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e4c424066ce390fe186f4757da31ebed0bd2257a
Binary files /dev/null and b/resources/ILSVRC2012_test_00085874.jpg differ
diff --git a/resources/im052.jpg b/resources/im052.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..569b41eda0305ad49b5d9a1f3cc3372c08a771ad
Binary files /dev/null and b/resources/im052.jpg differ
diff --git a/resources/sun_ainjbonxmervsvpv.jpg b/resources/sun_ainjbonxmervsvpv.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2ba94a5dfdc3dd3c63934de88d58d5625248ad4b
Binary files /dev/null and b/resources/sun_ainjbonxmervsvpv.jpg differ
diff --git a/resources/sun_alfntqzssslakmss.jpg b/resources/sun_alfntqzssslakmss.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..92decf70f2ed58211d2402504e1385fc0f56f879
Binary files /dev/null and b/resources/sun_alfntqzssslakmss.jpg differ
diff --git a/resources/sun_amnrcxhisjfrliwa.jpg b/resources/sun_amnrcxhisjfrliwa.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..08ad6f5439968c4844cf62432a2123ea943c7ce5
Binary files /dev/null and b/resources/sun_amnrcxhisjfrliwa.jpg differ
diff --git a/resources/sun_bvyxpvkouzlfwwod.jpg b/resources/sun_bvyxpvkouzlfwwod.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..463012695b73898a8fee6aa9f81805e1fbddc8a8
Binary files /dev/null and b/resources/sun_bvyxpvkouzlfwwod.jpg differ
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5442d5d5d3452f0782acdf98d23fdc0c970c80e3
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,87 @@
+from argparse import Namespace
+from typing import Optional
+import torch
+
+
+def get_model(
+ arch: str,
+ patch_size: Optional[int] = None,
+ training_method: Optional[str] = None,
+ configs: Optional[Namespace] = None,
+ **kwargs
+):
+ if arch == "maskformer":
+ assert configs is not None
+ from networks.maskformer.maskformer import MaskFormer
+ model = MaskFormer(
+ n_queries=configs.n_queries,
+ n_decoder_layers=configs.n_decoder_layers,
+ learnable_pixel_decoder=configs.learnable_pixel_decoder,
+ lateral_connection=configs.lateral_connection,
+ return_intermediate=configs.loss_every_decoder_layer,
+ scale_factor=configs.scale_factor,
+ abs_2d_pe_init=configs.abs_2d_pe_init,
+ use_binary_classifier=configs.use_binary_classifier,
+ arch=configs.arch,
+ training_method=configs.training_method,
+ patch_size=configs.patch_size
+ )
+
+ for n, p in model.encoder.named_parameters():
+ p.requires_grad_(True)
+
+ elif "vit" in arch:
+ import networks.vision_transformer as vits
+ import networks.timm_deit as timm_deit
+ if training_method == "dino":
+ arch = arch.replace("vit", "deit") if arch.find("small") != -1 else arch
+ model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
+ load_model(model, arch, patch_size)
+
+ elif training_method == "deit":
+ assert patch_size == 16
+ model = timm_deit.deit_small_distilled_patch16_224(True)
+
+ elif training_method == "supervised":
+ assert patch_size == 16
+ state_dict: dict = torch.load(
+ "/users/gyungin/selfmask/networks/pretrained/deit_small_patch16_224-cd65a155.pth"
+ )["model"]
+ for k in list(state_dict.keys()):
+ if k in ["head.weight", "head.bias"]: # classifier head, which is not used in our network
+ state_dict.pop(k)
+
+ model = get_model(arch="vit_small", patch_size=16, training_method="dino")
+ model.load_state_dict(state_dict=state_dict, strict=True)
+
+ else:
+ raise NotImplementedError
+ print(f"{arch}_p{patch_size}_{training_method} is built.")
+
+ elif arch == "resnet50":
+ from networks.resnet import ResNet50
+ assert training_method in ["mocov2", "swav", "supervised"]
+ model = ResNet50(training_method)
+
+ else:
+ raise ValueError(f"{arch} is not supported arch. Choose from [maskformer, resnet50, vit, dino]")
+ return model
+
+
+def load_model(model, arch: str, patch_size: int) -> None:
+ url = None
+ if arch == "deit_small" and patch_size == 16:
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
+ elif arch == "deit_small" and patch_size == 8:
+ # model used for visualizations in our paper
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
+ elif arch == "vit_base" and patch_size == 16:
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
+ elif arch == "vit_base" and patch_size == 8:
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
+ if url is not None:
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
+ model.load_state_dict(state_dict, strict=True)
+ else:
+ print("There is no reference weights available for this model => We use random weights.")
\ No newline at end of file