ariG23498 HF staff commited on
Commit
d2ff88f
β€’
1 Parent(s): 78907b3
app.py CHANGED
@@ -1,18 +1,8 @@
1
- import git
2
-
3
- git_url = "https://github.com/ariG23498/clip_dinoiser.git"
4
- repo_dir = "clip_dinoiser"
5
- git.Repo.clone_from(git_url, repo_dir)
6
-
7
- import os
8
-
9
- print(os.getcwd())
10
- os.chdir("clip_dinoiser/")
11
-
12
  from models.builder import build_model
13
- from utils.visualization import mask2rgb
14
  from segmentation.datasets import PascalVOCDataset
15
 
 
16
  from hydra import compose, initialize
17
  from PIL import Image
18
  import matplotlib.pyplot as plt
@@ -23,6 +13,29 @@ from operator import itemgetter
23
  import torch
24
  import warnings
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  import gradio as gr
27
 
28
  def greet(name):
 
 
 
 
 
 
 
 
 
 
 
 
1
  from models.builder import build_model
2
+ from visualization import mask2rgb
3
  from segmentation.datasets import PascalVOCDataset
4
 
5
+ import os
6
  from hydra import compose, initialize
7
  from PIL import Image
8
  import matplotlib.pyplot as plt
 
13
  import torch
14
  import warnings
15
 
16
+ warnings.filterwarnings("ignore")
17
+ initialize(config_path="configs", version_base=None)
18
+
19
+ from huggingface_hub import Repository
20
+
21
+ repo = Repository(
22
+ local_dir="models",
23
+ clone_from="ariG23498/clip-dinoiser",
24
+ use_auth_token=os.environ.get("token")
25
+ )
26
+
27
+ check_path = 'models/checkpoints/last.pt'
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ check = torch.load(check_path, map_location=device)
31
+ dinoclip_cfg = "clip_dinoiser.yaml"
32
+ cfg = compose(config_name=dinoclip_cfg)
33
+
34
+ model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
35
+ model.clip_backbone.decode_head.use_templates=False # switching off the imagenet templates for fast inference
36
+ model.load_state_dict(check['model_state_dict'], strict=False)
37
+ model = model.eval()
38
+
39
  import gradio as gr
40
 
41
  def greet(name):
