diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8935f33351def1143b1450f471f567883a0c0ae9 Binary files /dev/null and b/.DS_Store differ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..13566b81b018ad684f3a35fee301741b2734c8f4 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000000000000000000000000000000000000..e482f5b5a58d40a37161eff92355e13ca618f4cb --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..d1a23f47254192ad5c9f5cb8a9c4310c35ae07a3 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,23 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..a2f7bd01a70dfab81da858710e1095b5f75ff03a --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..c05e55baa78ca1129ab64def195fc2944b23f28c --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/selfmask_demo.iml b/.idea/selfmask_demo.iml new file mode 100644 index 0000000000000000000000000000000000000000..d0876a78d06ac03b5d78c8dcdb95570281c6f1d6 --- /dev/null +++ b/.idea/selfmask_demo.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/sonarlint/issuestore/index.pb b/.idea/sonarlint/issuestore/index.pb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.idea/webServers.xml b/.idea/webServers.xml new file mode 100644 index 0000000000000000000000000000000000000000..faeee9bbfd00bc2fe119a6052475fc0d3e718fcb --- /dev/null +++ b/.idea/webServers.xml @@ -0,0 +1,14 @@ + + + + + + \ No newline at end of file diff --git a/__pycache__/bilateral_solver.cpython-38.pyc b/__pycache__/bilateral_solver.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d421a9f6a31a3b4665b76dc92bd05c5139232f4f Binary files /dev/null and b/__pycache__/bilateral_solver.cpython-38.pyc differ diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c33b934e0f314cbd4faffbbdb6de7631c554b4d Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ba09295e29046ede81f5ea268f3f251ed750e3 --- /dev/null +++ b/app.py @@ -0,0 +1,134 @@ +from argparse import ArgumentParser, Namespace +from typing import Dict, List, Tuple +import yaml +import numpy as np +import cv2 +from PIL import Image +import torch +import torch.nn.functional as F +from torchvision.transforms.functional import to_tensor, normalize, resize +import gradio as gr +from utils import get_model +from bilateral_solver import bilateral_solver_output +import os +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +state_dict: dict = torch.hub.load_state_dict_from_url( + "https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt", + map_location=device # "cuda" if torch.cuda.is_available() else "cpu" +)["model"] + +parser = ArgumentParser("SelfMask demo") +parser.add_argument( + "--config", + type=str, + default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml" +) + +# parser.add_argument( +# "--p_state_dict", +# type=str, +# default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt", +# ) +# +# parser.add_argument( +# "--dataset_name", '-dn', type=str, default="duts", +# choices=["dut_omron", "duts", "ecssd"] +# ) + +# independent variables +# parser.add_argument("--use_gpu", type=bool, default=True) +# parser.add_argument('--seed', default=0, type=int) +# parser.add_argument("--dir_root", type=str, default="..") +# parser.add_argument("--gpu_id", type=int, default=2) +# parser.add_argument("--suffix", type=str, default='') +args: Namespace = parser.parse_args() +base_args = yaml.safe_load(open(f"{args.config}", 'r')) +base_args.pop("dataset_name") +args: dict = vars(args) +args.update(base_args) +args: Namespace = Namespace(**args) + +model = get_model(arch="maskformer", configs=args).to(device) +model.load_state_dict(state_dict) +model.eval() + + +@torch.no_grad() +def main( + image: Image.Image, + size: int = 384, + max_size: int = 512, + mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), + std: Tuple[float, float, float] = (0.229, 0.224, 0.225) +): + pil_image: Image.Image = resize(image, size=size, max_size=max_size) + image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W + dict_outputs = model(image[None].to(device)) + + batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1] + batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1] + + if len(batch_pred_masks.shape) == 5: + # b x n_layers x n_queries x h x w -> b x n_queries x h x w + batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer + + if batch_objectness is not None: + # b x n_layers x n_queries x 1 -> b x n_queries x 1 + batch_objectness = batch_objectness[:, -1, ...] + + # resize prediction to original resolution + # note: upsampling by 4 and cutting the padded region allows for a better result + H, W = image.shape[-2:] + batch_pred_masks = F.interpolate( + batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False + )[..., :H, :W] + + # iterate over batch dimension + for batch_index, pred_masks in enumerate(batch_pred_masks): + # n_queries x 1 -> n_queries + objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1) + ranks = torch.argsort(objectness, descending=True) # n_queries + pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W + pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255 + + pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64 + pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8) + + attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB) + super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0) + return super_imposed_img + # return pred_mask_bi + +demo = gr.Interface( + fn=main, + inputs=gr.inputs.Image(type="pil"), + outputs="image", + examples=[f"resources/{fname}.jpg" for fname in [ + "0053", + "0236", + "0239", + "0403", + "0412", + "ILSVRC2012_test_00005309", + "ILSVRC2012_test_00012622", + "ILSVRC2012_test_00022698", + "ILSVRC2012_test_00040725", + "ILSVRC2012_test_00075738", + "ILSVRC2012_test_00080683", + "ILSVRC2012_test_00085874", + "im052", + "sun_ainjbonxmervsvpv", + "sun_alfntqzssslakmss", + "sun_amnrcxhisjfrliwa", + "sun_bvyxpvkouzlfwwod" + ]], + title="Unsupervised Salient Object Detection with Spectral Cluster Voting", + allow_flagging="never", + analytics_enabled=False +) + +demo.launch( + # share=True +) \ No newline at end of file diff --git a/bilateral_solver.py b/bilateral_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..6e75876397756d7ae364b746a15191423c402dd4 --- /dev/null +++ b/bilateral_solver.py @@ -0,0 +1,206 @@ +from scipy.sparse import diags +from scipy.sparse.linalg import cg +from scipy.sparse import csr_matrix +import numpy as np +from skimage.io import imread +from scipy import ndimage +import torch +import PIL.Image as Image +import os +from argparse import ArgumentParser, Namespace +from typing import Dict, Union +from collections import defaultdict +import yaml +import ujson as json +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + + +RGB_TO_YUV = np.array([ + [0.299, 0.587, 0.114], + [-0.168736, -0.331264, 0.5], + [0.5, -0.418688, -0.081312]]) +YUV_TO_RGB = np.array([ + [1.0, 0.0, 1.402], + [1.0, -0.34414, -0.71414], + [1.0, 1.772, 0.0]]) +YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1) +MAX_VAL = 255.0 + + +def rgb2yuv(im): + return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET + + +def yuv2rgb(im): + return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1])) + + +def get_valid_idx(valid, candidates): + """Find which values are present in a list and where they are located""" + locs = np.searchsorted(valid, candidates) + # Handle edge case where the candidate is larger than all valid values + locs = np.clip(locs, 0, len(valid) - 1) + # Identify which values are actually present + valid_idx = np.flatnonzero(valid[locs] == candidates) + locs = locs[valid_idx] + return valid_idx, locs + + +class BilateralGrid(object): + def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8): + im_yuv = rgb2yuv(im) + # Compute 5-dimensional XYLUV bilateral-space coordinates + Iy, Ix = np.mgrid[:im.shape[0], :im.shape[1]] + x_coords = (Ix / sigma_spatial).astype(int) + y_coords = (Iy / sigma_spatial).astype(int) + luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int) + chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int) + coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords)) + coords_flat = coords.reshape(-1, coords.shape[-1]) + self.npixels, self.dim = coords_flat.shape + # Hacky "hash vector" for coordinates, + # Requires all scaled coordinates be < MAX_VAL + self.hash_vec = (MAX_VAL ** np.arange(self.dim)) + # Construct S and B matrix + self._compute_factorization(coords_flat) + + def _compute_factorization(self, coords_flat): + # Hash each coordinate in grid to a unique value + hashed_coords = self._hash_coords(coords_flat) + unique_hashes, unique_idx, idx = \ + np.unique(hashed_coords, return_index=True, return_inverse=True) + # Identify unique set of vertices + unique_coords = coords_flat[unique_idx] + self.nvertices = len(unique_coords) + # Construct sparse splat matrix that maps from pixels to vertices + self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels)))) + # Construct sparse blur matrices. + # Note that these represent [1 0 1] blurs, excluding the central element + self.blurs = [] + for d in range(self.dim): + blur = 0.0 + for offset in (-1, 1): + offset_vec = np.zeros((1, self.dim)) + offset_vec[:, d] = offset + neighbor_hash = self._hash_coords(unique_coords + offset_vec) + valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash) + blur = blur + csr_matrix((np.ones((len(valid_coord),)), + (valid_coord, idx)), + shape=(self.nvertices, self.nvertices)) + self.blurs.append(blur) + + def _hash_coords(self, coord): + """Hacky function to turn a coordinate into a unique value""" + return np.dot(coord.reshape(-1, self.dim), self.hash_vec) + + def splat(self, x): + return self.S.dot(x) + + def slice(self, y): + return self.S.T.dot(y) + + def blur(self, x): + """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension""" + assert x.shape[0] == self.nvertices + out = 2 * self.dim * x + for blur in self.blurs: + out = out + blur.dot(x) + return out + + def filter(self, x): + """Apply bilateral filter to an input x""" + return self.slice(self.blur(self.splat(x))) / \ + self.slice(self.blur(self.splat(np.ones_like(x)))) + + +def bistochastize(grid, maxiter=10): + """Compute diagonal matrices to bistochastize a bilateral grid""" + m = grid.splat(np.ones(grid.npixels)) + n = np.ones(grid.nvertices) + for i in range(maxiter): + n = np.sqrt(n * m / grid.blur(n)) + # Correct m to satisfy the assumption of bistochastization regardless + # of how many iterations have been run. + m = n * grid.blur(n) + Dm = diags(m, 0) + Dn = diags(n, 0) + return Dn, Dm + + +class BilateralSolver(object): + def __init__(self, grid, params): + self.grid = grid + self.params = params + self.Dn, self.Dm = bistochastize(grid) + + def solve(self, x, w): + # Check that w is a vector or a nx1 matrix + if w.ndim == 2: + assert (w.shape[1] == 1) + elif w.dim == 1: + w = w.reshape(w.shape[0], 1) + A_smooth = (self.Dm - self.Dn.dot(self.grid.blur(self.Dn))) + w_splat = self.grid.splat(w) + A_data = diags(w_splat[:, 0], 0) + A = self.params["lam"] * A_smooth + A_data + xw = x * w + b = self.grid.splat(xw) + # Use simple Jacobi preconditioner + A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"]) + M = diags(1 / A_diag, 0) + # Flat initialization + y0 = self.grid.splat(xw) / w_splat + yhat = np.empty_like(y0) + for d in range(x.shape[-1]): + yhat[..., d], info = cg(A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"], + tol=self.params["cg_tol"]) + xhat = self.grid.slice(yhat) + return xhat + + +def bilateral_solver_output( + img: Image.Image, + target: np.ndarray, + sigma_spatial=16, + sigma_luma=16, + sigma_chroma=8 +): + reference = np.array(img) + h, w = target.shape + confidence = np.ones((h, w)) * 0.999 + + grid_params = { + 'sigma_luma': sigma_luma, # Brightness bandwidth + 'sigma_chroma': sigma_chroma, # Color bandwidth + 'sigma_spatial': sigma_spatial # Spatial bandwidth + } + + bs_params = { + 'lam': 256, # The strength of the smoothness parameter + 'A_diag_min': 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner. + 'cg_tol': 1e-5, # The tolerance on the convergence in PCG + 'cg_maxiter': 25 # The number of PCG iterations + } + + grid = BilateralGrid(reference, **grid_params) + + t = target.reshape(-1, 1).astype(np.double) + c = confidence.reshape(-1, 1).astype(np.double) + + ## output solver, which is a soft value + output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w)) + + binary_solver = ndimage.binary_fill_holes(output_solver > 0.5) + labeled, nr_objects = ndimage.label(binary_solver) + + nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)] + pixel_order = np.argsort(nb_pixel) + try: + binary_solver = labeled == pixel_order[-2] + except: + binary_solver = np.ones((h, w), dtype=bool) + + return output_solver, binary_solver \ No newline at end of file diff --git a/duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml b/duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0cebb9c24cdb795291eb4df27ef45e1071d5aa9b --- /dev/null +++ b/duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml @@ -0,0 +1,56 @@ +# augmentations +use_copy_paste: false +scale_range: [ 0.1, 1.0 ] +repeat_image: false + +# base directories +dir_ckpt: "/users/gyungin/selfmask/ckpt" # "/work/gyungin/selfmask/ckpt" +dir_dataset: "/scratch/shared/beegfs/gyungin/datasets" + +# clustering +k: [2, 3, 4] +clustering_mode: "spectral" +use_gpu: true # if you want to use gpu-accelerated code for clustering +scale_factor: 2 # "how much you want to upsample encoder features before clustering" + +# dataset +dataset_name: "duts" +use_pseudo_masks: true +train_image_size: 224 +eval_image_size: 224 +n_percent: 100 +n_copy_pastes: null +pseudo_masks_fp: "/users/gyungin/selfmask/datasets/swav_mocov2_dino_p16_k234.json" + +# dataloader: +batch_size: 8 +num_workers: 4 +pin_memory: true + +# networks +abs_2d_pe_init: false +arch: "vit_small" +lateral_connection: false +learnable_pixel_decoder: false # if False, use the bilinear interpolation +use_binary_classifier: true # if True, use a binary classifier to get an objectness for each query from transformer decoder +n_decoder_layers: 6 +n_queries: 20 +num_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +patch_size: 8 +training_method: "dino" # "supervised", "deit", "dino", "mocov2", "swav" + +# objective +loss_every_decoder_layer: true +weight_dice_loss: 1.0 +weight_focal_loss: 0.0 + +# optimizer +lr: 0.000006 # default: 0.00006 +lr_warmup_duration: 0 # 5 +momentum: 0.9 +n_epochs: 12 +weight_decay: 0.01 +optimizer_type: "adamw" + +# validation +benchmarks: null \ No newline at end of file diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/networks/__pycache__/__init__.cpython-38.pyc b/networks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0f4d3d9de3c2020e8d0d03610ce22dc18186aa Binary files /dev/null and b/networks/__pycache__/__init__.cpython-38.pyc differ diff --git a/networks/__pycache__/timm_deit.cpython-38.pyc b/networks/__pycache__/timm_deit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b6bd1424fc3768496f2bd9185d57dce2116eb3 Binary files /dev/null and b/networks/__pycache__/timm_deit.cpython-38.pyc differ diff --git a/networks/__pycache__/timm_vit.cpython-38.pyc b/networks/__pycache__/timm_vit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b494e2e88b94c3fd4b43fe416047bf1420239a44 Binary files /dev/null and b/networks/__pycache__/timm_vit.cpython-38.pyc differ diff --git a/networks/__pycache__/vision_transformer.cpython-38.pyc b/networks/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f28b4f5bc8b1c2f5b50989873ddc04317635dd9e Binary files /dev/null and b/networks/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/networks/maskformer/__pycache__/maskformer.cpython-38.pyc b/networks/maskformer/__pycache__/maskformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d86e6997bd4f473f54befe77d4394cba393612bf Binary files /dev/null and b/networks/maskformer/__pycache__/maskformer.cpython-38.pyc differ diff --git a/networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc b/networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f94f2f482506505c296f3fa23af62aadb4c1f93 Binary files /dev/null and b/networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc differ diff --git a/networks/maskformer/maskformer.py b/networks/maskformer/maskformer.py new file mode 100644 index 0000000000000000000000000000000000000000..003896f1f4c8d58c361dea9699625d1a29fd4782 --- /dev/null +++ b/networks/maskformer/maskformer.py @@ -0,0 +1,267 @@ +from typing import Dict, List +from math import sqrt, log +import torch +import torch.nn as nn +import torch.nn.functional as F + +from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder +from utils import get_model + + +class MaskFormer(nn.Module): + def __init__( + self, + n_queries: int = 100, + arch: str = "vit_small", + patch_size: int = 8, + training_method: str = "dino", + n_decoder_layers: int = 6, + normalize_before: bool = False, + return_intermediate: bool = False, + learnable_pixel_decoder: bool = False, + lateral_connection: bool = False, + scale_factor: int = 2, + abs_2d_pe_init: bool = False, + use_binary_classifier: bool = False + ): + """Define a encoder and decoder along with queries to be learned through the decoder.""" + super(MaskFormer, self).__init__() + + if arch == "vit_small": + self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method) + n_dims: int = self.encoder.n_embs + n_heads: int = self.encoder.n_heads + mlp_ratio: int = self.encoder.mlp_ratio + else: + self.encoder = get_model(arch=arch, training_method=training_method) + n_dims_resnet: int = self.encoder.n_embs + n_dims: int = 384 + n_heads: int = 6 + mlp_ratio: int = 4 + self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1) + + decoder_layer = TransformerDecoderLayer( + n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before + ) + self.decoder = TransformerDecoder( + decoder_layer, + n_decoder_layers, + norm=nn.LayerNorm(n_dims), + return_intermediate=return_intermediate + ) + + self.query_embed = nn.Embedding(n_queries, n_dims).weight # initialized with gaussian(0, 1) + + if use_binary_classifier: + # self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3) + # self.linear_classifier = nn.Linear(n_dims, 1) + self.ffn = MLP(n_dims, n_dims, 1, num_layers=3) + # self.norm = nn.LayerNorm(n_dims) + else: + # self.ffn = None + # self.linear_classifier = None + # self.norm = None + self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3) + self.linear_classifier = nn.Linear(n_dims, 2) + self.norm = nn.LayerNorm(n_dims) + + self.arch = arch + self.use_binary_classifier = use_binary_classifier + self.lateral_connection = lateral_connection + self.learnable_pixel_decoder = learnable_pixel_decoder + self.scale_factor = scale_factor + + # copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py + @staticmethod + def positional_encoding_2d(n_dims: int, height: int, width: int): + """ + :param n_dims: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if n_dims % 4 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(n_dims)) + pe = torch.zeros(n_dims, height, width) + # Each dimension use half of d_model + d_model = int(n_dims / 2) + div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model)) + pos_w = torch.arange(0., width).unsqueeze(1) + pos_h = torch.arange(0., height).unsqueeze(1) + pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + + return pe + + def forward_encoder(self, x: torch.Tensor): + """ + :param x: b x c x h x w + :return patch_tokens: b x depth x hw x n_dims + """ + if self.arch == "vit_small": + encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x) # [:, 1:, :] + all_patch_tokens: List[torch.Tensor] = list() + for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]: + patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :] # b x hw x n_dims + all_patch_tokens.append(patch_tokens) + + all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0) # depth x b x hw x n_dims + all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2) # b x depth x n_dims x hw + return all_patch_tokens + else: + encoder_outputs = self.linear_layer(self.encoder(x)[-1]) # b x n_dims x h x w + return encoder_outputs + + def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor: + """Forward transformer decoder given patch tokens from the encoder's last layer. + :param patch_tokens: b x n_dims x hw -> hw x b x n_dims + :param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication + between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting + experiment. + :return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims + """ + b = patch_tokens.shape[0] + patch_tokens = patch_tokens.permute(2, 0, 1) # b x n_dims x hw -> hw x b x n_dims + + # n_queries x n_dims -> n_queries x b x n_dims + queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1) + queries: torch.Tensor = self.decoder.forward( + tgt=torch.zeros_like(queries), + memory=patch_tokens, + query_pos=queries + ).squeeze(dim=0) + + if len(queries.shape) == 3: + queries: torch.Tensor = queries.permute(1, 0, 2) # n_queries x b x n_dims -> b x n_queries x n_dims + elif len(queries.shape) == 4: + # n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims + queries: torch.Tensor = queries.permute(2, 0, 1, 3) + return queries + + def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None): + """ Upsample patch tokens by self.scale_factor and produce mask predictions + :param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w + :param queries: b x n_queries x n_dims + :return mask_predictions: b x n_queries x h x w + """ + + if input_size is None: + # assume square shape features + hw = patch_tokens.shape[-1] + h = w = int(sqrt(hw)) + else: + # arbitrary shape features + h, w = input_size + patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w) + + assert len(patch_tokens.shape) == 4 + patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear") + return patch_tokens + + def forward(self, x, encoder_only=False, skip_decoder: bool = False): + """ + x: b x c x h x w + patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims + query_emb: n_queries x n_dims -> n_queries x b x n_dims + """ + dict_outputs: dict = dict() + + # b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50) + features: torch.Tensor = self.forward_encoder(x) + + if self.arch == "vit_small": + # extract the last layer for decoder input + last_layer_features: torch.Tensor = features[:, -1, ...] # b x n_dims x hw + else: + # transform the shape of the features to the one compatible with transformer decoder + b, n_dims, h, w = features.shape + last_layer_features: torch.Tensor = features.view(b, n_dims, h * w) # b x n_dims x hw + + if encoder_only: + _h, _w = self.encoder.make_input_divisible(x).shape[-2:] + _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size + + b, n_dims, hw = last_layer_features.shape + dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)}) + return dict_outputs + + # transformer decoder forward + queries: torch.Tensor = self.forward_transformer_decoder( + last_layer_features, + skip_decoder=skip_decoder + ) # b x n_queries x n_dims or b x n_layers x n_queries x n_dims + + # pixel decoder forward (upsampling the patch tokens by self.scale_factor) + if self.arch == "vit_small": + _h, _w = self.encoder.make_input_divisible(x).shape[-2:] + _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size + else: + _h, _w = h, w + features: torch.Tensor = self.forward_pixel_decoder( + patch_tokens=features if self.lateral_connection else last_layer_features, + input_size=(_h, _w) + ) # b x n_dims x h x w + + # queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims + # features: b x n_dims x h x w + # mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w + if len(queries.shape) == 3: + mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features) + else: + if self.use_binary_classifier: + mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features)) + else: + mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features)) + + if self.use_binary_classifier: + # queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims + queries = queries.permute(1, 0, 2, 3) + objectness: List[torch.Tensor] = list() + for n_layer, queries_per_layer in enumerate(queries): # queries_per_layer: b x n_queries x n_dims + # objectness_per_layer = self.linear_classifier( + # self.ffn(self.norm(queries_per_layer)) + # ) # b x n_queries x 1 + objectness_per_layer = self.ffn(queries_per_layer) # b x n_queries x 1 + objectness.append(objectness_per_layer) + # n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1 + objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3) + dict_outputs.update({ + "objectness": torch.sigmoid(objectness), + "mask_pred": mask_pred + }) + + return dict_outputs + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class UpsampleBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2): + super(UpsampleBlock, self).__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding), + nn.GroupNorm(n_groups, out_channels), + nn.ReLU() + ) + self.scale_factor = scale_factor + + def forward(self, x): + return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear") \ No newline at end of file diff --git a/networks/maskformer/positional_embedding.py b/networks/maskformer/positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..69d117e05e5a133424c9f9f7b883da129b9c4d45 --- /dev/null +++ b/networks/maskformer/positional_embedding.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos \ No newline at end of file diff --git a/networks/maskformer/transformer_decoder.py b/networks/maskformer/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6111be730f7f063281fab0199f6dd413ba50e9ba --- /dev/null +++ b/networks/maskformer/transformer_decoder.py @@ -0,0 +1,376 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py +""" +Transformer class. +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", # noel - dino used GeLU + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if mask is not None: + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder( + tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed + ) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers: nn.ModuleList = _get_clones(decoder_layer, num_layers) + self.num_layers: int = num_layers + self.norm = norm + self.return_intermediate: bool = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + + tgt2 = self.self_attn( + q, + k, + value=tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/networks/module_helper.py b/networks/module_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9a55e55b71f1d4f6951c625d90367a77d8f6d9 --- /dev/null +++ b/networks/module_helper.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Author: Donny You (youansheng@gmail.com) +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + + +class FixedBatchNorm(nn.BatchNorm2d): + def forward(self, input): + return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps) + + +class ModuleHelper(object): + @staticmethod + def BNReLU(num_features, norm_type=None, **kwargs): + if norm_type == 'batchnorm': + return nn.Sequential( + nn.BatchNorm2d(num_features, **kwargs), + nn.ReLU() + ) + elif norm_type == 'encsync_batchnorm': + from encoding.nn import BatchNorm2d + return nn.Sequential( + BatchNorm2d(num_features, **kwargs), + nn.ReLU() + ) + elif norm_type == 'instancenorm': + return nn.Sequential( + nn.InstanceNorm2d(num_features, **kwargs), + nn.ReLU() + ) + elif norm_type == 'fixed_batchnorm': + return nn.Sequential( + FixedBatchNorm(num_features, **kwargs), + nn.ReLU() + ) + else: + raise ValueError('Not support BN type: {}.'.format(norm_type)) + + @staticmethod + def BatchNorm3d(norm_type=None, ret_cls=False): + if norm_type == 'batchnorm': + return nn.BatchNorm3d + elif norm_type == 'encsync_batchnorm': + from encoding.nn import BatchNorm3d + return BatchNorm3d + elif norm_type == 'instancenorm': + return nn.InstanceNorm3d + else: + raise ValueError('Not support BN type: {}.'.format(norm_type)) + + @staticmethod + def BatchNorm2d(norm_type=None, ret_cls=False): + if norm_type == 'batchnorm': + return nn.BatchNorm2d + elif norm_type == 'encsync_batchnorm': + from encoding.nn import BatchNorm2d + return BatchNorm2d + + elif norm_type == 'instancenorm': + return nn.InstanceNorm2d + else: + raise ValueError('Not support BN type: {}.'.format(norm_type)) + + @staticmethod + def BatchNorm1d(norm_type=None, ret_cls=False): + if norm_type == 'batchnorm': + return nn.BatchNorm1d + elif norm_type == 'encsync_batchnorm': + from encoding.nn import BatchNorm1d + return BatchNorm1d + elif norm_type == 'instancenorm': + return nn.InstanceNorm1d + else: + raise ValueError('Not support BN type: {}.'.format(norm_type)) + + @staticmethod + def load_model(model, pretrained=None, all_match=True, map_location='cpu'): + if pretrained is None: + return model + + if not os.path.exists(pretrained): + pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation") + if os.path.exists(pretrained): + pass + else: + raise FileNotFoundError('{} not exists.'.format(pretrained)) + + print('Loading pretrained model:{}'.format(pretrained)) + if all_match: + pretrained_dict = torch.load(pretrained, map_location=map_location) + model_dict = model.state_dict() + load_dict = dict() + for k, v in pretrained_dict.items(): + if 'prefix.{}'.format(k) in model_dict: + load_dict['prefix.{}'.format(k)] = v + else: + load_dict[k] = v + model.load_state_dict(load_dict) + + else: + pretrained_dict = torch.load(pretrained) + model_dict = model.state_dict() + load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + print('Matched Keys: {}'.format(load_dict.keys())) + model_dict.update(load_dict) + model.load_state_dict(model_dict) + + return model + + @staticmethod + def load_url(url, map_location=None): + model_dir = os.path.join('~', '.TorchCV', 'model') + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + print('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + + print('Loading pretrained model:{}'.format(cached_file)) + return torch.load(cached_file, map_location=map_location) + + @staticmethod + def constant_init(module, val, bias=0): + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def xavier_init(module, gain=1, bias=0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def normal_init(module, mean=0, std=1, bias=0): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def uniform_init(module, a=0, b=1, bias=0): + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def kaiming_init(module, + mode='fan_in', + nonlinearity='leaky_relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + diff --git a/networks/resnet.py b/networks/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..49b8e791d2248c0affdacc8f64af03094a8c9137 --- /dev/null +++ b/networks/resnet.py @@ -0,0 +1,60 @@ +import os + +import torch +import torch.nn as nn +from .resnet_backbone import ResNetBackbone + + +class ResNet50(nn.Module): + def __init__( + self, + weight_type: str = "supervised", + use_dilated_resnet: bool = True + ): + super(ResNet50, self).__init__() + self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None) + self.n_embs = self.network.num_features + self.use_dilated_resnet = use_dilated_resnet + self._load_pretrained(weight_type) + + def _load_pretrained(self, training_method: str) -> None: + curr_state_dict = self.network.state_dict() + if training_method == "mocov2": + state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"] + + for k in list(state_dict.keys()): + if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]): + state_dict.pop(k) + + elif training_method == "swav": + state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar") + for k in list(state_dict.keys()): + if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]): + state_dict.pop(k) + + elif training_method == "supervised": + # Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why. + # for k in list(curr_state_dict.keys()): + # if k.find("num_batches_tracked") != -1: + # curr_state_dict.pop(k) + # state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth") + + from torchvision.models.resnet import resnet50 + resnet50_supervised = resnet50(True, True) + state_dict = resnet50_supervised.state_dict() + for k in list(state_dict.keys()): + if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]): + state_dict.pop(k) + + assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}" + for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()): + curr_state_dict[k_curr].copy_(state_dict[k]) + print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.") + return + + def forward(self, x): + return self.network(x) + + +if __name__ == '__main__': + resnet = ResNet50("mocov2") diff --git a/networks/resnet_backbone.py b/networks/resnet_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..498ae808f02f60f109eb2c2302e2783360eb6db4 --- /dev/null +++ b/networks/resnet_backbone.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Author: Donny You(youansheng@gmail.com) + + +import torch.nn as nn +from networks.resnet_models import * + + +class NormalResnetBackbone(nn.Module): + def __init__(self, orig_resnet): + super(NormalResnetBackbone, self).__init__() + + self.num_features = 2048 + # take pretrained resnet, except AvgPool and FC + self.prefix = orig_resnet.prefix + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def get_num_features(self): + return self.num_features + + def forward(self, x): + tuple_features = list() + x = self.prefix(x) + x = self.maxpool(x) + x = self.layer1(x) + tuple_features.append(x) + x = self.layer2(x) + tuple_features.append(x) + x = self.layer3(x) + tuple_features.append(x) + x = self.layer4(x) + tuple_features.append(x) + + return tuple_features + + +class DilatedResnetBackbone(nn.Module): + def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)): + super(DilatedResnetBackbone, self).__init__() + + self.num_features = 2048 + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) + if multi_grid is None: + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) + else: + for i, r in enumerate(multi_grid): + orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r))) + + elif dilate_scale == 16: + if multi_grid is None: + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) + else: + for i, r in enumerate(multi_grid): + orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r))) + + # Take pretrained resnet, except AvgPool and FC + self.prefix = orig_resnet.prefix + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def get_num_features(self): + return self.num_features + + def forward(self, x): + tuple_features = list() + + x = self.prefix(x) + x = self.maxpool(x) + + x = self.layer1(x) + tuple_features.append(x) + x = self.layer2(x) + tuple_features.append(x) + x = self.layer3(x) + tuple_features.append(x) + x = self.layer4(x) + tuple_features.append(x) + + return tuple_features + + +def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'): + arch = backbone + + if arch == 'resnet18': + orig_resnet = resnet18(pretrained=pretrained) + arch_net = NormalResnetBackbone(orig_resnet) + arch_net.num_features = 512 + + elif arch == 'resnet18_dilated8': + orig_resnet = resnet18(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) + arch_net.num_features = 512 + + elif arch == 'resnet34': + orig_resnet = resnet34(pretrained=pretrained) + arch_net = NormalResnetBackbone(orig_resnet) + arch_net.num_features = 512 + + elif arch == 'resnet34_dilated8': + orig_resnet = resnet34(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) + arch_net.num_features = 512 + + elif arch == 'resnet34_dilated16': + orig_resnet = resnet34(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) + arch_net.num_features = 512 + + elif arch == 'resnet50': + orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) + arch_net = NormalResnetBackbone(orig_resnet) + + elif arch == 'resnet50_dilated8': + orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) + + elif arch == 'resnet50_dilated16': + orig_resnet = resnet50(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) + + elif arch == 'deepbase_resnet50': + if pretrained: + pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' + orig_resnet = deepbase_resnet50(pretrained=pretrained) + arch_net = NormalResnetBackbone(orig_resnet) + + elif arch == 'deepbase_resnet50_dilated8': + if pretrained: + pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' + # pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth" + orig_resnet = deepbase_resnet50(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) + + elif arch == 'deepbase_resnet50_dilated16': + orig_resnet = deepbase_resnet50(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) + + elif arch == 'resnet101': + orig_resnet = resnet101(pretrained=pretrained) + arch_net = NormalResnetBackbone(orig_resnet) + + elif arch == 'resnet101_dilated8': + orig_resnet = resnet101(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) + + elif arch == 'resnet101_dilated16': + orig_resnet = resnet101(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) + + elif arch == 'deepbase_resnet101': + orig_resnet = deepbase_resnet101(pretrained=pretrained) + arch_net = NormalResnetBackbone(orig_resnet) + + elif arch == 'deepbase_resnet101_dilated8': + if pretrained: + pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth' + orig_resnet = deepbase_resnet101(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) + + elif arch == 'deepbase_resnet101_dilated16': + orig_resnet = deepbase_resnet101(pretrained=pretrained) + arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) + + else: + raise Exception('Architecture undefined!') + + return arch_net diff --git a/networks/resnet_models.py b/networks/resnet_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e5fe1285a0eb657cdc6f865edcf95023db946dcb --- /dev/null +++ b/networks/resnet_models.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Author: Donny You(youansheng@gmail.com) +import math +import torch.nn as nn +from collections import OrderedDict +from .module_helper import ModuleHelper + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/backbones/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/backbones/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/backbones/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/backbones/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/backbones/resnet152-b121ed2d.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers, width_multiplier=1.0, num_classes=1000, deep_base=False, norm_type=None): + super(ResNet, self).__init__() + self.inplanes = 128 if deep_base else int(64 * width_multiplier) + self.width_multiplier = width_multiplier + if deep_base: + self.prefix = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)), + ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)), + ('relu1', nn.ReLU(inplace=False)), + ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)), + ('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)), + ('relu2', nn.ReLU(inplace=False)), + ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)), + ('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)), + ('relu3', nn.ReLU(inplace=False))] + )) + else: + self.prefix = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)), + ('relu', nn.ReLU(inplace=False))] + )) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) # change. + + self.layer1 = self._make_layer(block, int(64 * width_multiplier), layers[0], norm_type=norm_type) + self.layer2 = self._make_layer(block, int(128 * width_multiplier), layers[1], stride=2, norm_type=norm_type) + self.layer3 = self._make_layer(block, int(256 * width_multiplier), layers[2], stride=2, norm_type=norm_type) + self.layer4 = self._make_layer(block, int(512 * width_multiplier), layers[3], stride=2, norm_type=norm_type) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(int(512 * block.expansion * width_multiplier), num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, norm_type=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + ModuleHelper.BatchNorm2d(norm_type=norm_type)(int(planes * block.expansion * self.width_multiplier)), + ) + + layers = [] + layers.append(block(self.inplanes, planes, + stride, downsample, norm_type=norm_type)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_type=norm_type)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.prefix(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + norm_type (str): choose norm type + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type, + width_multiplier=kwargs["width_multiplier"]) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model + + +def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type) + model = ModuleHelper.load_model(model, pretrained=pretrained) + return model diff --git a/networks/timm_deit.py b/networks/timm_deit.py new file mode 100644 index 0000000000000000000000000000000000000000..95754e99af6f3c18ae649e807f9d0ef800d466f5 --- /dev/null +++ b/networks/timm_deit.py @@ -0,0 +1,254 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import math +import torch +import torch.nn as nn +from functools import partial + +from networks.timm_vit import VisionTransformer, _cfg +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_ + + +__all__ = [ + 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', + 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', + 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', + 'deit_base_distilled_patch16_384', +] + + +class DistilledVisionTransformer(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + trunc_normal_(self.dist_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + self.head_dist.apply(self._init_weights) + + def forward_features(self, x): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0], x[:, 1] + + def forward(self, x): + x, x_dist = self.forward_features(x) + x = self.head(x) + x_dist = self.head_dist(x_dist) + if self.training: + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + + def interpolate_pos_encoding(self, x, pos_embed): + """Interpolate the learnable positional encoding to match the number of patches. + + x: B x (1 + 1 + N patches) x dim_embedding + pos_embed: B x (1 + 1 + N patches) x dim_embedding + + return interpolated positional embedding + """ + + npatch = x.shape[1] - 2 # (H // patch_size * W // patch_size) + N = pos_embed.shape[1] - 2 # 784 (= 28 x 28) + + if npatch == N: + return pos_embed + + class_emb, distil_token, pos_embed = pos_embed[:, 0], pos_embed[:, 1], pos_embed[:, 2:] # a learnable CLS token, learnable position embeddings + + dim = x.shape[-1] # dimension of embeddings + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28 + scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer. + recompute_scale_factor=True, + mode='bicubic' + ) + # print("pos_embed", pos_embed.shape, npatch, N, math.sqrt(npatch/N), math.sqrt(npatch/N) * int(math.sqrt(N))) + # exit(12) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + pos_embed = torch.cat((class_emb.unsqueeze(0), distil_token.unsqueeze(0), pos_embed), dim=1) + return pos_embed + + def get_tokens( + self, + x, + layers: list, + patch_tokens: bool = False, + norm: bool = True, + input_tokens: bool = False, + post_pe: bool = False + ): + """Return intermediate tokens.""" + list_tokens: list = [] + + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + dist_token = self.dist_token.expand(B, -1, -1) + + x = torch.cat((cls_tokens, dist_token, x), dim=1) + + if input_tokens: + list_tokens.append(x) + + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + + if post_pe: + list_tokens.append(x) + + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + x = blk(x) # B x # patches x dim + if layers is None or i in layers: + list_tokens.append(self.norm(x) if norm else x) + + tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim + + if not patch_tokens: + return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim + + else: + return torch.cat((tokens[:, :, 0, :].unsqueeze(dim=2), tokens[:, :, 2:, :]), dim=2) # exclude distil token. + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model diff --git a/networks/timm_vit.py b/networks/timm_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..fe8aea2eb1fc17066c0d9b1a6541461dff12a46b --- /dev/null +++ b/networks/timm_vit.py @@ -0,0 +1,819 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg +from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.models.registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), + + # patch models (weights ported from official Google JAX impl) + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_base_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_base_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_base_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + + # patch models, imagenet21k (weights ported from official Google JAX impl) + 'vit_base_patch16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_base_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_huge_patch14_224_in21k': _cfg( + hf_hub='timm/vit_huge_patch14_224_in21k', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + + # deit models (FB weights) + 'vit_deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + 'vit_deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + 'vit_deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), + 'vit_deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + classifier=('head', 'head_dist')), + 'vit_deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + classifier=('head', 'head_dist')), + 'vit_deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + classifier=('head', 'head_dist')), + 'vit_deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + + # ViT ImageNet-21K-P pretraining + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), +} + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, + act_layer=None, weight_init='', + # noel + img_size_eval: int = 224): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + # Weight init + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if weight_init.startswith('jax'): + # leave cls token as zeros to match jax impl + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + # noel + self.depth = depth + self.distilled = distilled + self.patch_size = patch_size + self.patch_embed.img_size = (img_size_eval, img_size_eval) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + x = self.norm(x) + if self.dist_token is None: + return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] + + # def forward(self, x): + # x = self.forward_features(x) + # if self.head_dist is not None: + # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + # if self.training and not torch.jit.is_scripting(): + # # during inference, return the average of both classifier predictions + # return x, x_dist + # else: + # return (x + x_dist) / 2 + # else: + # x = self.head(x) + # return x + + # noel - start + def make_square(self, x: torch.Tensor): + """Pad some pixels to make the input size divisible by the patch size.""" + B, _, H_0, W_0 = x.shape + pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size + pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=x.mean()) + + H_p, W_p = H_0 + pad_h, W_0 + pad_w + x = nn.functional.pad(x, (0, H_p - W_p, 0, 0) if H_p > W_p else (0, 0, 0, W_p - H_p), value=x.mean()) + return x + + def interpolate_pos_encoding(self, x, pos_embed, size): + """Interpolate the learnable positional encoding to match the number of patches. + + x: B x (1 + N patches) x dim_embedding + pos_embed: B x (1 + N patches) x dim_embedding + + return interpolated positional embedding + """ + npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size) + N = pos_embed.shape[1] - 1 # 784 (= 28 x 28) + if npatch == N: + return pos_embed + class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings + + dim = x.shape[-1] # dimension of embeddings + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28 + size=size, + mode='bicubic', + align_corners=False + ) + + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + return pos_embed + + # def interpolate_pos_encoding(self, x, pos_embed): + # """Interpolate the learnable positional encoding to match the number of patches. + # + # x: B x (1 + N patches) x dim_embedding + # pos_embed: B x (1 + N patches) x dim_embedding + # + # return interpolated positional embedding + # """ + # npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size) + # N = pos_embed.shape[1] - 1 # 784 (= 28 x 28) + # if npatch == N: + # return pos_embed + # class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings + # + # dim = x.shape[-1] # dimension of embeddings + # pos_embed = nn.functional.interpolate( + # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28 + # scale_factor=math.sqrt(npatch / N) + 1e-5, # noel: this can be a float, but the output shape will be integer. + # recompute_scale_factor=True, + # mode='bicubic', + # align_corners=False + # ) + # + # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + # return pos_embed + + def prepare_tokens(self, x): + B, nc, h, w = x.shape + patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w)) + return self.pos_drop(x) + + def get_tokens( + self, + x, + layers: list, + patch_tokens: bool = False, + norm: bool = True, + input_tokens: bool = False, + post_pe: bool = False + ): + """Return intermediate tokens.""" + list_tokens: list = [] + + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + x = torch.cat((cls_tokens, x), dim=1) + + if input_tokens: + list_tokens.append(x) + + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + + if post_pe: + list_tokens.append(x) + + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + x = blk(x) # B x # patches x dim + if layers is None or i in layers: + list_tokens.append(self.norm(x) if norm else x) + + tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim + + if not patch_tokens: + return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim + + else: + return tokens + + def forward(self, x, layer: str = None): + x = self.prepare_tokens(x) + + features: dict = {} + for i, blk in enumerate(self.blocks): + x = blk(x) + features[f"layer{i + 1}"] = self.norm(x) + + if layer is not None: + return features[layer] + else: + return features["layer12"] + # noel - end + + +def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(m, nn.Linear): + if n.startswith('head'): + nn.init.zeros_(m.weight) + nn.init.constant_(m.bias, head_bias) + elif n.startswith('pre_logits'): + lecun_normal_(m.weight) + nn.init.zeros_(m.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, std=1e-6) + else: + nn.init.zeros_(m.bias) + else: + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif jax_impl and isinstance(m, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + default_cfg=default_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3. + NOTE: + * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6 + * this model does not have a bias for QKV (unlike the official ViT and DeiT models) + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., + qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) + if pretrained: + # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model + model_kwargs.setdefault('qk_scale', 768 ** -0.5) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: converted weights not currently available, too large for github release hosting. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_base_patch16_224(pretrained=False, **kwargs): + """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_base_patch16_384(pretrained=False, **kwargs): + """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer( + 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer( + 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + return model \ No newline at end of file diff --git a/networks/vision_transformer.py b/networks/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1172f4dd482cd6789ec231902e8fc979f0cd4f --- /dev/null +++ b/networks/vision_transformer.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +from typing import Optional +import math +from functools import partial + +import torch +import torch.nn as nn + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.", + stacklevel=2 + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 # square root of dimension for normalisation + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape # B x (cls token + # patch tokens) x dim + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # qkv: 3 x B x Nh x (cls token + # patch tokens) x (dim // Nh) + + q, k, v = qkv[0], qkv[1], qkv[2] + # q, k, v: B x Nh x (cls token + # patch tokens) x (dim // Nh) + + # q: B x Nh x (cls token + # patch tokens) x (dim // Nh) + # k.transpose(-2, -1) = B x Nh x (dim // Nh) x (cls token + # patch tokens) + # attn: B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens) + attn = (q @ k.transpose(-2, -1)) * self.scale # @ operator is for matrix multiplication + attn = attn.softmax(dim=-1) # B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens) + attn = self.attn_drop(attn) + + # attn = B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens) + # v = B x Nh x (cls token + # patch tokens) x (dim // Nh) + # attn @ v = B x Nh x (cls token + # patch tokens) x (dim // Nh) + # (attn @ v).transpose(1, 2) = B x (cls token + # patch tokens) x Nh x (dim // Nh) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) # B x (cls token + # patch tokens) x dim + x = self.proj(x) # B x (cls token + # patch tokens) x dim + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, + dim, num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding""" + def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) # B x (P_H * P_W) x C + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, + img_size=(224, 224), + patch_size=16, + in_chans=3, + num_classes=0, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=(224, 224), # noel: this is to load pretrained model. + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer + ) for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + self.depth = depth + self.embed_dim = self.n_embs = embed_dim + self.mlp_ratio = mlp_ratio + self.n_heads = num_heads + self.patch_size = patch_size + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor: + """Pad some pixels to make the input size divisible by the patch size.""" + B, _, H_0, W_0 = x.shape + pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size + pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size + + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0) + return x + + def prepare_tokens(self, x): + B, nc, h, w = x.shape + x: torch.Tensor = self.make_input_divisible(x) + patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + + x = self.patch_embed(x) # patch linear embedding + + # add positional encoding to each token + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w)) + return self.pos_drop(x) + + @staticmethod + def split_token(x, token_type: str): + if token_type == "cls": + return x[:, 0, :] + elif token_type == "patch": + return x[:, 1:, :] + else: + return x + + # noel + def forward(self, x, layer: Optional[str] = None): + x: torch.Tensor = self.prepare_tokens(x) + + features: dict = {} + for i, blk in enumerate(self.blocks): + x = blk(x) + features[f"layer{i + 1}"] = self.norm(x) + + if layer is not None: + return features[layer] + else: + return features + + # noel - for DINO's visual + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_tokens( + self, + x, + layers: list, + patch_tokens: bool = False, + norm: bool = True, + input_tokens: bool = False, + post_pe: bool = False + ): + """Return intermediate tokens.""" + list_tokens: list = [] + + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + x = torch.cat((cls_tokens, x), dim=1) + + if input_tokens: + list_tokens.append(x) + + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + + if post_pe: + list_tokens.append(x) + + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + x = blk(x) # B x # patches x dim + if layers is None or i in layers: + list_tokens.append(self.norm(x) if norm else x) + + tokens = torch.stack(list_tokens, dim=1) # B x n_layers x (1 + # patches) x dim + + if not patch_tokens: + return tokens[:, :, 0, :] # index [CLS] tokens only, B x n_layers x dim + + else: + return tokens + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.norm is not None: + x = self.norm(x) + + return x[:, 0] + + def interpolate_pos_encoding(self, x, pos_embed, size): + """Interpolate the learnable positional encoding to match the number of patches. + + x: B x (1 + N patches) x dim_embedding + pos_embed: B x (1 + N patches) x dim_embedding + + return interpolated positional embedding + """ + npatch = x.shape[1] - 1 # (H // patch_size * W // patch_size) + N = pos_embed.shape[1] - 1 # 784 (= 28 x 28) + if npatch == N: + return pos_embed + class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:] # a learnable CLS token, learnable position embeddings + + dim = x.shape[-1] # dimension of embeddings + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), # B x dim x 28 x 28 + size=size, + mode='bicubic', + align_corners=False + ) + + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + return pos_embed + + def forward_selfattention(self, x, return_interm_attn=False): + B, nc, w, h = x.shape + N = self.pos_embed.shape[1] - 1 + x = self.patch_embed(x) + + # interpolate patch embeddings + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic' + ) + if w0 != patch_pos_embed.shape[-2]: + helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device) + patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2) + if h0 != patch_pos_embed.shape[-1]: + helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device) + pos_embed = torch.cat((patch_pos_embed, helper), dim=-1) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + cls_tokens = self.cls_token.expand(B, -1, -1) # self.cls_token: 1 x 1 x emb_dim -> ? + x = torch.cat((cls_tokens, x), dim=1) + x = x + pos_embed + x = self.pos_drop(x) + + if return_interm_attn: + list_attn = [] + for i, blk in enumerate(self.blocks): + attn = blk(x, return_attention=True) + x = blk(x) + list_attn.append(attn) + return torch.cat(list_attn, dim=0) + + else: + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + return blk(x, return_attention=True) + + def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + x = torch.cat((cls_tokens, x), dim=1) + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + x = self.pos_drop(x) + + # we will return the [CLS] tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + # get only CLS token (B x dim) + output.append(self.norm(x)[:, 0]) + if return_patch_avgpool: + x = self.norm(x) + # In addition to the [CLS] tokens from the `n` last blocks, we also return + # the patch tokens from the last block. This is useful for linear eval. + output.append(torch.mean(x[:, 1:], dim=1)) + return torch.cat(output, dim=-1) + + def return_patch_emb_from_n_last_blocks(self, x, n=1, return_patch_avgpool=False): + """Return intermediate patch embeddings, rather than CLS token, from the last n blocks.""" + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + x = torch.cat((cls_tokens, x), dim=1) + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + x = self.pos_drop(x) + + # we will return the [CLS] tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)[:, 1:]) # get only CLS token (B x dim) + + if return_patch_avgpool: + x = self.norm(x) + # In addition to the [CLS] tokens from the `n` last blocks, we also return + # the patch tokens from the last block. This is useful for linear eval. + output.append(torch.mean(x[:, 1:], dim=1)) + return torch.stack(output, dim=-1) # B x n_patches x dim x n + + +def deit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return model + + +def deit_small(patch_size=16, **kwargs): + depth = kwargs.pop("depth") if "depth" in kwargs else 12 + model = VisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=depth, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/resources/.DS_Store b/resources/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..039236944271ae9bfcc54629b05f46e51106b004 Binary files /dev/null and b/resources/.DS_Store differ diff --git a/resources/0053.jpg b/resources/0053.jpg new file mode 100644 index 0000000000000000000000000000000000000000..82abb9e649639af3906936a10cc894769681d6bc Binary files /dev/null and b/resources/0053.jpg differ diff --git a/resources/0236.jpg b/resources/0236.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b2d5c52031370e67008f36ad86f52a1b0788ef1 Binary files /dev/null and b/resources/0236.jpg differ diff --git a/resources/0239.jpg b/resources/0239.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e309e0338648952704230f8a5968bc4b58549e10 Binary files /dev/null and b/resources/0239.jpg differ diff --git a/resources/0403.jpg b/resources/0403.jpg new file mode 100644 index 0000000000000000000000000000000000000000..693807004eb10b07d01b988815c576bf3ee9fb7b Binary files /dev/null and b/resources/0403.jpg differ diff --git a/resources/0412.jpg b/resources/0412.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f38a782ac2d4c93c41346682171e3208f59a8a4 Binary files /dev/null and b/resources/0412.jpg differ diff --git a/resources/ILSVRC2012_test_00005309.jpg b/resources/ILSVRC2012_test_00005309.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3c75f4df56a5d19913c78dbb8f139a9cd03f6e1 Binary files /dev/null and b/resources/ILSVRC2012_test_00005309.jpg differ diff --git a/resources/ILSVRC2012_test_00012622.jpg b/resources/ILSVRC2012_test_00012622.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d3a7481ad690cdc0b77eb92c67c823fb4756452d Binary files /dev/null and b/resources/ILSVRC2012_test_00012622.jpg differ diff --git a/resources/ILSVRC2012_test_00022698.jpg b/resources/ILSVRC2012_test_00022698.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7cb1edc16e45fd5d61d0502ff46f448b43c82b28 Binary files /dev/null and b/resources/ILSVRC2012_test_00022698.jpg differ diff --git a/resources/ILSVRC2012_test_00040725.jpg b/resources/ILSVRC2012_test_00040725.jpg new file mode 100644 index 0000000000000000000000000000000000000000..67c4035e78f6891ddc555d911e82598652bb22db Binary files /dev/null and b/resources/ILSVRC2012_test_00040725.jpg differ diff --git a/resources/ILSVRC2012_test_00075738.jpg b/resources/ILSVRC2012_test_00075738.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7ebdcb5281e963c42c7c23f3013831b765e7005c Binary files /dev/null and b/resources/ILSVRC2012_test_00075738.jpg differ diff --git a/resources/ILSVRC2012_test_00080683.jpg b/resources/ILSVRC2012_test_00080683.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7c4182ed81b25a7efc59c9870cb316a8cede8583 Binary files /dev/null and b/resources/ILSVRC2012_test_00080683.jpg differ diff --git a/resources/ILSVRC2012_test_00085874.jpg b/resources/ILSVRC2012_test_00085874.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4c424066ce390fe186f4757da31ebed0bd2257a Binary files /dev/null and b/resources/ILSVRC2012_test_00085874.jpg differ diff --git a/resources/im052.jpg b/resources/im052.jpg new file mode 100644 index 0000000000000000000000000000000000000000..569b41eda0305ad49b5d9a1f3cc3372c08a771ad Binary files /dev/null and b/resources/im052.jpg differ diff --git a/resources/sun_ainjbonxmervsvpv.jpg b/resources/sun_ainjbonxmervsvpv.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2ba94a5dfdc3dd3c63934de88d58d5625248ad4b Binary files /dev/null and b/resources/sun_ainjbonxmervsvpv.jpg differ diff --git a/resources/sun_alfntqzssslakmss.jpg b/resources/sun_alfntqzssslakmss.jpg new file mode 100644 index 0000000000000000000000000000000000000000..92decf70f2ed58211d2402504e1385fc0f56f879 Binary files /dev/null and b/resources/sun_alfntqzssslakmss.jpg differ diff --git a/resources/sun_amnrcxhisjfrliwa.jpg b/resources/sun_amnrcxhisjfrliwa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..08ad6f5439968c4844cf62432a2123ea943c7ce5 Binary files /dev/null and b/resources/sun_amnrcxhisjfrliwa.jpg differ diff --git a/resources/sun_bvyxpvkouzlfwwod.jpg b/resources/sun_bvyxpvkouzlfwwod.jpg new file mode 100644 index 0000000000000000000000000000000000000000..463012695b73898a8fee6aa9f81805e1fbddc8a8 Binary files /dev/null and b/resources/sun_bvyxpvkouzlfwwod.jpg differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5442d5d5d3452f0782acdf98d23fdc0c970c80e3 --- /dev/null +++ b/utils.py @@ -0,0 +1,87 @@ +from argparse import Namespace +from typing import Optional +import torch + + +def get_model( + arch: str, + patch_size: Optional[int] = None, + training_method: Optional[str] = None, + configs: Optional[Namespace] = None, + **kwargs +): + if arch == "maskformer": + assert configs is not None + from networks.maskformer.maskformer import MaskFormer + model = MaskFormer( + n_queries=configs.n_queries, + n_decoder_layers=configs.n_decoder_layers, + learnable_pixel_decoder=configs.learnable_pixel_decoder, + lateral_connection=configs.lateral_connection, + return_intermediate=configs.loss_every_decoder_layer, + scale_factor=configs.scale_factor, + abs_2d_pe_init=configs.abs_2d_pe_init, + use_binary_classifier=configs.use_binary_classifier, + arch=configs.arch, + training_method=configs.training_method, + patch_size=configs.patch_size + ) + + for n, p in model.encoder.named_parameters(): + p.requires_grad_(True) + + elif "vit" in arch: + import networks.vision_transformer as vits + import networks.timm_deit as timm_deit + if training_method == "dino": + arch = arch.replace("vit", "deit") if arch.find("small") != -1 else arch + model = vits.__dict__[arch](patch_size=patch_size, num_classes=0) + load_model(model, arch, patch_size) + + elif training_method == "deit": + assert patch_size == 16 + model = timm_deit.deit_small_distilled_patch16_224(True) + + elif training_method == "supervised": + assert patch_size == 16 + state_dict: dict = torch.load( + "/users/gyungin/selfmask/networks/pretrained/deit_small_patch16_224-cd65a155.pth" + )["model"] + for k in list(state_dict.keys()): + if k in ["head.weight", "head.bias"]: # classifier head, which is not used in our network + state_dict.pop(k) + + model = get_model(arch="vit_small", patch_size=16, training_method="dino") + model.load_state_dict(state_dict=state_dict, strict=True) + + else: + raise NotImplementedError + print(f"{arch}_p{patch_size}_{training_method} is built.") + + elif arch == "resnet50": + from networks.resnet import ResNet50 + assert training_method in ["mocov2", "swav", "supervised"] + model = ResNet50(training_method) + + else: + raise ValueError(f"{arch} is not supported arch. Choose from [maskformer, resnet50, vit, dino]") + return model + + +def load_model(model, arch: str, patch_size: int) -> None: + url = None + if arch == "deit_small" and patch_size == 16: + url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" + elif arch == "deit_small" and patch_size == 8: + # model used for visualizations in our paper + url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" + elif arch == "vit_base" and patch_size == 16: + url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" + elif arch == "vit_base" and patch_size == 8: + url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" + if url is not None: + print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") + state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) + model.load_state_dict(state_dict, strict=True) + else: + print("There is no reference weights available for this model => We use random weights.") \ No newline at end of file