clip_dinoiser.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_: "default.yml"
2
+ defaults:
3
+ - _self_
4
+
5
+ seed: 0
6
+ model_name: clip_dinoiser
7
+ model:
8
+ type: CLIP_DINOiser
9
+ clip_backbone: maskclip
10
+ mask_th: 0.2
11
+ in_dim: 256
12
+ certainty_th: 0.9
13
+ found_th: 0.5
14
+ feats_idx: -3
15
+
16
+ checkpoint_path: "checkpoints/last.pt"
17
+ output: logs
18
+
19
+ evaluate:
20
+ eval_only: true
21
+ task:
22
+ - voc
23
+ - voc20
24
+ - context
25
+ - context59
26
+ - coco_stuff
27
+ - coco_object
28
+ - cityscapes
29
+ - ade20k
30
+
31
+ # evaluation
32
+ voc: segmentation/configs/_base_/datasets/pascal_voc12.py
33
+ voc20: segmentation/configs/_base_/datasets/pascal_voc12_20.py
34
+ context: segmentation/configs/_base_/datasets/pascal_context.py
35
+ context59: segmentation/configs/_base_/datasets/pascal_context59.py
36
+ coco_stuff: segmentation/configs/_base_/datasets/stuff.py
37
+ coco_object: segmentation/configs/_base_/datasets/coco.py
38
+ cityscapes: segmentation/configs/_base_/datasets/cityscapes.py
39
+ ade20k: segmentation/configs/_base_/datasets/ade20k.py
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .maskclip import *
2
+ from .clip_dinoiser import *
models/builder.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # author: Monika Wysoczanska
4
+ # ------------------------------------------------------------------------------
5
+ # Modified from GroupViT (https://github.com/NVlabs/GroupViT)
6
+ # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ from mmcv.utils import Registry
9
+ MODELS = Registry('models')
10
+ from omegaconf import OmegaConf
11
+
12
+
13
+ def build_model(config, class_names):
14
+ model = MODELS.build(OmegaConf.to_container(config, resolve=True),
15
+ default_args={'class_names': class_names})
16
+ return model
models/clip_dinoiser/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip_dinoiser import *
models/clip_dinoiser/clip_dinoiser.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology & Oriane Simeoni, valeo.ai
4
+ # ---------------------------------------------------------------------------------------------------
5
+ import torch.nn as nn
6
+ from models.builder import MODELS
7
+ from models.builder import build_model
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from omegaconf import OmegaConf
11
+ import torch.nn.functional as F
12
+
13
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
14
+
15
+
16
+ @MODELS.register_module()
17
+ class CLIP_DINOiser(nn.Module):
18
+ def __init__(self, clip_backbone, class_names, mask_th=None, found_th=0.5, certainty_th=0.9, apply_found=False,
19
+ in_dim=256, conv_kernel=3, feats_idx=-3):
20
+
21
+ super(CLIP_DINOiser, self).__init__()
22
+ self.mask_th = mask_th
23
+ self.apply_found = apply_found
24
+ self.found_th = found_th
25
+ self.certainty_th = certainty_th
26
+ self.sigmoid = nn.Sigmoid()
27
+ maskclip_cfg = OmegaConf.load(f"configs/{clip_backbone}.yaml")
28
+ self.clip_backbone = build_model(maskclip_cfg["model"], class_names=class_names)
29
+ self.vit_patch_size = self.clip_backbone.patch_size
30
+ self.feats_idx = feats_idx
31
+ self.in_dim = [in_dim]
32
+ in_size = 768 if self.feats_idx != 'final' else 512
33
+ self.bkg_decoder = nn.Conv2d(in_size, 1, (1, 1))
34
+ self.obj_proj = nn.Conv2d(in_size, in_dim, (conv_kernel, conv_kernel),
35
+ padding=conv_kernel // 2, padding_mode='replicate')
36
+
37
+ # setup clip feature for training
38
+ if feats_idx != 'final':
39
+ train_feats = {}
40
+ def get_activation(name):
41
+ def hook(model, input, output):
42
+ train_feats[name] = output.detach()
43
+ return hook
44
+ self.clip_backbone.backbone.layers[feats_idx].ln2.register_forward_hook(get_activation('clip_inter'))
45
+ self.train_feats = train_feats
46
+
47
+
48
+ def forward_pass(self, x):
49
+ clip_feats = self.get_clip_map(x)[0]
50
+ B, c_dim, h, w = clip_feats.shape
51
+ _, _, H, W = x.shape
52
+ if self.feats_idx != 'final':
53
+ clip_feats = self.train_feats['clip_inter']
54
+ c_dim = clip_feats.shape[-1]
55
+ clip_feats = clip_feats[:, 1:, ].permute(0, 2, 1).reshape(B, c_dim, h, w)
56
+
57
+ proj_feats = self.obj_proj(clip_feats).reshape(B, self.in_dim[-1], -1)
58
+ proj_feats = proj_feats / proj_feats.norm(dim=1, keepdim=True)
59
+
60
+ corrs = torch.matmul(proj_feats.permute(0, 2, 1), proj_feats).reshape(B,h*w, h, w)
61
+ output = clip_feats / clip_feats.norm(dim=1, keepdim=True)
62
+ output = self.bkg_decoder(output)
63
+
64
+ return output, corrs
65
+
66
+ def forward(self, x):
67
+ preds, corrs = self.forward_pass(x)
68
+ output, _, _ = self.get_clip_map(x)
69
+ B, C, hf, wf = output.shape
70
+ preds = F.interpolate(preds, (hf, wf), mode="bilinear", align_corners=False )
71
+
72
+ # Compute weighted pooling
73
+ if self.mask_th:
74
+ corrs[corrs < self.mask_th] = 0.0
75
+ output = self.compute_weighted_pool(output, corrs)
76
+ output = output.reshape(B, C, hf, wf)
77
+ output = self.clip_backbone.decode_head.cls_seg(output)
78
+
79
+ if self.apply_found:
80
+ # Compute FOUND --------------------------------------------------
81
+ soft_found = self.sigmoid(preds.detach())
82
+ r_soft_found = soft_found.reshape(-1)
83
+ nb_cls = output.shape[1]
84
+ r_hard_found = (r_soft_found > self.found_th).float()
85
+
86
+ # TODO: make it work for Batch Size != 1
87
+ uncertain = (output.max(dim=1)[0] < self.certainty_th).reshape(-1)
88
+ output.reshape(1, nb_cls, -1)[:, 0, uncertain & (~r_hard_found.bool())] = 1.0 # background class
89
+
90
+ return output
91
+
92
+ def predict(self, x):
93
+ return self(x)
94
+
95
+ @torch.no_grad()
96
+ def get_clip_map(self, img):
97
+ maskclip_map, feat, k = self.clip_backbone(img, return_feat=True)
98
+
99
+ return feat, k, maskclip_map
100
+
101
+ @torch.no_grad()
102
+ def compute_weighted_pool(self, clipmap, corrs):
103
+ # upsampling
104
+ B = clipmap.shape[0]
105
+ h_m, w_m = clipmap.shape[-2:]
106
+ h_w, w_w = corrs.shape[-2:]
107
+
108
+ if (h_m != h_w) or (w_m != w_w):
109
+ clipmap = F.interpolate(clipmap, (h_w, w_w), mode="bilinear", align_corners=False )
110
+ h_m, w_m = h_w, w_w
111
+
112
+ corrs[corrs < 0.0] = 0.0 # B HW H W
113
+ clipmap_refined = torch.einsum("bnij, bcij -> bcn", corrs, clipmap) # B C HW
114
+ norm_factor = corrs.flatten(-2, -1).sum(dim=-1)[:, None] # B 1 HW
115
+ clipmap_refined = clipmap_refined / (norm_factor + 1e-6)
116
+
117
+ # RESHAPE back to 2d
118
+ clipmap_refined = clipmap_refined.reshape(B, -1, h_m, w_m)
119
+
120
+ return clipmap_refined
models/maskclip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .maskclip import *
models/maskclip/maskclip.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # author: Monika Wysoczanska, Warsaw University of Technology
4
+ # ------------------------------------------------------------------------------
5
+ # Modified from OpenMMLab https://github.com/chongzhou96/MaskCLIP
6
+ # ------------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from mmseg.ops import resize
12
+ from typing import Any, List
13
+ from torch import Tensor
14
+ from mmcv.utils import print_log
15
+ from mmseg.utils import get_root_logger
16
+ from open_clip import get_tokenizer, create_model_from_pretrained
17
+ from models.builder import MODELS
18
+ from .vit import VisionTransformer
19
+ import torchvision.transforms as T
20
+ from .utils.embed import AdaptivePadding
21
+ from .utils.prompt_templates import imagenet_templates
22
+
23
+ OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
24
+
25
+
26
+ def make_vision_transformer(backbone_cfg):
27
+ model = VisionTransformer(**backbone_cfg)
28
+ model.init_weights()
29
+ return model
30
+
31
+
32
+ @MODELS.register_module()
33
+ class MaskClip(nn.Module):
34
+ def __init__(
35
+ self,
36
+ backbone,
37
+ decode_head,
38
+ clip_model,
39
+ class_names
40
+ ):
41
+ super(MaskClip, self).__init__()
42
+
43
+ self.decode_head = eval(decode_head.get('type'))(clip_model, class_names, **decode_head)
44
+ self.backbone = make_vision_transformer(backbone)
45
+ self.clip_T = OPENAI_NORMALIZE
46
+
47
+ self.to_PIL = T.ToPILImage()
48
+ self.patch_size = backbone.get('patch_size')
49
+ self.padding = AdaptivePadding(self.patch_size, self.patch_size)
50
+
51
+ def extract_feat(self, inputs: Tensor) -> List[Tensor]:
52
+ """Extract features from images."""
53
+ x = self.backbone(inputs)
54
+ return x
55
+
56
+ def forward(self, inputs: Tensor, return_feat=False) -> Tensor:
57
+ """Encode images with backbone and decode into a semantic segmentation
58
+ map of the same size as input."""
59
+ inputs = self.clip_T(inputs)
60
+ x = self.extract_feat(inputs)
61
+
62
+ seg_logits, feats, k = self.decode_head(x, return_feat)
63
+
64
+ if return_feat:
65
+ return seg_logits, feats, k
66
+ return seg_logits
67
+
68
+ class MaskClipHead(nn.Module):
69
+ def __init__(self, clip_model, class_names, visual_projs_path=None, in_index=-1, in_channels=3, norm_cfg=None, channels=0,
70
+ text_channels=512, attn_pooling=False, align_corners=False, model_prefix='hf-hub:laion', use_templates=False, **kwargs):
71
+ super(MaskClipHead, self).__init__()
72
+
73
+ self.text_channels = text_channels
74
+ self.visual_projs_path = visual_projs_path
75
+ self.clip_model = clip_model
76
+ self.class_names = class_names
77
+ self.in_channels = in_channels
78
+ self.in_index = in_index # from base decode head default
79
+ self._init_inputs(in_channels, in_index, None)
80
+ self.channels = channels
81
+ self.norm_cfg = norm_cfg
82
+ self.align_corners = align_corners
83
+ self.use_templates = use_templates
84
+
85
+ self.proj = nn.Conv2d(self.in_channels, text_channels, 1, bias=False)
86
+ self.load_visual_projs()
87
+
88
+ self.attn_pooling = attn_pooling
89
+ self.tokenizer = get_tokenizer(f'{model_prefix}/{clip_model}')
90
+ self.hf_modelname = f'{model_prefix}/{clip_model}'
91
+ model, _ = create_model_from_pretrained(f'{model_prefix}/{clip_model}')
92
+ model.eval()
93
+ self.register_buffer("class_embeddings", self._get_class_embeddings(model, class_names))
94
+
95
+ @torch.no_grad()
96
+ def update_vocab(self, class_names):
97
+ model, _ = create_model_from_pretrained(self.hf_modelname)
98
+ model.eval()
99
+ self.class_embeddings = self._get_class_embeddings(model, class_names)
100
+
101
+ @torch.no_grad()
102
+ def _embed_label(self, text_model: torch.nn.Module, label: str) -> torch.Tensor:
103
+ """
104
+ Encode label name into a single vector
105
+ """
106
+ if self.use_templates:
107
+ templates = imagenet_templates
108
+ else:
109
+ templates = ['a photo of an {}' if label.startswith('aeiou') else 'a photo of a {}']
110
+
111
+ all_prompts = [self.tokenizer(template.format(label)) for template in templates]
112
+ out = text_model.encode_text(torch.cat(all_prompts))
113
+ out /= out.norm(dim=-1, keepdim=True)
114
+ out = out.mean(dim=0)
115
+ return out
116
+
117
+ def _get_class_embeddings(self, text_model: torch.nn.Module, class_names: List[str]):
118
+ aug_embeddings = torch.stack([self._embed_label(text_model, label) for label in class_names])
119
+ # normalize vector
120
+ aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
121
+ return aug_embeddings.squeeze(1)
122
+
123
+ def load_visual_projs(self):
124
+ loaded = torch.load(self.visual_projs_path, map_location='cuda')
125
+ attrs = ['proj']
126
+ for attr in attrs:
127
+ current_attr = getattr(self, attr)
128
+ state_dict = loaded[attr]
129
+ for key in state_dict:
130
+ if 'weight' in key:
131
+ state_dict[key] = state_dict[key][:, :, None, None]
132
+ current_attr.load_state_dict(state_dict)
133
+ print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger())
134
+
135
+ def forward(self, inputs, return_feat=False):
136
+ x = self._transform_inputs(inputs)
137
+ q, k, v, cls_token = None, None, None, None
138
+ if isinstance(x, list) and len(x) == 4:
139
+ x, q, k, v = x
140
+ if isinstance(x, list) and len(x) == 2:
141
+ x, cls_token = x
142
+ if v is not None:
143
+ feat = self.proj(v)
144
+ else:
145
+ feat = self.proj(x)
146
+ output = self.cls_seg(feat)
147
+ if return_feat:
148
+ return output, feat, k
149
+
150
+ return output
151
+
152
+ def _init_inputs(self, in_channels, in_index, input_transform):
153
+ """Check and initialize input transforms.
154
+
155
+ The in_channels, in_index and input_transform must match.
156
+ Specifically, when input_transform is None, only single feature map
157
+ will be selected. So in_channels and in_index must be of type int.
158
+ When input_transform
159
+
160
+ Args:
161
+ in_channels (int|Sequence[int]): Input channels.
162
+ in_index (int|Sequence[int]): Input feature index.
163
+ input_transform (str|None): Transformation type of input features.
164
+ Options: 'resize_concat', 'multiple_select', None.
165
+ 'resize_concat': Multiple feature maps will be resize to the
166
+ same size as first one and than concat together.
167
+ Usually used in FCN head of HRNet.
168
+ 'multiple_select': Multiple feature maps will be bundle into
169
+ a list and passed into decode head.
170
+ None: Only one select feature map is allowed.
171
+ """
172
+
173
+ if input_transform is not None:
174
+ assert input_transform in ['resize_concat', 'multiple_select']
175
+ self.input_transform = input_transform
176
+ self.in_index = in_index
177
+ if input_transform is not None:
178
+ assert isinstance(in_channels, (list, tuple))
179
+ assert isinstance(in_index, (list, tuple))
180
+ assert len(in_channels) == len(in_index)
181
+ if input_transform == 'resize_concat':
182
+ self.in_channels = sum(in_channels)
183
+ else:
184
+ self.in_channels = in_channels
185
+ else:
186
+ assert isinstance(in_channels, int)
187
+ assert isinstance(in_index, int)
188
+ self.in_channels = in_channels
189
+
190
+ def cls_seg(self, feat):
191
+ feat = feat / feat.norm(dim=1, keepdim=True)
192
+ output = F.conv2d(feat, self.class_embeddings[:, :, None, None])
193
+ output = F.softmax(output * 100, dim=1) # softmax of similarities with temp scaling
194
+
195
+ return output
196
+
197
+ def _transform_inputs(self, inputs):
198
+ """Transform inputs for decoder.
199
+
200
+ Args:
201
+ inputs (list[Tensor]): List of multi-level img features.
202
+
203
+ Returns:
204
+ Tensor: The transformed inputs
205
+ """
206
+ if self.input_transform == 'resize_concat':
207
+ inputs = [inputs[i] for i in self.in_index]
208
+ upsampled_inputs = [
209
+ resize(
210
+ input=x,
211
+ size=inputs[0].shape[2:],
212
+ mode='bilinear',
213
+ align_corners=self.align_corners) for x in inputs
214
+ ]
215
+ inputs = torch.cat(upsampled_inputs, dim=1)
216
+ elif self.input_transform == 'multiple_select':
217
+ inputs = [inputs[i] for i in self.in_index]
218
+ else:
219
+ inputs = inputs[self.in_index]
220
+
221
+ return inputs
models/maskclip/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .embed import PatchEmbed
2
+ from .prompt_templates import imagenet_templates
3
+
4
+ __all__ = ['PatchEmbed', 'imagenet_templates']
models/maskclip/utils/embed.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From OpenMMLab https://github.com/chongzhou96/MaskCLIP
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import math
7
+ from typing import Sequence
8
+
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from mmcv.cnn import build_conv_layer, build_norm_layer
12
+ from mmcv.runner.base_module import BaseModule
13
+ from mmcv.utils import to_2tuple
14
+
15
+
16
+ class AdaptivePadding(nn.Module):
17
+ """Applies padding to input (if needed) so that input can get fully covered
18
+ by filter you specified. It support two modes "same" and "corner". The
19
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
20
+ input. The "corner" mode would pad zero to bottom right.
21
+
22
+ Args:
23
+ kernel_size (int | tuple): Size of the kernel:
24
+ stride (int | tuple): Stride of the filter. Default: 1:
25
+ dilation (int | tuple): Spacing between kernel elements.
26
+ Default: 1.
27
+ padding (str): Support "same" and "corner", "corner" mode
28
+ would pad zero to bottom right, and "same" mode would
29
+ pad zero around input. Default: "corner".
30
+ Example:
31
+ >>> kernel_size = 16
32
+ >>> stride = 16
33
+ >>> dilation = 1
34
+ >>> input = torch.rand(1, 1, 15, 17)
35
+ >>> adap_pad = AdaptivePadding(
36
+ >>> kernel_size=kernel_size,
37
+ >>> stride=stride,
38
+ >>> dilation=dilation,
39
+ >>> padding="corner")
40
+ >>> out = adap_pad(input)
41
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
42
+ >>> input = torch.rand(1, 1, 16, 17)
43
+ >>> out = adap_pad(input)
44
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
45
+ """
46
+
47
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
48
+
49
+ super(AdaptivePadding, self).__init__()
50
+
51
+ assert padding in ('same', 'corner')
52
+
53
+ kernel_size = to_2tuple(kernel_size)
54
+ stride = to_2tuple(stride)
55
+ dilation = to_2tuple(dilation)
56
+
57
+ self.padding = padding
58
+ self.kernel_size = kernel_size
59
+ self.stride = stride
60
+ self.dilation = dilation
61
+
62
+ def get_pad_shape(self, input_shape):
63
+ input_h, input_w = input_shape
64
+ kernel_h, kernel_w = self.kernel_size
65
+ stride_h, stride_w = self.stride
66
+ output_h = math.ceil(input_h / stride_h)
67
+ output_w = math.ceil(input_w / stride_w)
68
+ pad_h = max((output_h - 1) * stride_h +
69
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
70
+ pad_w = max((output_w - 1) * stride_w +
71
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
72
+ return pad_h, pad_w
73
+
74
+ def forward(self, x):
75
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
76
+ if pad_h > 0 or pad_w > 0:
77
+ if self.padding == 'corner':
78
+ x = F.pad(x, [0, pad_w, 0, pad_h])
79
+ elif self.padding == 'same':
80
+ x = F.pad(x, [
81
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
82
+ pad_h - pad_h // 2
83
+ ])
84
+ return x
85
+
86
+
87
+ class PatchEmbed(BaseModule):
88
+ """Image to Patch Embedding.
89
+
90
+ We use a conv layer to implement PatchEmbed.
91
+
92
+ Args:
93
+ in_channels (int): The num of input channels. Default: 3
94
+ embed_dims (int): The dimensions of embedding. Default: 768
95
+ conv_type (str): The config dict for embedding
96
+ conv layer type selection. Default: "Conv2d".
97
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
98
+ stride (int, optional): The slide stride of embedding conv.
99
+ Default: None (Would be set as `kernel_size`).
100
+ padding (int | tuple | string ): The padding length of
101
+ embedding conv. When it is a string, it means the mode
102
+ of adaptive padding, support "same" and "corner" now.
103
+ Default: "corner".
104
+ dilation (int): The dilation rate of embedding conv. Default: 1.
105
+ bias (bool): Bias of embed conv. Default: True.
106
+ norm_cfg (dict, optional): Config dict for normalization layer.
107
+ Default: None.
108
+ input_size (int | tuple | None): The size of input, which will be
109
+ used to calculate the out size. Only work when `dynamic_size`
110
+ is False. Default: None.
111
+ init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
112
+ Default: None.
113
+ """
114
+
115
+ def __init__(self,
116
+ in_channels=3,
117
+ embed_dims=768,
118
+ conv_type='Conv2d',
119
+ kernel_size=16,
120
+ stride=None,
121
+ padding='corner',
122
+ dilation=1,
123
+ bias=True,
124
+ norm_cfg=None,
125
+ input_size=None,
126
+ init_cfg=None):
127
+ super(PatchEmbed, self).__init__(init_cfg=init_cfg)
128
+
129
+ self.embed_dims = embed_dims
130
+ if stride is None:
131
+ stride = kernel_size
132
+
133
+ kernel_size = to_2tuple(kernel_size)
134
+ stride = to_2tuple(stride)
135
+ dilation = to_2tuple(dilation)
136
+
137
+ if isinstance(padding, str):
138
+ self.adap_padding = AdaptivePadding(
139
+ kernel_size=kernel_size,
140
+ stride=stride,
141
+ dilation=dilation,
142
+ padding=padding)
143
+ # disable the padding of conv
144
+ padding = 0
145
+ else:
146
+ self.adap_padding = None
147
+ padding = to_2tuple(padding)
148
+
149
+ self.projection = build_conv_layer(
150
+ dict(type=conv_type),
151
+ in_channels=in_channels,
152
+ out_channels=embed_dims,
153
+ kernel_size=kernel_size,
154
+ stride=stride,
155
+ padding=padding,
156
+ dilation=dilation,
157
+ bias=bias)
158
+
159
+ if norm_cfg is not None:
160
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
161
+ else:
162
+ self.norm = None
163
+
164
+ if input_size:
165
+ input_size = to_2tuple(input_size)
166
+ # `init_out_size` would be used outside to
167
+ # calculate the num_patches
168
+ # when `use_abs_pos_embed` outside
169
+ self.init_input_size = input_size
170
+ if self.adap_padding:
171
+ pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
172
+ input_h, input_w = input_size
173
+ input_h = input_h + pad_h
174
+ input_w = input_w + pad_w
175
+ input_size = (input_h, input_w)
176
+
177
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
178
+ h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
179
+ (kernel_size[0] - 1) - 1) // stride[0] + 1
180
+ w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
181
+ (kernel_size[1] - 1) - 1) // stride[1] + 1
182
+ self.init_out_size = (h_out, w_out)
183
+ else:
184
+ self.init_input_size = None
185
+ self.init_out_size = None
186
+
187
+ def forward(self, x):
188
+ """
189
+ Args:
190
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
191
+
192
+ Returns:
193
+ tuple: Contains merged results and its spatial shape.
194
+
195
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
196
+ - out_size (tuple[int]): Spatial shape of x, arrange as
197
+ (out_h, out_w).
198
+ """
199
+
200
+ if self.adap_padding:
201
+ x = self.adap_padding(x)
202
+
203
+ x = self.projection(x)
204
+ out_size = (x.shape[2], x.shape[3])
205
+ x = x.flatten(2).transpose(1, 2)
206
+ if self.norm is not None:
207
+ x = self.norm(x)
208
+ return x, out_size
209
+
210
+
211
+ class PatchMerging(BaseModule):
212
+ """Merge patch feature map.
213
+
214
+ This layer groups feature map by kernel_size, and applies norm and linear
215
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
216
+ merge patch, which is about 25% faster than original implementation.
217
+ Instead, we need to modify pretrained models for compatibility.
218
+
219
+ Args:
220
+ in_channels (int): The num of input channels.
221
+ out_channels (int): The num of output channels.
222
+ kernel_size (int | tuple, optional): the kernel size in the unfold
223
+ layer. Defaults to 2.
224
+ stride (int | tuple, optional): the stride of the sliding blocks in the
225
+ unfold layer. Default: None. (Would be set as `kernel_size`)
226
+ padding (int | tuple | string ): The padding length of
227
+ embedding conv. When it is a string, it means the mode
228
+ of adaptive padding, support "same" and "corner" now.
229
+ Default: "corner".
230
+ dilation (int | tuple, optional): dilation parameter in the unfold
231
+ layer. Default: 1.
232
+ bias (bool, optional): Whether to add bias in linear layer or not.
233
+ Defaults: False.
234
+ norm_cfg (dict, optional): Config dict for normalization layer.
235
+ Default: dict(type='LN').
236
+ init_cfg (dict, optional): The extra config for initialization.
237
+ Default: None.
238
+ """
239
+
240
+ def __init__(self,
241
+ in_channels,
242
+ out_channels,
243
+ kernel_size=2,
244
+ stride=None,
245
+ padding='corner',
246
+ dilation=1,
247
+ bias=False,
248
+ norm_cfg=dict(type='LN'),
249
+ init_cfg=None):
250
+ super().__init__(init_cfg=init_cfg)
251
+ self.in_channels = in_channels
252
+ self.out_channels = out_channels
253
+ if stride:
254
+ stride = stride
255
+ else:
256
+ stride = kernel_size
257
+
258
+ kernel_size = to_2tuple(kernel_size)
259
+ stride = to_2tuple(stride)
260
+ dilation = to_2tuple(dilation)
261
+
262
+ if isinstance(padding, str):
263
+ self.adap_padding = AdaptivePadding(
264
+ kernel_size=kernel_size,
265
+ stride=stride,
266
+ dilation=dilation,
267
+ padding=padding)
268
+ # disable the padding of unfold
269
+ padding = 0
270
+ else:
271
+ self.adap_padding = None
272
+
273
+ padding = to_2tuple(padding)
274
+ self.sampler = nn.Unfold(
275
+ kernel_size=kernel_size,
276
+ dilation=dilation,
277
+ padding=padding,
278
+ stride=stride)
279
+
280
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
281
+
282
+ if norm_cfg is not None:
283
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
284
+ else:
285
+ self.norm = None
286
+
287
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
288
+
289
+ def forward(self, x, input_size):
290
+ """
291
+ Args:
292
+ x (Tensor): Has shape (B, H*W, C_in).
293
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
294
+ Default: None.
295
+
296
+ Returns:
297
+ tuple: Contains merged results and its spatial shape.
298
+
299
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
300
+ - out_size (tuple[int]): Spatial shape of x, arrange as
301
+ (Merged_H, Merged_W).
302
+ """
303
+ B, L, C = x.shape
304
+ assert isinstance(input_size, Sequence), f'Expect ' \
305
+ f'input_size is ' \
306
+ f'`Sequence` ' \
307
+ f'but get {input_size}'
308
+
309
+ H, W = input_size
310
+ assert L == H * W, 'input feature has wrong size'
311
+
312
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
313
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
314
+ # but need to modify pretrained model for compatibility
315
+
316
+ if self.adap_padding:
317
+ x = self.adap_padding(x)
318
+ H, W = x.shape[-2:]
319
+
320
+ x = self.sampler(x)
321
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
322
+
323
+ out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
324
+ (self.sampler.kernel_size[0] - 1) -
325
+ 1) // self.sampler.stride[0] + 1
326
+ out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
327
+ (self.sampler.kernel_size[1] - 1) -
328
+ 1) // self.sampler.stride[1] + 1
329
+
330
+ output_size = (out_h, out_w)
331
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
332
+ x = self.norm(x) if self.norm else x
333
+ x = self.reduction(x)
334
+ return x, output_size
models/maskclip/utils/prompt_templates.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_templates = [
2
+ 'a bad photo of a {}.',
3
+ 'a photo of many {}.',
4
+ 'a sculpture of a {}.',
5
+ 'a photo of the hard to see {}.',
6
+ 'a low resolution photo of the {}.',
7
+ 'a rendering of a {}.',
8
+ 'graffiti of a {}.',
9
+ 'a bad photo of the {}.',
10
+ 'a cropped photo of the {}.',
11
+ 'a tattoo of a {}.',
12
+ 'the embroidered {}.',
13
+ 'a photo of a hard to see {}.',
14
+ 'a bright photo of a {}.',
15
+ 'a photo of a clean {}.',
16
+ 'a photo of a dirty {}.',
17
+ 'a dark photo of the {}.',
18
+ 'a drawing of a {}.',
19
+ 'a photo of my {}.',
20
+ 'the plastic {}.',
21
+ 'a photo of the cool {}.',
22
+ 'a close-up photo of a {}.',
23
+ 'a black and white photo of the {}.',
24
+ 'a painting of the {}.',
25
+ 'a painting of a {}.',
26
+ 'a pixelated photo of the {}.',
27
+ 'a sculpture of the {}.',
28
+ 'a bright photo of the {}.',
29
+ 'a cropped photo of a {}.',
30
+ 'a plastic {}.',
31
+ 'a photo of the dirty {}.',
32
+ 'a jpeg corrupted photo of a {}.',
33
+ 'a blurry photo of the {}.',
34
+ 'a photo of the {}.',
35
+ 'a good photo of the {}.',
36
+ 'a rendering of the {}.',
37
+ 'a {} in a video game.',
38
+ 'a photo of one {}.',
39
+ 'a doodle of a {}.',
40
+ 'a close-up photo of the {}.',
41
+ 'a photo of a {}.',
42
+ 'the origami {}.',
43
+ 'the {} in a video game.',
44
+ 'a sketch of a {}.',
45
+ 'a doodle of the {}.',
46
+ 'a origami {}.',
47
+ 'a low resolution photo of a {}.',
48
+ 'the toy {}.',
49
+ 'a rendition of the {}.',
50
+ 'a photo of the clean {}.',
51
+ 'a photo of a large {}.',
52
+ 'a rendition of a {}.',
53
+ 'a photo of a nice {}.',
54
+ 'a photo of a weird {}.',
55
+ 'a blurry photo of a {}.',
56
+ 'a cartoon {}.',
57
+ 'art of a {}.',
58
+ 'a sketch of the {}.',
59
+ 'a embroidered {}.',
60
+ 'a pixelated photo of a {}.',
61
+ 'itap of the {}.',
62
+ 'a jpeg corrupted photo of the {}.',
63
+ 'a good photo of a {}.',
64
+ 'a plushie {}.',
65
+ 'a photo of the nice {}.',
66
+ 'a photo of the small {}.',
67
+ 'a photo of the weird {}.',
68
+ 'the cartoon {}.',
69
+ 'art of the {}.',
70
+ 'a drawing of the {}.',
71
+ 'a photo of the large {}.',
72
+ 'a black and white photo of a {}.',
73
+ 'the plushie {}.',
74
+ 'a dark photo of a {}.',
75
+ 'itap of a {}.',
76
+ 'graffiti of the {}.',
77
+ 'a toy {}.',
78
+ 'itap of my {}.',
79
+ 'a photo of a cool {}.',
80
+ 'a photo of a small {}.',
81
+ 'a tattoo of the {}.',
82
+ ]
models/maskclip/vit.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmcv.cnn import build_norm_layer
8
+ from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
9
+ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
10
+ trunc_normal_)
11
+ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
12
+ from torch.nn.modules.batchnorm import _BatchNorm
13
+ from torch.nn.modules.utils import _pair as to_2tuple
14
+ import torch.nn.functional as F
15
+
16
+ from mmseg.ops import resize
17
+ from mmseg.utils import get_root_logger
18
+
19
+ from models.maskclip.utils import PatchEmbed
20
+
21
+
22
+ class TransformerEncoderLayer(BaseModule):
23
+ """Implements one encoder layer in Vision Transformer.
24
+
25
+ Args:
26
+ embed_dims (int): The feature dimension.
27
+ num_heads (int): Parallel attention heads.
28
+ feedforward_channels (int): The hidden dimension for FFNs.
29
+ drop_rate (float): Probability of an element to be zeroed
30
+ after the feed forward layer. Default: 0.0.
31
+ attn_drop_rate (float): The drop out rate for attention layer.
32
+ Default: 0.0.
33
+ drop_path_rate (float): stochastic depth rate. Default 0.0.
34
+ num_fcs (int): The number of fully-connected layers for FFNs.
35
+ Default: 2.
36
+ qkv_bias (bool): enable bias for qkv if True. Default: True
37
+ act_cfg (dict): The activation config for FFNs.
38
+ Default: dict(type='GELU').
39
+ norm_cfg (dict): Config dict for normalization layer.
40
+ Default: dict(type='LN').
41
+ batch_first (bool): Key, Query and Value are shape of
42
+ (batch, n, embed_dim)
43
+ or (n, batch, embed_dim). Default: True.
44
+ """
45
+
46
+ def __init__(self,
47
+ embed_dims,
48
+ num_heads,
49
+ feedforward_channels,
50
+ drop_rate=0.,
51
+ attn_drop_rate=0.,
52
+ drop_path_rate=0.,
53
+ num_fcs=2,
54
+ qkv_bias=True,
55
+ act_cfg=dict(type='GELU'),
56
+ norm_cfg=dict(type='LN'),
57
+ batch_first=True):
58
+ super(TransformerEncoderLayer, self).__init__()
59
+
60
+ self.norm1_name, norm1 = build_norm_layer(
61
+ norm_cfg, embed_dims, postfix=1)
62
+ self.add_module(self.norm1_name, norm1)
63
+
64
+ self.attn = MultiheadAttention(
65
+ embed_dims=embed_dims,
66
+ num_heads=num_heads,
67
+ attn_drop=attn_drop_rate,
68
+ proj_drop=drop_rate,
69
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
70
+ batch_first=batch_first,
71
+ bias=qkv_bias)
72
+
73
+ self.norm2_name, norm2 = build_norm_layer(
74
+ norm_cfg, embed_dims, postfix=2)
75
+ self.add_module(self.norm2_name, norm2)
76
+
77
+ self.ffn = FFN(
78
+ embed_dims=embed_dims,
79
+ feedforward_channels=feedforward_channels,
80
+ num_fcs=num_fcs,
81
+ ffn_drop=drop_rate,
82
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
83
+ act_cfg=act_cfg)
84
+
85
+ @property
86
+ def norm1(self):
87
+ return getattr(self, self.norm1_name)
88
+
89
+ @property
90
+ def norm2(self):
91
+ return getattr(self, self.norm2_name)
92
+
93
+ def forward(self, x, return_qkv=False):
94
+ q, k, v = None, None, None
95
+ if return_qkv:
96
+ y = self.norm1(x)
97
+ y = F.linear(y, self.attn.attn.in_proj_weight, self.attn.attn.in_proj_bias)
98
+ N, L, C = y.shape
99
+ y = y.view(N, L, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * N, L, C // 3)
100
+ y = F.linear(y, self.attn.attn.out_proj.weight, self.attn.attn.out_proj.bias)
101
+ q, k, v = y.tensor_split(3, dim=0)
102
+ v += x
103
+ v = self.ffn(self.norm2(v), identity=v)
104
+
105
+ x = self.attn(self.norm1(x), identity=x)
106
+ x = self.ffn(self.norm2(x), identity=x)
107
+ return x, q, k, v
108
+
109
+
110
+ class VisionTransformer(BaseModule):
111
+ """Vision Transformer.
112
+
113
+ This backbone is the implementation of `An Image is Worth 16x16 Words:
114
+ Transformers for Image Recognition at
115
+ Scale <https://arxiv.org/abs/2010.11929>`_.
116
+
117
+ Args:
118
+ img_size (int | tuple): Input image size. Default: 224.
119
+ patch_size (int): The patch size. Default: 16.
120
+ in_channels (int): Number of input channels. Default: 3.
121
+ embed_dims (int): embedding dimension. Default: 768.
122
+ num_layers (int): depth of transformer. Default: 12.
123
+ num_heads (int): number of attention heads. Default: 12.
124
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
125
+ Default: 4.
126
+ out_indices (list | tuple | int): Output from which stages.
127
+ Default: -1.
128
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
129
+ drop_rate (float): Probability of an element to be zeroed.
130
+ Default 0.0
131
+ attn_drop_rate (float): The drop out rate for attention layer.
132
+ Default 0.0
133
+ drop_path_rate (float): stochastic depth rate. Default 0.0
134
+ with_cls_token (bool): Whether concatenating class token into image
135
+ tokens as transformer input. Default: True.
136
+ output_cls_token (bool): Whether output the cls_token. If set True,
137
+ `with_cls_token` must be True. Default: False.
138
+ norm_cfg (dict): Config dict for normalization layer.
139
+ Default: dict(type='LN')
140
+ act_cfg (dict): The activation config for FFNs.
141
+ Default: dict(type='GELU').
142
+ patch_norm (bool): Whether to add a norm in PatchEmbed Block.
143
+ Default: False.
144
+ final_norm (bool): Whether to add a additional layer to normalize
145
+ final feature map. Default: False.
146
+ interpolate_mode (str): Select the interpolate mode for position
147
+ embeding vector resize. Default: bicubic.
148
+ num_fcs (int): The number of fully-connected layers for FFNs.
149
+ Default: 2.
150
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
151
+ freeze running stats (mean and var). Note: Effect on Batch Norm
152
+ and its variants only. Default: False.
153
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
154
+ some memory while slowing down the training speed. Default: False.
155
+ pretrained (str, optional): model pretrained path. Default: None.
156
+ init_cfg (dict or list[dict], optional): Initialization config dict.
157
+ Default: None.
158
+ """
159
+
160
+ def __init__(self,
161
+ img_size=224,
162
+ patch_size=16,
163
+ patch_bias=True,
164
+ in_channels=3,
165
+ embed_dims=768,
166
+ num_layers=12,
167
+ num_heads=12,
168
+ mlp_ratio=4,
169
+ out_indices=-1,
170
+ qkv_bias=True,
171
+ drop_rate=0.,
172
+ attn_drop_rate=0.,
173
+ drop_path_rate=0.,
174
+ with_cls_token=True,
175
+ output_cls_token=False,
176
+ norm_cfg=dict(type='LN'),
177
+ act_cfg=dict(type='GELU'),
178
+ patch_norm=False,
179
+ pre_norm=False,
180
+ final_norm=False,
181
+ return_qkv=False,
182
+ skip_last_attn=False,
183
+ interpolate_mode='bicubic',
184
+ num_fcs=2,
185
+ norm_eval=False,
186
+ with_cp=False,
187
+ pretrained=None,
188
+ init_cfg=None):
189
+ super(VisionTransformer, self).__init__(init_cfg=init_cfg)
190
+
191
+ if isinstance(img_size, int):
192
+ img_size = to_2tuple(img_size)
193
+ elif isinstance(img_size, tuple):
194
+ if len(img_size) == 1:
195
+ img_size = to_2tuple(img_size[0])
196
+ assert len(img_size) == 2, \
197
+ f'The size of image should have length 1 or 2, ' \
198
+ f'but got {len(img_size)}'
199
+
200
+ if output_cls_token:
201
+ assert with_cls_token is True, f'with_cls_token must be True if' \
202
+ f'set output_cls_token to True, but got {with_cls_token}'
203
+
204
+ assert not (init_cfg and pretrained), \
205
+ 'init_cfg and pretrained cannot be set at the same time'
206
+ if isinstance(pretrained, str):
207
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
208
+ 'please use "init_cfg" instead')
209
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
210
+ elif pretrained is not None:
211
+ raise TypeError('pretrained must be a str or None')
212
+
213
+ self.img_size = img_size
214
+ self.patch_size = patch_size
215
+ self.interpolate_mode = interpolate_mode
216
+ self.norm_eval = norm_eval
217
+ self.with_cp = with_cp
218
+ self.pretrained = pretrained
219
+
220
+ self.patch_embed = PatchEmbed(
221
+ in_channels=in_channels,
222
+ embed_dims=embed_dims,
223
+ conv_type='Conv2d',
224
+ kernel_size=patch_size,
225
+ stride=patch_size,
226
+ padding='corner',
227
+ bias=patch_bias,
228
+ norm_cfg=norm_cfg if patch_norm else None,
229
+ init_cfg=None,
230
+ )
231
+
232
+ num_patches = (img_size[0] // patch_size) * \
233
+ (img_size[1] // patch_size)
234
+
235
+ self.with_cls_token = with_cls_token
236
+ self.output_cls_token = output_cls_token
237
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
238
+ self.pos_embed = nn.Parameter(
239
+ torch.zeros(1, num_patches + 1, embed_dims))
240
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
241
+
242
+ if isinstance(out_indices, int):
243
+ if out_indices == -1:
244
+ out_indices = num_layers - 1
245
+ self.out_indices = [out_indices]
246
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
247
+ self.out_indices = out_indices
248
+ else:
249
+ raise TypeError('out_indices must be type of int, list or tuple')
250
+
251
+ dpr = [
252
+ x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
253
+ ] # stochastic depth decay rule
254
+
255
+ self.layers = ModuleList()
256
+ for i in range(num_layers):
257
+ self.layers.append(
258
+ TransformerEncoderLayer(
259
+ embed_dims=embed_dims,
260
+ num_heads=num_heads,
261
+ feedforward_channels=mlp_ratio * embed_dims,
262
+ attn_drop_rate=attn_drop_rate,
263
+ drop_rate=drop_rate,
264
+ drop_path_rate=dpr[i],
265
+ num_fcs=num_fcs,
266
+ qkv_bias=qkv_bias,
267
+ act_cfg=act_cfg,
268
+ norm_cfg=norm_cfg,
269
+ batch_first=True))
270
+
271
+ self.pre_norm = pre_norm
272
+ if pre_norm:
273
+ self.norm0_name, norm0 = build_norm_layer(
274
+ norm_cfg, embed_dims, postfix=0)
275
+ self.add_module(self.norm0_name, norm0)
276
+
277
+ self.final_norm = final_norm
278
+ if final_norm:
279
+ self.norm1_name, norm1 = build_norm_layer(
280
+ norm_cfg, embed_dims, postfix=1)
281
+ self.add_module(self.norm1_name, norm1)
282
+
283
+ self.return_qkv = [False] * num_layers
284
+ if isinstance(return_qkv, bool):
285
+ for out_i in self.out_indices:
286
+ self.return_qkv[out_i] = return_qkv
287
+ elif isinstance(return_qkv, list) or isinstance(return_qkv, tuple):
288
+ for i, out_i in enumerate(self.out_indices):
289
+ self.return_qkv[out_i] = return_qkv[i]
290
+ else:
291
+ raise TypeError('return_qkv must be type of bool, list or tuple')
292
+
293
+ self.skip_last_attn = skip_last_attn
294
+
295
+ @property
296
+ def norm0(self):
297
+ return getattr(self, self.norm0_name)
298
+
299
+ @property
300
+ def norm1(self):
301
+ return getattr(self, self.norm1_name)
302
+
303
+ def init_weights(self):
304
+ if (isinstance(self.init_cfg, dict)
305
+ and self.init_cfg.get('type') == 'Pretrained'):
306
+ logger = get_root_logger()
307
+
308
+ checkpoint = _load_checkpoint(
309
+ self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
310
+
311
+ if 'state_dict' in checkpoint:
312
+ state_dict = checkpoint['state_dict']
313
+ else:
314
+ state_dict = checkpoint
315
+
316
+ if 'pos_embed' in state_dict.keys():
317
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
318
+ logger.info(msg=f'Resize the pos_embed shape from '
319
+ f'{state_dict["pos_embed"].shape} to '
320
+ f'{self.pos_embed.shape}')
321
+ h, w = self.img_size
322
+ pos_size = int(
323
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
324
+ state_dict['pos_embed'] = self.resize_pos_embed(
325
+ state_dict['pos_embed'],
326
+ (h // self.patch_size, w // self.patch_size),
327
+ (pos_size, pos_size), self.interpolate_mode)
328
+
329
+ print(self.load_state_dict(state_dict, False))
330
+ elif self.init_cfg is not None:
331
+ super(VisionTransformer, self).init_weights()
332
+ else:
333
+ # We only implement the 'jax_impl' initialization implemented at
334
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
335
+ trunc_normal_(self.pos_embed, std=.02)
336
+ trunc_normal_(self.cls_token, std=.02)
337
+ for n, m in self.named_modules():
338
+ if isinstance(m, nn.Linear):
339
+ trunc_normal_(m.weight, std=.02)
340
+ if m.bias is not None:
341
+ if 'ffn' in n:
342
+ nn.init.normal_(m.bias, mean=0., std=1e-6)
343
+ else:
344
+ nn.init.constant_(m.bias, 0)
345
+ elif isinstance(m, nn.Conv2d):
346
+ kaiming_init(m, mode='fan_in', bias=0.)
347
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
348
+ constant_init(m, val=1.0, bias=0.)
349
+
350
+ def _pos_embeding(self, patched_img, hw_shape, pos_embed):
351
+ """Positiong embeding method.
352
+
353
+ Resize the pos_embed, if the input image size doesn't match
354
+ the training size.
355
+ Args:
356
+ patched_img (torch.Tensor): The patched image, it should be
357
+ shape of [B, L1, C].
358
+ hw_shape (tuple): The downsampled image resolution.
359
+ pos_embed (torch.Tensor): The pos_embed weighs, it should be
360
+ shape of [B, L2, c].
361
+ Return:
362
+ torch.Tensor: The pos encoded image feature.
363
+ """
364
+ assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
365
+ 'the shapes of patched_img and pos_embed must be [B, L, C]'
366
+ x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
367
+ if x_len != pos_len:
368
+ if pos_len == (self.img_size[0] // self.patch_size) * (
369
+ self.img_size[1] // self.patch_size) + 1:
370
+ pos_h = self.img_size[0] // self.patch_size
371
+ pos_w = self.img_size[1] // self.patch_size
372
+ else:
373
+ raise ValueError(
374
+ 'Unexpected shape of pos_embed, got {}.'.format(
375
+ pos_embed.shape))
376
+ pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
377
+ (pos_h, pos_w),
378
+ self.interpolate_mode)
379
+ return self.drop_after_pos(patched_img + pos_embed)
380
+
381
+ @staticmethod
382
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
383
+ """Resize pos_embed weights.
384
+
385
+ Resize pos_embed using bicubic interpolate method.
386
+ Args:
387
+ pos_embed (torch.Tensor): Position embedding weights.
388
+ input_shpae (tuple): Tuple for (downsampled input image height,
389
+ downsampled input image width).
390
+ pos_shape (tuple): The resolution of downsampled origin training
391
+ image.
392
+ mode (str): Algorithm used for upsampling:
393
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
394
+ ``'trilinear'``. Default: ``'nearest'``
395
+ Return:
396
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
397
+ """
398
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
399
+ pos_h, pos_w = pos_shape
400
+ cls_token_weight = pos_embed[:, 0]
401
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
402
+ pos_embed_weight = pos_embed_weight.reshape(
403
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
404
+ pos_embed_weight = resize(
405
+ pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
406
+ cls_token_weight = cls_token_weight.unsqueeze(1)
407
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
408
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
409
+ return pos_embed
410
+
411
+ def forward(self, inputs):
412
+ B = inputs.shape[0]
413
+
414
+ x, hw_shape = self.patch_embed(inputs)
415
+
416
+ # stole cls_tokens impl from Phil Wang, thanks
417
+ cls_tokens = self.cls_token.expand(B, -1, -1)
418
+ x = torch.cat((cls_tokens, x), dim=1)
419
+ x = self._pos_embeding(x, hw_shape, self.pos_embed)
420
+
421
+ if not self.with_cls_token:
422
+ # Remove class token for transformer encoder input
423
+ x = x[:, 1:]
424
+
425
+ if self.pre_norm:
426
+ x = self.norm0(x)
427
+
428
+ outs = []
429
+ for i, layer in enumerate(self.layers):
430
+ x, q, k, v = layer(x, self.return_qkv[i] \
431
+ or (i == len(self.layers) - 1 and self.skip_last_attn))
432
+ if i == len(self.layers) - 1:
433
+ if self.final_norm:
434
+ x = self.norm1(x)
435
+ if self.return_qkv[i]:
436
+ v = self.norm1(v)
437
+ if self.skip_last_attn:
438
+ if self.with_cls_token:
439
+ x[:, 1:] = v[:, 1:]
440
+ else:
441
+ x = v
442
+ if i in self.out_indices:
443
+ if self.with_cls_token:
444
+ # Remove class token and reshape token for decoder head
445
+ out = x[:, 1:]
446
+ else:
447
+ out = x
448
+ B, _, C = out.shape
449
+ out = out.reshape(B, hw_shape[0], hw_shape[1],
450
+ C).permute(0, 3, 1, 2).contiguous()
451
+ if self.output_cls_token:
452
+ out = [out, x[:, 0]]
453
+ if self.return_qkv[i]:
454
+ if self.with_cls_token:
455
+ q = q[:, 1:]
456
+ k = k[:, 1:]
457
+ v = v[:, 1:]
458
+ v = v.reshape(B, hw_shape[0], hw_shape[1],
459
+ C).permute(0, 3, 1, 2).contiguous()
460
+ out = [out, q, k, v]
461
+ outs.append(out)
462
+
463
+ return tuple(outs)
464
+
465
+ def train(self, mode=True):
466
+ super(VisionTransformer, self).train(mode)
467
+ if mode and self.norm_eval:
468
+ for m in self.modules():
469
+ if isinstance(m, nn.LayerNorm):
470
+ m.eval()
segmentation/configs/_base_/custom_import.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+
9
+ custom_imports = dict(
10
+ imports=["segmentation.datasets.coco_object", "segmentation.datasets.pascal_voc", "datasets.transforms", "segmentation.datasets.pascal_voc20"],
11
+ allow_failed_imports=False,
12
+ )
segmentation/configs/_base_/datasets/ade20k.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "ADE20KDataset"
11
+ data_root = "./data"
12
+
13
+ train_pipeline = [
14
+ dict(type="LoadImageFromFile"),
15
+ dict(type='ToRGB'),
16
+ dict(
17
+ type="MultiScaleFlipAug",
18
+ img_scale=(2048, 448),
19
+ flip=True,
20
+ transforms=[
21
+ dict(type='LoadImageFromFile'),
22
+ dict(type='ToRGB'),
23
+ dict(type='Resize', img_scale=(2048, 448)),
24
+ dict(type='RandomCrop', crop_size=(448, 448)),
25
+ dict(type='RandomFlip', prob=0.5),
26
+ dict(type='PhotoMetricDistortion'),
27
+ dict(type="ImageToTensorV2", keys=["img"]),
28
+ dict(type='Collect', keys=['img'], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
29
+ ],
30
+ ),
31
+ ]
32
+
33
+ test_pipeline = [
34
+ dict(type="LoadImageFromFile"),
35
+ dict(type='ToRGB'),
36
+ dict(
37
+ type="MultiScaleFlipAug",
38
+ img_scale=(2048, 448),
39
+ flip=False,
40
+ transforms=[
41
+ dict(type="Resize", keep_ratio=True),
42
+ dict(type="RandomFlip"),
43
+ dict(type="ImageToTensorV2", keys=["img"]),
44
+ dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
45
+ ],
46
+ ),
47
+ ]
48
+ data = dict(
49
+ test=dict(
50
+ type=dataset_type,
51
+ data_root=data_root,
52
+ img_dir="ADEChallengeData2016/images/validation",
53
+ ann_dir="ADEChallengeData2016/annotations/validation",
54
+ pipeline=test_pipeline,
55
+ )
56
+ )
57
+
58
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/cityscapes.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "CityscapesDataset"
11
+ data_root = "./data/cityscapes"
12
+ test_pipeline = [
13
+ dict(type="LoadImageFromFile"),
14
+ dict(type='ToRGB'),
15
+ dict(
16
+ type="MultiScaleFlipAug",
17
+ img_scale=(2048, 448),
18
+ flip=False,
19
+ transforms=[
20
+ dict(type="Resize", keep_ratio=True),
21
+ dict(type="RandomFlip"),
22
+ dict(type="ImageToTensorV2", keys=["img"]),
23
+ dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
24
+ ],
25
+ ),
26
+ ]
27
+ data = dict(
28
+ test=dict(
29
+ type=dataset_type,
30
+ data_root=data_root,
31
+ img_dir="leftImg8bit/val",
32
+ ann_dir="gtFine/val",
33
+ pipeline=test_pipeline,
34
+ )
35
+ )
36
+
37
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/coco.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "COCOObjectDataset"
11
+ data_root = "./data/coco_stuff164k"
12
+
13
+ test_pipeline = [
14
+ dict(type="LoadImageFromFile"),
15
+ dict(type='ToRGB'),
16
+ dict(
17
+ type="MultiScaleFlipAug",
18
+ img_scale=(2048, 448),
19
+ flip=False,
20
+ transforms=[
21
+ dict(type="Resize", keep_ratio=True),
22
+ dict(type="RandomFlip"),
23
+ dict(type="ImageToTensorV2", keys=["img"]),
24
+ dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
25
+ ],
26
+ ),
27
+ ]
28
+ data = dict(
29
+
30
+ test=dict(
31
+ type=dataset_type,
32
+ data_root=data_root,
33
+ img_dir="images/val2017",
34
+ ann_dir="annotations/val2017",
35
+ pipeline=test_pipeline,
36
+ )
37
+ )
38
+
39
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/pascal_context.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "PascalContextDataset"
11
+ data_root = "./data/VOCdevkit/VOC2010"
12
+ test_pipeline = [
13
+ dict(type="LoadImageFromFile"),
14
+ dict(type='ToRGB'),
15
+ dict(
16
+ type="MultiScaleFlipAug",
17
+ img_scale=(2048, 448),
18
+ flip=False,
19
+ transforms=[
20
+ dict(type="Resize", keep_ratio=True),
21
+ dict(type="RandomFlip"),
22
+ dict(type="ImageToTensorV2", keys=["img"]),
23
+ dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
24
+ ],
25
+ ),
26
+ ]
27
+ data = dict(
28
+ test=dict(
29
+ type=dataset_type,
30
+ data_root=data_root,
31
+ img_dir="JPEGImages",
32
+ ann_dir="SegmentationClassContext",
33
+ split="ImageSets/SegmentationContext/val.txt",
34
+ pipeline=test_pipeline,
35
+ )
36
+ )
37
+
38
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/pascal_context59.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ # dataset settings
9
+ dataset_type = "PascalContextDataset59"
10
+ data_root = "./data/VOCdevkit/VOC2010"
11
+ test_pipeline = [
12
+ dict(type="LoadImageFromFile"),
13
+ dict(type='ToRGB'),
14
+ dict(
15
+ type="MultiScaleFlipAug",
16
+ img_scale=(2048, 448),
17
+ flip=False,
18
+ transforms=[
19
+ dict(type="Resize", keep_ratio=True),
20
+ dict(type="RandomFlip"),
21
+ dict(type="ImageToTensorV2", keys=["img"]),
22
+ dict(type="Collect", keys=["img"],
23
+ meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
24
+ ],
25
+ ),
26
+ ]
27
+ data = dict(
28
+ test=dict(
29
+ type=dataset_type,
30
+ data_root=data_root,
31
+ img_dir="JPEGImages",
32
+ ann_dir="SegmentationClassContext",
33
+ split="ImageSets/SegmentationContext/val.txt",
34
+ pipeline=test_pipeline,
35
+ )
36
+ )
37
+
38
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/pascal_voc12.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "PascalVOCDataset"
11
+ data_root = "./data/VOCdevkit/VOC2012"
12
+
13
+ test_pipeline = [
14
+ dict(type="LoadImageFromFile"),
15
+ dict(type='ToRGB'),
16
+ dict(
17
+ type="MultiScaleFlipAug",
18
+ img_scale=(2048, 448),
19
+ flip=False,
20
+ transforms=[
21
+ dict(type="Resize", keep_ratio=True),
22
+ dict(type="RandomFlip"),
23
+ dict(type="ImageToTensorV2", keys=["img"]),
24
+ dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
25
+ ],
26
+ ),
27
+ ]
28
+ data = dict(
29
+
30
+ test=dict(
31
+ type=dataset_type,
32
+ data_root=data_root,
33
+ img_dir="JPEGImages",
34
+ ann_dir="SegmentationClass",
35
+ split="ImageSets/Segmentation/val.txt",
36
+ pipeline=test_pipeline,
37
+ )
38
+ )
39
+
40
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/pascal_voc12_20.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from GroupViT (https://github.com/NVlabs/GroupViT)
6
+ # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "PascalVOCDataset20"
11
+ data_root = "./data/VOCdevkit/VOC2012"
12
+ test_pipeline = [
13
+ dict(type="LoadImageFromFile"),
14
+ dict(type='ToRGB'),
15
+ dict(
16
+ type="MultiScaleFlipAug",
17
+ img_scale=(2048, 448),
18
+ flip=False,
19
+ transforms=[
20
+ dict(type="Resize", keep_ratio=True),
21
+ dict(type="RandomFlip"),
22
+ dict(type="ImageToTensorV2", keys=["img"]),
23
+ dict(type="Collect", keys=["img"],
24
+ meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
25
+ ],
26
+ ),
27
+ ]
28
+ data = dict(
29
+ test=dict(
30
+ type=dataset_type,
31
+ data_root=data_root,
32
+ img_dir="JPEGImages",
33
+ ann_dir="SegmentationClass",
34
+ split="ImageSets/Segmentation/val.txt",
35
+ pipeline=test_pipeline,
36
+
37
+ )
38
+ )
39
+
40
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/configs/_base_/datasets/stuff.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ _base_ = ["../custom_import.py"]
9
+ # dataset settings
10
+ dataset_type = "COCOStuffDataset"
11
+ data_root = "./data/coco_stuff164k"
12
+
13
+ test_pipeline = [
14
+ dict(type="LoadImageFromFile"),
15
+ dict(type='ToRGB'),
16
+ dict(
17
+ type="MultiScaleFlipAug",
18
+ img_scale=(2048, 448),
19
+ flip=False,
20
+ transforms=[
21
+ dict(type="Resize", keep_ratio=True),
22
+ dict(type="RandomFlip"),
23
+ dict(type="ImageToTensorV2", keys=["img"]),
24
+ dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
25
+ ],
26
+ ),
27
+ ]
28
+ data = dict(
29
+
30
+ test=dict(
31
+ type=dataset_type,
32
+ data_root=data_root,
33
+ img_dir="images/val2017",
34
+ ann_dir="annotations/val2017",
35
+ pipeline=test_pipeline,
36
+ )
37
+ )
38
+
39
+ test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
segmentation/datasets/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .coco_object import *
2
+ from .pascal_voc import *
3
+ from .pascal_voc20 import *
4
+ from .pascal_context import *
5
+ from .coco_stuff import *
segmentation/datasets/coco_object.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ from mmseg.datasets import DATASETS, CustomDataset
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class COCOObjectDataset(CustomDataset):
13
+ """COCO-Object dataset.
14
+
15
+ 1 bg class + first 80 classes from the COCO-Stuff dataset.
16
+ """
17
+
18
+ CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'aeroplane', 'bus', 'train', 'truck', 'boat',
19
+ 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
20
+ 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
21
+ 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
22
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
23
+ 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
24
+ 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
25
+ 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
26
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
27
+
28
+ PALETTE = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
29
+ [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
30
+ [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
31
+ [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
32
+ [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
33
+ [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
34
+ [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
35
+ [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
36
+ [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
37
+ [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
38
+ [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
39
+ [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]
40
+
41
+ def __init__(self, **kwargs):
42
+ super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)
segmentation/datasets/coco_stuff.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+
5
+
6
+ from mmseg.datasets import DATASETS, CustomDataset
7
+
8
+
9
+ @DATASETS.register_module(force=True)
10
+ class COCOStuffDataset(CustomDataset):
11
+ """COCO-Stuff dataset.
12
+
13
+ In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version
14
+ are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff
15
+ 164k is from 0 to 170, where 255 is the ignore index. So, they are all 171
16
+ semantic categories. ``reduce_zero_label`` is set to True and False for the
17
+ 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
18
+ and ``seg_map_suffix`` is fixed to '.png'.
19
+ """
20
+ CLASSES = (
21
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
22
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
23
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
24
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
25
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
26
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
27
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
28
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
29
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
30
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
31
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
32
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
33
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
34
+ 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
35
+ 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
36
+ 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
37
+ 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
38
+ 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',
39
+ 'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',
40
+ 'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',
41
+ 'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
42
+ 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
43
+ 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
44
+ 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
45
+ 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
46
+ 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
47
+ 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
48
+ 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
49
+ 'window-blind', 'window-other', 'wood')
50
+
51
+ PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
52
+ [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
53
+ [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
54
+ [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
55
+ [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
56
+ [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
57
+ [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
58
+ [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
59
+ [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
60
+ [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
61
+ [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
62
+ [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
63
+ [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
64
+ [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
65
+ [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
66
+ [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
67
+ [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],
68
+ [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],
69
+ [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],
70
+ [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
71
+ [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],
72
+ [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],
73
+ [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],
74
+ [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],
75
+ [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],
76
+ [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],
77
+ [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],
78
+ [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],
79
+ [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],
80
+ [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],
81
+ [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],
82
+ [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],
83
+ [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],
84
+ [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],
85
+ [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],
86
+ [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],
87
+ [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],
88
+ [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],
89
+ [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
90
+ [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
91
+ [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
92
+ [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
93
+ [64, 192, 96], [64, 160, 64], [64, 64, 0]]
94
+
95
+ def __init__(self, **kwargs):
96
+ super(COCOStuffDataset, self).__init__(
97
+ img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)
segmentation/datasets/pascal_context.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # from MaskCLIP
6
+ # Copyright (c) OpenMMLab. All rights reserved.
7
+ # ------------------------------------------------------------------------------
8
+ from mmseg.datasets import DATASETS, CustomDataset
9
+ import os.path as osp
10
+
11
+
12
+ @DATASETS.register_module(force=True)
13
+ class PascalContextDataset(CustomDataset):
14
+ """PascalContext dataset.
15
+
16
+ In segmentation map annotation for PascalContext, 0 stands for background,
17
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
18
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
19
+ fixed to '.png'.
20
+
21
+ Args:
22
+ split (str): Split txt file for PascalContext.
23
+ """
24
+
25
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
26
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
27
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
28
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
29
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
30
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
31
+ 'plate', 'platform', 'potted plant', 'road', 'rock', 'sheep',
32
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
33
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
34
+ 'window', 'wood')
35
+
36
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
37
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
38
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
39
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
40
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
41
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
42
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
43
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
44
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
45
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
46
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
47
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
48
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
49
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
50
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
51
+
52
+ def __init__(self, split, **kwargs):
53
+ super(PascalContextDataset, self).__init__(
54
+ img_suffix='.jpg',
55
+ seg_map_suffix='.png',
56
+ split=split,
57
+ reduce_zero_label=False,
58
+ **kwargs)
59
+ assert osp.exists(self.img_dir) and self.split is not None
60
+
61
+
62
+ @DATASETS.register_module(force=True)
63
+ class PascalContextDataset59(CustomDataset):
64
+ """PascalContext dataset.
65
+
66
+ In segmentation map annotation for PascalContext, 0 stands for background,
67
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
68
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
69
+ fixed to '.png'.
70
+
71
+ Args:
72
+ split (str): Split txt file for PascalContext.
73
+ """
74
+
75
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
76
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
77
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
78
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
79
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
80
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
81
+ 'potted plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
82
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
83
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
84
+
85
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
86
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
87
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
88
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
89
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
90
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
91
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
92
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
93
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
94
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
95
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
96
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
97
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
98
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
99
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
100
+
101
+ def __init__(self, split, **kwargs):
102
+ super(PascalContextDataset59, self).__init__(
103
+ img_suffix='.jpg',
104
+ seg_map_suffix='.png',
105
+ split=split,
106
+ reduce_zero_label=True,
107
+ **kwargs)
108
+ assert osp.exists(self.img_dir) and self.split is not None
segmentation/datasets/pascal_voc.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # TCL
3
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------
5
+ # Modified from GroupViT (https://github.com/NVlabs/GroupViT)
6
+ # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+ import os
9
+ from mmseg.datasets import DATASETS
10
+ from mmseg.datasets import CustomDataset
11
+
12
+
13
+ @DATASETS.register_module(force=True)
14
+ class PascalVOCDataset(CustomDataset):
15
+ """Pascal VOC dataset (the background class is ignored).
16
+ Burrowed from MaskCLIP
17
+
18
+ Args:
19
+ split (str): Split txt file for Pascal VOC.
20
+ """
21
+
22
+ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
23
+ 'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog',
24
+ 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa',
25
+ 'train', 'tvmonitor')
26
+
27
+ PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
28
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
29
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
30
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
31
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
32
+
33
+ def __init__(self, split, **kwargs):
34
+ super(PascalVOCDataset, self).__init__(
35
+ img_suffix='.jpg',
36
+ seg_map_suffix='.png',
37
+ split=split,
38
+ reduce_zero_label=False,
39
+ **kwargs)
40
+ assert os.path.exists(self.img_dir) and self.split is not None
segmentation/datasets/pascal_voc20.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ----------------------------------------------------------------------------------------------------
5
+ # Modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ------------------------------------------------------------------------------
8
+
9
+ import os.path as osp
10
+ from mmseg.datasets import DATASETS
11
+ from mmseg.datasets import CustomDataset
12
+
13
+
14
+ @DATASETS.register_module()
15
+ class PascalVOCDataset20(CustomDataset):
16
+ """Pascal VOC dataset (the background class is ignored).
17
+
18
+ Args:
19
+ split (str): Split txt file for Pascal VOC.
20
+ """
21
+
22
+ CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
23
+ 'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog',
24
+ 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa',
25
+ 'train', 'tvmonitor')
26
+
27
+ PALETTE = [[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
28
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
29
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
30
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
31
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
32
+
33
+ def __init__(self, split, **kwargs):
34
+ super(PascalVOCDataset20, self).__init__(
35
+ img_suffix='.jpg',
36
+ seg_map_suffix='.png',
37
+ split=split,
38
+ reduce_zero_label=True,
39
+ **kwargs)
40
+ assert osp.exists(self.img_dir) and self.split is not None
segmentation/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .builder import build_seg_dataloader, build_seg_dataset, build_seg_inference
segmentation/evaluation/builder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+ # ---------------------------------------------------------------------------------------------------
5
+ # modified from TCL
6
+ # Copyright (c) 2023 Kakao Brain. All Rights Reserved.
7
+ # ---------------------------------------------------------------------------------------------------
8
+
9
+ import mmcv
10
+ from mmseg.datasets import build_dataloader, build_dataset
11
+ from mmcv.utils import Registry
12
+ from mmcv.cnn import MODELS as MMCV_MODELS
13
+ MODELS = Registry('models', parent=MMCV_MODELS)
14
+ SEGMENTORS = MODELS
15
+ from .clip_dinoiser_eval import DinoCLIP_Infrencer
16
+
17
+
18
+ def build_seg_dataset(config):
19
+ """Build a dataset from config."""
20
+ cfg = mmcv.Config.fromfile(config)
21
+ dataset = build_dataset(cfg.data.test)
22
+ return dataset
23
+
24
+
25
+ def build_seg_dataloader(dataset, dist=True):
26
+ # batch size is set to 1 to handle varying image size (due to different aspect ratio)
27
+ if dist:
28
+ data_loader = build_dataloader(
29
+ dataset,
30
+ samples_per_gpu=1,
31
+ workers_per_gpu=2,
32
+ dist=dist,
33
+ shuffle=False,
34
+ persistent_workers=True,
35
+ pin_memory=False,
36
+ )
37
+ else:
38
+ data_loader = build_dataloader(
39
+ dataset=dataset,
40
+ samples_per_gpu=1,
41
+ workers_per_gpu=2,
42
+ dist=dist,
43
+ shuffle=False,
44
+ persistent_workers=True,
45
+ pin_memory=False,
46
+ )
47
+ return data_loader
48
+
49
+
50
+ def build_seg_inference(
51
+ model,
52
+ dataset,
53
+ config,
54
+ seg_config,
55
+ ):
56
+ dset_cfg = mmcv.Config.fromfile(seg_config) # dataset config
57
+ classnames = dataset.CLASSES
58
+ kwargs = dict()
59
+ if hasattr(dset_cfg, "test_cfg"):
60
+ kwargs["test_cfg"] = dset_cfg.test_cfg
61
+
62
+ seg_model = DinoCLIP_Infrencer(model, num_classes=len(classnames), **kwargs, **config.evaluate)
63
+ seg_model.CLASSES = dataset.CLASSES
64
+ seg_model.PALETTE = dataset.PALETTE
65
+
66
+ return seg_model
segmentation/evaluation/clip_dinoiser_eval.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ import torch
3
+ import logging
4
+ log = logging.getLogger(__name__)
5
+ from mmseg.ops import resize
6
+ from mmseg.models import EncoderDecoder
7
+
8
+ class DinoCLIP_Infrencer(EncoderDecoder):
9
+ def __init__(
10
+ self,
11
+ model,
12
+ num_classes,
13
+ test_cfg=dict(),
14
+ **kwargs,
15
+ ):
16
+ super(EncoderDecoder, self).__init__()
17
+ self.mode = test_cfg['mode']
18
+ self.num_classes = num_classes
19
+ self.model = model
20
+ self.test_cfg = test_cfg
21
+ self.align_corners = False
22
+
23
+ @torch.no_grad()
24
+ def encode_decode(self, img, meta_data):
25
+ """
26
+ """
27
+ masks = self.model(img)
28
+ masks = resize(
29
+ input=masks,
30
+ size=img.shape[-2:],
31
+ mode='bilinear',
32
+ align_corners=self.align_corners)
33
+ return masks
34
+
visualization.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def mask2rgb(mask, palette):
4
+ img = np.zeros((mask.shape[0], mask.shape[1], 3))
5
+ for l in np.unique(mask):
6
+ img[mask == int(l)] = palette[int(l)]
7
+ return img.astype(int)