wjf5203 commited on
Commit
2aac0e2
1 Parent(s): acefb81

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. GLEE/.DS_Store +0 -0
  3. GLEE/clip_vit_base_patch32/config.json +157 -0
  4. GLEE/clip_vit_base_patch32/merges.txt +0 -0
  5. GLEE/clip_vit_base_patch32/preprocessor_config.json +19 -0
  6. GLEE/clip_vit_base_patch32/pytorch_model.bin +3 -0
  7. GLEE/clip_vit_base_patch32/special_tokens_map.json +1 -0
  8. GLEE/clip_vit_base_patch32/tokenizer.json +0 -0
  9. GLEE/clip_vit_base_patch32/tokenizer_config.json +1 -0
  10. GLEE/clip_vit_base_patch32/vocab.json +0 -0
  11. GLEE/configs/R50.yaml +71 -0
  12. GLEE/configs/SwinL.yaml +79 -0
  13. GLEE/glee/.DS_Store +0 -0
  14. GLEE/glee/__init__.py +12 -0
  15. GLEE/glee/backbone/__init__.py +7 -0
  16. GLEE/glee/backbone/backbone.py +51 -0
  17. GLEE/glee/backbone/build.py +11 -0
  18. GLEE/glee/backbone/davit.py +623 -0
  19. GLEE/glee/backbone/eva01.py +676 -0
  20. GLEE/glee/backbone/eva02-dino.py +598 -0
  21. GLEE/glee/backbone/eva02.py +647 -0
  22. GLEE/glee/backbone/eva_01_utils.py +222 -0
  23. GLEE/glee/backbone/eva_02_utils.py +356 -0
  24. GLEE/glee/backbone/internimage.py +737 -0
  25. GLEE/glee/backbone/registry.py +14 -0
  26. GLEE/glee/backbone/resnet.py +731 -0
  27. GLEE/glee/backbone/swin.py +783 -0
  28. GLEE/glee/backbone/vit.py +472 -0
  29. GLEE/glee/backbone/vit_utils.py +222 -0
  30. GLEE/glee/config.py +387 -0
  31. GLEE/glee/config_deeplab.py +28 -0
  32. GLEE/glee/models/.DS_Store +0 -0
  33. GLEE/glee/models/glee_model.py +296 -0
  34. GLEE/glee/models/pixel_decoder/__init__.py +1 -0
  35. GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-38.pyc +0 -0
  36. GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-39.pyc +0 -0
  37. GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-38.pyc +0 -0
  38. GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-39.pyc +0 -0
  39. GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-38.pyc +0 -0
  40. GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-39.pyc +0 -0
  41. GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-38.pyc +0 -0
  42. GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-39.pyc +0 -0
  43. GLEE/glee/models/pixel_decoder/early_fusion.py +230 -0
  44. GLEE/glee/models/pixel_decoder/maskdino_encoder.py +463 -0
  45. GLEE/glee/models/pixel_decoder/ops/functions/__init__.py +13 -0
  46. GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-38.pyc +0 -0
  47. GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-39.pyc +0 -0
  48. GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc +0 -0
  49. GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc +0 -0
  50. GLEE/glee/models/pixel_decoder/ops/functions/ms_deform_attn_func.py +72 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
GLEE/.DS_Store ADDED
Binary file (6.15 kB). View file
 
GLEE/clip_vit_base_patch32/config.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-base-patch32",
3
+ "architectures": [
4
+ "CLIPModel"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 2.6592,
8
+ "model_type": "clip",
9
+ "projection_dim": 512,
10
+ "text_config": {
11
+ "_name_or_path": "",
12
+ "add_cross_attention": false,
13
+ "architectures": null,
14
+ "attention_dropout": 0.0,
15
+ "bad_words_ids": null,
16
+ "bos_token_id": 0,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "dropout": 0.0,
23
+ "early_stopping": false,
24
+ "encoder_no_repeat_ngram_size": 0,
25
+ "eos_token_id": 2,
26
+ "finetuning_task": null,
27
+ "forced_bos_token_id": null,
28
+ "forced_eos_token_id": null,
29
+ "hidden_act": "quick_gelu",
30
+ "hidden_size": 512,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "initializer_factor": 1.0,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 2048,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_eps": 1e-05,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "max_position_embeddings": 77,
48
+ "min_length": 0,
49
+ "model_type": "clip_text_model",
50
+ "no_repeat_ngram_size": 0,
51
+ "num_attention_heads": 8,
52
+ "num_beam_groups": 1,
53
+ "num_beams": 1,
54
+ "num_hidden_layers": 12,
55
+ "num_return_sequences": 1,
56
+ "output_attentions": false,
57
+ "output_hidden_states": false,
58
+ "output_scores": false,
59
+ "pad_token_id": 1,
60
+ "prefix": null,
61
+ "projection_dim": 512,
62
+ "problem_type": null,
63
+ "pruned_heads": {},
64
+ "remove_invalid_values": false,
65
+ "repetition_penalty": 1.0,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "sep_token_id": null,
69
+ "task_specific_params": null,
70
+ "temperature": 1.0,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": null,
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": null,
77
+ "torchscript": false,
78
+ "transformers_version": "4.16.0.dev0",
79
+ "use_bfloat16": false,
80
+ "vocab_size": 49408
81
+ },
82
+ "text_config_dict": null,
83
+ "transformers_version": null,
84
+ "vision_config": {
85
+ "_name_or_path": "",
86
+ "add_cross_attention": false,
87
+ "architectures": null,
88
+ "attention_dropout": 0.0,
89
+ "bad_words_ids": null,
90
+ "bos_token_id": null,
91
+ "chunk_size_feed_forward": 0,
92
+ "cross_attention_hidden_size": null,
93
+ "decoder_start_token_id": null,
94
+ "diversity_penalty": 0.0,
95
+ "do_sample": false,
96
+ "dropout": 0.0,
97
+ "early_stopping": false,
98
+ "encoder_no_repeat_ngram_size": 0,
99
+ "eos_token_id": null,
100
+ "finetuning_task": null,
101
+ "forced_bos_token_id": null,
102
+ "forced_eos_token_id": null,
103
+ "hidden_act": "quick_gelu",
104
+ "hidden_size": 768,
105
+ "id2label": {
106
+ "0": "LABEL_0",
107
+ "1": "LABEL_1"
108
+ },
109
+ "image_size": 224,
110
+ "initializer_factor": 1.0,
111
+ "initializer_range": 0.02,
112
+ "intermediate_size": 3072,
113
+ "is_decoder": false,
114
+ "is_encoder_decoder": false,
115
+ "label2id": {
116
+ "LABEL_0": 0,
117
+ "LABEL_1": 1
118
+ },
119
+ "layer_norm_eps": 1e-05,
120
+ "length_penalty": 1.0,
121
+ "max_length": 20,
122
+ "min_length": 0,
123
+ "model_type": "clip_vision_model",
124
+ "no_repeat_ngram_size": 0,
125
+ "num_attention_heads": 12,
126
+ "num_beam_groups": 1,
127
+ "num_beams": 1,
128
+ "num_hidden_layers": 12,
129
+ "num_return_sequences": 1,
130
+ "output_attentions": false,
131
+ "output_hidden_states": false,
132
+ "output_scores": false,
133
+ "pad_token_id": null,
134
+ "patch_size": 32,
135
+ "prefix": null,
136
+ "projection_dim" : 512,
137
+ "problem_type": null,
138
+ "pruned_heads": {},
139
+ "remove_invalid_values": false,
140
+ "repetition_penalty": 1.0,
141
+ "return_dict": true,
142
+ "return_dict_in_generate": false,
143
+ "sep_token_id": null,
144
+ "task_specific_params": null,
145
+ "temperature": 1.0,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": true,
148
+ "tokenizer_class": null,
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torch_dtype": null,
152
+ "torchscript": false,
153
+ "transformers_version": "4.16.0.dev0",
154
+ "use_bfloat16": false
155
+ },
156
+ "vision_config_dict": null
157
+ }
GLEE/clip_vit_base_patch32/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
GLEE/clip_vit_base_patch32/preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.48145466,
9
+ 0.4578275,
10
+ 0.40821073
11
+ ],
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "resample": 3,
18
+ "size": 224
19
+ }
GLEE/clip_vit_base_patch32/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a63082132ba4f97a80bea76823f544493bffa8082296d62d71581a4feff1576f
3
+ size 605247071
GLEE/clip_vit_base_patch32/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
GLEE/clip_vit_base_patch32/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
GLEE/clip_vit_base_patch32/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
GLEE/clip_vit_base_patch32/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
GLEE/configs/R50.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GLEE"
3
+ MASK_ON: True
4
+ BACKBONE:
5
+ FREEZE_AT: 0
6
+ NAME: "build_resnet_backbone"
7
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
8
+ PIXEL_STD: [58.395, 57.120, 57.375]
9
+ RESNETS:
10
+ DEPTH: 50
11
+ STEM_TYPE: "basic" # not used
12
+ STEM_OUT_CHANNELS: 64
13
+ STRIDE_IN_1X1: False
14
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ # NORM: "SyncBN"
16
+ RES5_MULTI_GRID: [1, 1, 1] # not used
17
+ SEM_SEG_HEAD:
18
+ NAME: "MaskDINOHead"
19
+ IGNORE_VALUE: 255
20
+ NUM_CLASSES: 80
21
+ LOSS_WEIGHT: 1.0
22
+ CONVS_DIM: 256
23
+ MASK_DIM: 256
24
+ NORM: "GN"
25
+ # pixel decoder
26
+ PIXEL_DECODER_NAME: "MaskDINOEncoder"
27
+ DIM_FEEDFORWARD: 2048
28
+ NUM_FEATURE_LEVELS: 3
29
+ TOTAL_NUM_FEATURE_LEVELS: 4
30
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
31
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
32
+ COMMON_STRIDE: 4
33
+ TRANSFORMER_ENC_LAYERS: 6
34
+ FEATURE_ORDER: "low2high"
35
+ MaskDINO:
36
+ TRANSFORMER_DECODER_NAME: "MaskDINODecoder"
37
+ DEEP_SUPERVISION: True
38
+ NO_OBJECT_WEIGHT: 0.1
39
+ CLASS_WEIGHT: 4.0
40
+ MASK_WEIGHT: 5.0
41
+ DICE_WEIGHT: 5.0
42
+ BOX_WEIGHT: 5.0
43
+ GIOU_WEIGHT: 2.0
44
+ HIDDEN_DIM: 256
45
+ NUM_OBJECT_QUERIES: 300
46
+ NHEADS: 8
47
+ DROPOUT: 0.0
48
+ DIM_FEEDFORWARD: 2048
49
+ ENC_LAYERS: 0
50
+ PRE_NORM: False
51
+ ENFORCE_INPUT_PROJ: False
52
+ SIZE_DIVISIBILITY: 32
53
+ DEC_LAYERS: 9 # 9+1, 9 decoder layers, add one for the loss on learnable query
54
+ TRAIN_NUM_POINTS: 12544
55
+ OVERSAMPLE_RATIO: 3.0
56
+ IMPORTANCE_SAMPLE_RATIO: 0.75
57
+ INITIAL_PRED: True
58
+ TWO_STAGE: True
59
+ DN: "standard"
60
+ DN_NUM: 100
61
+ INITIALIZE_BOX_TYPE: "no"
62
+ TEST:
63
+ SEMANTIC_ON: False
64
+ INSTANCE_ON: True
65
+ PANOPTIC_ON: False
66
+ OVERLAP_THRESHOLD: 0.8
67
+ OBJECT_MASK_THRESHOLD: 0.25
68
+ TEXT:
69
+ ARCH: clip_teacher
70
+ LANGUAGE_BACKBONE:
71
+ LANG_DIM: 512
GLEE/configs/SwinL.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GLEE"
3
+ MASK_ON: True
4
+ BACKBONE:
5
+ NAME: "D2SwinTransformer"
6
+ SWIN:
7
+ EMBED_DIM: 192
8
+ DEPTHS: [2, 2, 18, 2]
9
+ NUM_HEADS: [6, 12, 24, 48]
10
+ WINDOW_SIZE: 12
11
+ APE: False
12
+ DROP_PATH_RATE: 0.3
13
+ PATCH_NORM: True
14
+ PRETRAIN_IMG_SIZE: 384
15
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
16
+ PIXEL_STD: [58.395, 57.120, 57.375]
17
+ RESNETS:
18
+ DEPTH: 50
19
+ STEM_TYPE: "basic" # not used
20
+ STEM_OUT_CHANNELS: 64
21
+ STRIDE_IN_1X1: False
22
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
23
+ # NORM: "SyncBN"
24
+ RES5_MULTI_GRID: [1, 1, 1] # not used
25
+ SEM_SEG_HEAD:
26
+ NAME: "MaskDINOHead"
27
+ IGNORE_VALUE: 255
28
+ NUM_CLASSES: 80
29
+ LOSS_WEIGHT: 1.0
30
+ CONVS_DIM: 256
31
+ MASK_DIM: 256
32
+ NORM: "GN"
33
+ # pixel decoder
34
+ PIXEL_DECODER_NAME: "MaskDINOEncoder"
35
+ DIM_FEEDFORWARD: 2048
36
+ NUM_FEATURE_LEVELS: 3
37
+ TOTAL_NUM_FEATURE_LEVELS: 4
38
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
39
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
40
+ COMMON_STRIDE: 4
41
+ TRANSFORMER_ENC_LAYERS: 6
42
+ FEATURE_ORDER: "low2high"
43
+ MaskDINO:
44
+ TRANSFORMER_DECODER_NAME: "MaskDINODecoder"
45
+ DEEP_SUPERVISION: True
46
+ NO_OBJECT_WEIGHT: 0.1
47
+ CLASS_WEIGHT: 4.0
48
+ MASK_WEIGHT: 5.0
49
+ DICE_WEIGHT: 5.0
50
+ BOX_WEIGHT: 5.0
51
+ GIOU_WEIGHT: 2.0
52
+ HIDDEN_DIM: 256
53
+ NUM_OBJECT_QUERIES: 300
54
+ NHEADS: 8
55
+ DROPOUT: 0.0
56
+ DIM_FEEDFORWARD: 2048
57
+ ENC_LAYERS: 0
58
+ PRE_NORM: False
59
+ ENFORCE_INPUT_PROJ: False
60
+ SIZE_DIVISIBILITY: 32
61
+ DEC_LAYERS: 9 # 9+1, 9 decoder layers, add one for the loss on learnable query
62
+ TRAIN_NUM_POINTS: 12544
63
+ OVERSAMPLE_RATIO: 3.0
64
+ IMPORTANCE_SAMPLE_RATIO: 0.75
65
+ INITIAL_PRED: True
66
+ TWO_STAGE: True
67
+ DN: "standard"
68
+ DN_NUM: 100
69
+ INITIALIZE_BOX_TYPE: "no"
70
+ TEST:
71
+ SEMANTIC_ON: False
72
+ INSTANCE_ON: True
73
+ PANOPTIC_ON: False
74
+ OVERLAP_THRESHOLD: 0.8
75
+ OBJECT_MASK_THRESHOLD: 0.25
76
+ TEXT:
77
+ ARCH: clip_teacher
78
+ LANGUAGE_BACKBONE:
79
+ LANG_DIM: 512
GLEE/glee/.DS_Store ADDED
Binary file (6.15 kB). View file
 
GLEE/glee/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+
6
+ from .config import add_glee_config
7
+ from .config_deeplab import add_deeplab_config
8
+ # from .GLEE import GLEE
9
+ # from .data import build_detection_train_loader, build_detection_test_loader
10
+ from .backbone.swin import D2SwinTransformer
11
+ from .backbone.eva02 import D2_EVA02
12
+
GLEE/glee/backbone/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .build import build_backbone
2
+
3
+ from .resnet import *
4
+ from .swin import *
5
+ # from .focal import *
6
+ # from .focal_dw import *
7
+ from .backbone import *
GLEE/glee/backbone/backbone.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import torch.nn as nn
3
+
4
+ from detectron2.modeling import ShapeSpec
5
+
6
+ __all__ = ["Backbone"]
7
+
8
+
9
+ class Backbone(nn.Module):
10
+ """
11
+ Abstract base class for network backbones.
12
+ """
13
+
14
+ def __init__(self):
15
+ """
16
+ The `__init__` method of any subclass can specify its own set of arguments.
17
+ """
18
+ super().__init__()
19
+
20
+ def forward(self):
21
+ """
22
+ Subclasses must override this method, but adhere to the same return type.
23
+
24
+ Returns:
25
+ dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
26
+ """
27
+ pass
28
+
29
+ @property
30
+ def size_divisibility(self) -> int:
31
+ """
32
+ Some backbones require the input height and width to be divisible by a
33
+ specific integer. This is typically true for encoder / decoder type networks
34
+ with lateral connection (e.g., FPN) for which feature maps need to match
35
+ dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
36
+ input size divisibility is required.
37
+ """
38
+ return 0
39
+
40
+ def output_shape(self):
41
+ """
42
+ Returns:
43
+ dict[str->ShapeSpec]
44
+ """
45
+ # this is a backward-compatible default
46
+ return {
47
+ name: ShapeSpec(
48
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
49
+ )
50
+ for name in self._out_features
51
+ }
GLEE/glee/backbone/build.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import model_entrypoints
2
+ from .registry import is_model
3
+
4
+ from .backbone import *
5
+
6
+ def build_backbone(config, **kwargs):
7
+ model_name = config['MODEL']['BACKBONE']['NAME']
8
+ if not is_model(model_name):
9
+ raise ValueError(f'Unkown model: {model_name}')
10
+ model = model_entrypoints(model_name)(config, **kwargs)
11
+ return model
GLEE/glee/backbone/davit.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import itertools
3
+ import logging
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+ from collections import OrderedDict
10
+
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath, trunc_normal_
13
+
14
+ from detectron2.utils.file_io import PathManager
15
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
16
+
17
+ from .registry import register_backbone
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MySequential(nn.Sequential):
23
+ def forward(self, *inputs):
24
+ for module in self._modules.values():
25
+ if type(inputs) == tuple:
26
+ inputs = module(*inputs)
27
+ else:
28
+ inputs = module(inputs)
29
+ return inputs
30
+
31
+
32
+ class PreNorm(nn.Module):
33
+ def __init__(self, norm, fn, drop_path=None):
34
+ super().__init__()
35
+ self.norm = norm
36
+ self.fn = fn
37
+ self.drop_path = drop_path
38
+
39
+ def forward(self, x, *args, **kwargs):
40
+ shortcut = x
41
+ if self.norm != None:
42
+ x, size = self.fn(self.norm(x), *args, **kwargs)
43
+ else:
44
+ x, size = self.fn(x, *args, **kwargs)
45
+
46
+ if self.drop_path:
47
+ x = self.drop_path(x)
48
+
49
+ x = shortcut + x
50
+
51
+ return x, size
52
+
53
+
54
+ class Mlp(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_features,
58
+ hidden_features=None,
59
+ out_features=None,
60
+ act_layer=nn.GELU,
61
+ ):
62
+ super().__init__()
63
+ out_features = out_features or in_features
64
+ hidden_features = hidden_features or in_features
65
+ self.net = nn.Sequential(OrderedDict([
66
+ ("fc1", nn.Linear(in_features, hidden_features)),
67
+ ("act", act_layer()),
68
+ ("fc2", nn.Linear(hidden_features, out_features))
69
+ ]))
70
+
71
+ def forward(self, x, size):
72
+ return self.net(x), size
73
+
74
+
75
+ class DepthWiseConv2d(nn.Module):
76
+ def __init__(
77
+ self,
78
+ dim_in,
79
+ kernel_size,
80
+ padding,
81
+ stride,
82
+ bias=True,
83
+ ):
84
+ super().__init__()
85
+ self.dw = nn.Conv2d(
86
+ dim_in, dim_in,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim_in,
90
+ stride=stride,
91
+ bias=bias
92
+ )
93
+
94
+ def forward(self, x, size):
95
+ B, N, C = x.shape
96
+ H, W = size
97
+ assert N == H * W
98
+
99
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
100
+ size = (x.size(-2), x.size(-1))
101
+ x = x.flatten(2).transpose(1, 2)
102
+ return x, size
103
+
104
+
105
+ class ConvEmbed(nn.Module):
106
+ """ Image to Patch Embedding
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ patch_size=7,
112
+ in_chans=3,
113
+ embed_dim=64,
114
+ stride=4,
115
+ padding=2,
116
+ norm_layer=None,
117
+ pre_norm=True
118
+ ):
119
+ super().__init__()
120
+ self.patch_size = patch_size
121
+
122
+ self.proj = nn.Conv2d(
123
+ in_chans, embed_dim,
124
+ kernel_size=patch_size,
125
+ stride=stride,
126
+ padding=padding
127
+ )
128
+
129
+ dim_norm = in_chans if pre_norm else embed_dim
130
+ self.norm = norm_layer(dim_norm) if norm_layer else None
131
+
132
+ self.pre_norm = pre_norm
133
+
134
+ def forward(self, x, size):
135
+ H, W = size
136
+ if len(x.size()) == 3:
137
+ if self.norm and self.pre_norm:
138
+ x = self.norm(x)
139
+ x = rearrange(
140
+ x, 'b (h w) c -> b c h w',
141
+ h=H, w=W
142
+ )
143
+
144
+ x = self.proj(x)
145
+
146
+ _, _, H, W = x.shape
147
+ x = rearrange(x, 'b c h w -> b (h w) c')
148
+ if self.norm and not self.pre_norm:
149
+ x = self.norm(x)
150
+
151
+ return x, (H, W)
152
+
153
+
154
+ class ChannelAttention(nn.Module):
155
+
156
+ def __init__(self, dim, groups=8, qkv_bias=True):
157
+ super().__init__()
158
+
159
+ self.groups = groups
160
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
161
+ self.proj = nn.Linear(dim, dim)
162
+
163
+ def forward(self, x, size):
164
+ B, N, C = x.shape
165
+
166
+ qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
167
+ q, k, v = qkv[0], qkv[1], qkv[2]
168
+
169
+ q = q * (N ** -0.5)
170
+ attention = q.transpose(-1, -2) @ k
171
+ attention = attention.softmax(dim=-1)
172
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
173
+ x = x.transpose(1, 2).reshape(B, N, C)
174
+ x = self.proj(x)
175
+ return x, size
176
+
177
+
178
+ class ChannelBlock(nn.Module):
179
+
180
+ def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
181
+ drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
182
+ conv_at_attn=True, conv_at_ffn=True):
183
+ super().__init__()
184
+
185
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
186
+
187
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
188
+ self.channel_attn = PreNorm(
189
+ norm_layer(dim),
190
+ ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
191
+ drop_path
192
+ )
193
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
194
+ self.ffn = PreNorm(
195
+ norm_layer(dim),
196
+ Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
197
+ drop_path
198
+ )
199
+
200
+ def forward(self, x, size):
201
+ if self.conv1:
202
+ x, size = self.conv1(x, size)
203
+ x, size = self.channel_attn(x, size)
204
+
205
+ if self.conv2:
206
+ x, size = self.conv2(x, size)
207
+ x, size = self.ffn(x, size)
208
+
209
+ return x, size
210
+
211
+
212
+ def window_partition(x, window_size: int):
213
+ B, H, W, C = x.shape
214
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
215
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
216
+ return windows
217
+
218
+
219
+ def window_reverse(windows, window_size: int, H: int, W: int):
220
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
221
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
222
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
223
+ return x
224
+
225
+
226
+ class WindowAttention(nn.Module):
227
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
228
+
229
+ super().__init__()
230
+ self.dim = dim
231
+ self.window_size = window_size
232
+ self.num_heads = num_heads
233
+ head_dim = dim // num_heads
234
+ self.scale = head_dim ** -0.5
235
+
236
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
237
+ self.proj = nn.Linear(dim, dim)
238
+
239
+ self.softmax = nn.Softmax(dim=-1)
240
+
241
+ def forward(self, x, size):
242
+
243
+ H, W = size
244
+ B, L, C = x.shape
245
+ assert L == H * W, "input feature has wrong size"
246
+
247
+ x = x.view(B, H, W, C)
248
+
249
+ pad_l = pad_t = 0
250
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
251
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
252
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
253
+ _, Hp, Wp, _ = x.shape
254
+
255
+ x = window_partition(x, self.window_size)
256
+ x = x.view(-1, self.window_size * self.window_size, C)
257
+
258
+ # W-MSA/SW-MSA
259
+ # attn_windows = self.attn(x_windows)
260
+
261
+ B_, N, C = x.shape
262
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
263
+ q, k, v = qkv[0], qkv[1], qkv[2]
264
+
265
+ q = q * self.scale
266
+ attn = (q @ k.transpose(-2, -1))
267
+ attn = self.softmax(attn)
268
+
269
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
270
+ x = self.proj(x)
271
+
272
+ # merge windows
273
+ x = x.view(
274
+ -1, self.window_size, self.window_size, C
275
+ )
276
+ x = window_reverse(x, self.window_size, Hp, Wp)
277
+
278
+ if pad_r > 0 or pad_b > 0:
279
+ x = x[:, :H, :W, :].contiguous()
280
+
281
+ x = x.view(B, H * W, C)
282
+
283
+ return x, size
284
+
285
+
286
+ class SpatialBlock(nn.Module):
287
+
288
+ def __init__(self, dim, num_heads, window_size,
289
+ mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
290
+ norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
291
+ super().__init__()
292
+
293
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
294
+
295
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
296
+ self.window_attn = PreNorm(
297
+ norm_layer(dim),
298
+ WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
299
+ drop_path
300
+ )
301
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
302
+ self.ffn = PreNorm(
303
+ norm_layer(dim),
304
+ Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
305
+ drop_path
306
+ )
307
+
308
+ def forward(self, x, size):
309
+ if self.conv1:
310
+ x, size = self.conv1(x, size)
311
+ x, size = self.window_attn(x, size)
312
+
313
+ if self.conv2:
314
+ x, size = self.conv2(x, size)
315
+ x, size = self.ffn(x, size)
316
+ return x, size
317
+
318
+
319
+ class DaViT(nn.Module):
320
+ """ DaViT: Dual-Attention Transformer
321
+
322
+ Args:
323
+ img_size (int): Image size, Default: 224.
324
+ in_chans (int): Number of input image channels. Default: 3.
325
+ num_classes (int): Number of classes for classification head. Default: 1000.
326
+ patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
327
+ patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
328
+ patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
329
+ patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
330
+ embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
331
+ num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
332
+ num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
333
+ window_size (int): Window size. Default: 7.
334
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
335
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
336
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1.
337
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
338
+ enable_checkpoint (bool): If True, enable checkpointing. Default: False.
339
+ conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
340
+ conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ img_size=224,
346
+ in_chans=3,
347
+ num_classes=1000,
348
+ depths=(1, 1, 3, 1),
349
+ patch_size=(7, 2, 2, 2),
350
+ patch_stride=(4, 2, 2, 2),
351
+ patch_padding=(3, 0, 0, 0),
352
+ patch_prenorm=(False, False, False, False),
353
+ embed_dims=(64, 128, 192, 256),
354
+ num_heads=(3, 6, 12, 24),
355
+ num_groups=(3, 6, 12, 24),
356
+ window_size=7,
357
+ mlp_ratio=4.,
358
+ qkv_bias=True,
359
+ drop_path_rate=0.1,
360
+ norm_layer=nn.LayerNorm,
361
+ enable_checkpoint=False,
362
+ conv_at_attn=True,
363
+ conv_at_ffn=True,
364
+ out_indices=[],
365
+ ):
366
+ super().__init__()
367
+
368
+ self.num_classes = num_classes
369
+ self.embed_dims = embed_dims
370
+ self.num_heads = num_heads
371
+ self.num_groups = num_groups
372
+ self.num_stages = len(self.embed_dims)
373
+ self.enable_checkpoint = enable_checkpoint
374
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
375
+
376
+ num_stages = len(embed_dims)
377
+ self.img_size = img_size
378
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]
379
+
380
+
381
+ depth_offset = 0
382
+ convs = []
383
+ blocks = []
384
+ for i in range(num_stages):
385
+ conv_embed = ConvEmbed(
386
+ patch_size=patch_size[i],
387
+ stride=patch_stride[i],
388
+ padding=patch_padding[i],
389
+ in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
390
+ embed_dim=self.embed_dims[i],
391
+ norm_layer=norm_layer,
392
+ pre_norm=patch_prenorm[i]
393
+ )
394
+ convs.append(conv_embed)
395
+
396
+ print(f'=> Depth offset in stage {i}: {depth_offset}')
397
+ block = MySequential(
398
+ *[
399
+ MySequential(OrderedDict([
400
+ (
401
+ 'spatial_block', SpatialBlock(
402
+ embed_dims[i],
403
+ num_heads[i],
404
+ window_size,
405
+ drop_path_rate=dpr[depth_offset+j*2],
406
+ qkv_bias=qkv_bias,
407
+ mlp_ratio=mlp_ratio,
408
+ conv_at_attn=conv_at_attn,
409
+ conv_at_ffn=conv_at_ffn,
410
+ )
411
+ ),
412
+ (
413
+ 'channel_block', ChannelBlock(
414
+ embed_dims[i],
415
+ num_groups[i],
416
+ drop_path_rate=dpr[depth_offset+j*2+1],
417
+ qkv_bias=qkv_bias,
418
+ mlp_ratio=mlp_ratio,
419
+ conv_at_attn=conv_at_attn,
420
+ conv_at_ffn=conv_at_ffn,
421
+ )
422
+ )
423
+ ])) for j in range(depths[i])
424
+ ]
425
+ )
426
+ blocks.append(block)
427
+ depth_offset += depths[i]*2
428
+
429
+ self.convs = nn.ModuleList(convs)
430
+ self.blocks = nn.ModuleList(blocks)
431
+
432
+ self.out_indices = out_indices
433
+ # self.norms = norm_layer(self.embed_dims[-1])
434
+ # self.avgpool = nn.AdaptiveAvgPool1d(1)
435
+ # self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
436
+ self.apply(self._init_weights)
437
+
438
+ @property
439
+ def dim_out(self):
440
+ return self.embed_dims[-1]
441
+
442
+ def _init_weights(self, m):
443
+ if isinstance(m, nn.Linear):
444
+ trunc_normal_(m.weight, std=0.02)
445
+ if m.bias is not None:
446
+ nn.init.constant_(m.bias, 0)
447
+ elif isinstance(m, nn.Conv2d):
448
+ nn.init.normal_(m.weight, std=0.02)
449
+ for name, _ in m.named_parameters():
450
+ if name in ['bias']:
451
+ nn.init.constant_(m.bias, 0)
452
+ elif isinstance(m, nn.LayerNorm):
453
+ nn.init.constant_(m.weight, 1.0)
454
+ nn.init.constant_(m.bias, 0)
455
+ elif isinstance(m, nn.BatchNorm2d):
456
+ nn.init.constant_(m.weight, 1.0)
457
+ nn.init.constant_(m.bias, 0)
458
+
459
+ def _try_remap_keys(self, pretrained_dict):
460
+ remap_keys = {
461
+ "conv_embeds": "convs",
462
+ "main_blocks": "blocks",
463
+ "0.cpe.0.proj": "spatial_block.conv1.fn.dw",
464
+ "0.attn": "spatial_block.window_attn.fn",
465
+ "0.cpe.1.proj": "spatial_block.conv2.fn.dw",
466
+ "0.mlp": "spatial_block.ffn.fn.net",
467
+ "1.cpe.0.proj": "channel_block.conv1.fn.dw",
468
+ "1.attn": "channel_block.channel_attn.fn",
469
+ "1.cpe.1.proj": "channel_block.conv2.fn.dw",
470
+ "1.mlp": "channel_block.ffn.fn.net",
471
+ "0.norm1": "spatial_block.window_attn.norm",
472
+ "0.norm2": "spatial_block.ffn.norm",
473
+ "1.norm1": "channel_block.channel_attn.norm",
474
+ "1.norm2": "channel_block.ffn.norm"
475
+ }
476
+
477
+ full_key_mappings = {}
478
+ for k in pretrained_dict.keys():
479
+ old_k = k
480
+ for remap_key in remap_keys.keys():
481
+ if remap_key in k:
482
+ print(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
483
+ k = k.replace(remap_key, remap_keys[remap_key])
484
+
485
+ full_key_mappings[old_k] = k
486
+
487
+ return full_key_mappings
488
+
489
+ def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
490
+ model_dict = self.state_dict()
491
+ stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
492
+ full_key_mappings = self._try_remap_keys(pretrained_dict)
493
+
494
+ pretrained_dict = {
495
+ stripped_key(full_key_mappings[k]): v for k, v in pretrained_dict.items()
496
+ if stripped_key(full_key_mappings[k]) in model_dict.keys()
497
+ }
498
+ need_init_state_dict = {}
499
+ for k, v in pretrained_dict.items():
500
+ need_init = (
501
+ k.split('.')[0] in pretrained_layers
502
+ or pretrained_layers[0] == '*'
503
+ )
504
+ if need_init:
505
+ if verbose:
506
+ print(f'=> init {k} from pretrained state dict')
507
+
508
+ need_init_state_dict[k] = v
509
+ self.load_state_dict(need_init_state_dict, strict=False)
510
+
511
+ def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
512
+ if os.path.isfile(pretrained):
513
+ print(f'=> loading pretrained model {pretrained}')
514
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
515
+
516
+ self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
517
+
518
+ def forward_features(self, x):
519
+ input_size = (x.size(2), x.size(3))
520
+
521
+ outs = {}
522
+ for i, (conv, block) in enumerate(zip(self.convs, self.blocks)):
523
+ x, input_size = conv(x, input_size)
524
+ if self.enable_checkpoint:
525
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
526
+ else:
527
+ x, input_size = block(x, input_size)
528
+ if i in self.out_indices:
529
+ out = x.view(-1, *input_size, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
530
+ outs["res{}".format(i + 2)] = out
531
+
532
+ if len(self.out_indices) == 0:
533
+ outs["res5"] = x.view(-1, *input_size, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
534
+
535
+ return outs
536
+
537
+ def forward(self, x):
538
+ x = self.forward_features(x)
539
+ # x = self.head(x)
540
+ return x
541
+
542
+ class D2DaViT(DaViT, Backbone):
543
+ def __init__(self, cfg, input_shape):
544
+
545
+ spec = cfg['BACKBONE']['DAVIT']
546
+
547
+ super().__init__(
548
+ num_classes=0,
549
+ depths=spec['DEPTHS'],
550
+ embed_dims=spec['DIM_EMBED'],
551
+ num_heads=spec['NUM_HEADS'],
552
+ num_groups=spec['NUM_GROUPS'],
553
+ patch_size=spec['PATCH_SIZE'],
554
+ patch_stride=spec['PATCH_STRIDE'],
555
+ patch_padding=spec['PATCH_PADDING'],
556
+ patch_prenorm=spec['PATCH_PRENORM'],
557
+ drop_path_rate=spec['DROP_PATH_RATE'],
558
+ img_size=input_shape,
559
+ window_size=spec.get('WINDOW_SIZE', 7),
560
+ enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
561
+ conv_at_attn=spec.get('CONV_AT_ATTN', True),
562
+ conv_at_ffn=spec.get('CONV_AT_FFN', True),
563
+ out_indices=spec.get('OUT_INDICES', []),
564
+ )
565
+
566
+ self._out_features = cfg['BACKBONE']['DAVIT']['OUT_FEATURES']
567
+
568
+ self._out_feature_strides = {
569
+ "res2": 4,
570
+ "res3": 8,
571
+ "res4": 16,
572
+ "res5": 32,
573
+ }
574
+ self._out_feature_channels = {
575
+ "res2": self.embed_dims[0],
576
+ "res3": self.embed_dims[1],
577
+ "res4": self.embed_dims[2],
578
+ "res5": self.embed_dims[3],
579
+ }
580
+
581
+ def forward(self, x):
582
+ """
583
+ Args:
584
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
585
+ Returns:
586
+ dict[str->Tensor]: names and the corresponding features
587
+ """
588
+ assert (
589
+ x.dim() == 4
590
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
591
+ outputs = {}
592
+ y = super().forward(x)
593
+
594
+ for k in y.keys():
595
+ if k in self._out_features:
596
+ outputs[k] = y[k]
597
+ return outputs
598
+
599
+ def output_shape(self):
600
+ return {
601
+ name: ShapeSpec(
602
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
603
+ )
604
+ for name in self._out_features
605
+ }
606
+
607
+ @property
608
+ def size_divisibility(self):
609
+ return 32
610
+
611
+ @register_backbone
612
+ def get_davit_backbone(cfg):
613
+ davit = D2DaViT(cfg['MODEL'], 224)
614
+
615
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
616
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
617
+ logger.info(f'=> init from {filename}')
618
+ davit.from_pretrained(
619
+ filename,
620
+ cfg['MODEL']['BACKBONE']['DAVIT'].get('PRETRAINED_LAYERS', ['*']),
621
+ cfg['VERBOSE'])
622
+
623
+ return davit
GLEE/glee/backbone/eva01.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from functools import partial
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, Size
10
+ from typing import Union, List
11
+ from torch.nn.parameter import Parameter
12
+ import numbers
13
+
14
+ from detectron2.layers import CNNBlockBase, Conv2d, get_norm
15
+ from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
16
+
17
+ from fairscale.nn.checkpoint import checkpoint_wrapper
18
+ from timm.models.layers import DropPath, Mlp, trunc_normal_
19
+
20
+ # from detectron2.modeling.backbone import Backbone
21
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
22
+
23
+ from .eva_01_utils import (
24
+ PatchEmbed,
25
+ add_decomposed_rel_pos,
26
+ get_abs_pos,
27
+ window_partition,
28
+ window_unpartition,
29
+ )
30
+ from detectron2.modeling.backbone.fpn import LastLevelMaxPool
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ __all__ = ["EVAViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
36
+
37
+
38
+ _shape_t = Union[int, List[int], Size]
39
+
40
+
41
+ # steal from beit https://github.com/microsoft/unilm/tree/master/beit
42
+ class LayerNormWithForceFP32(nn.Module):
43
+ __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
44
+ normalized_shape: _shape_t
45
+ eps: float
46
+ elementwise_affine: bool
47
+
48
+ def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
49
+ super(LayerNormWithForceFP32, self).__init__()
50
+ if isinstance(normalized_shape, numbers.Integral):
51
+ normalized_shape = (normalized_shape,)
52
+ self.normalized_shape = tuple(normalized_shape)
53
+ self.eps = eps
54
+ self.elementwise_affine = elementwise_affine
55
+ if self.elementwise_affine:
56
+ self.weight = Parameter(torch.Tensor(*normalized_shape))
57
+ self.bias = Parameter(torch.Tensor(*normalized_shape))
58
+ else:
59
+ self.register_parameter('weight', None)
60
+ self.register_parameter('bias', None)
61
+ self.reset_parameters()
62
+
63
+ def reset_parameters(self) -> None:
64
+ if self.elementwise_affine:
65
+ nn.init.ones_(self.weight)
66
+ nn.init.zeros_(self.bias)
67
+
68
+ def forward(self, input: Tensor) -> Tensor:
69
+ return F.layer_norm(
70
+ input.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(input)
71
+
72
+ def extra_repr(self) -> Tensor:
73
+ return '{normalized_shape}, eps={eps}, ' \
74
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
75
+
76
+
77
+ class Attention(nn.Module):
78
+ """Multi-head Attention block with relative position embeddings."""
79
+
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ num_heads=8,
84
+ qkv_bias=True,
85
+ beit_like_qkv_bias=False,
86
+ use_rel_pos=False,
87
+ rel_pos_zero_init=True,
88
+ input_size=None,
89
+ interp_type="vitdet",
90
+ ):
91
+ """
92
+ Args:
93
+ dim (int): Number of input channels.
94
+ num_heads (int): Number of attention heads.
95
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
96
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
97
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
98
+ input_size (int or None): Input resolution for calculating the relative positional
99
+ parameter size.
100
+ """
101
+ super().__init__()
102
+ self.num_heads = num_heads
103
+ head_dim = dim // num_heads
104
+ self.scale = head_dim**-0.5
105
+
106
+ self.beit_like_qkv_bias = beit_like_qkv_bias
107
+ if beit_like_qkv_bias:
108
+ self.q_bias = nn.Parameter(torch.zeros(dim))
109
+ self.v_bias = nn.Parameter(torch.zeros(dim))
110
+
111
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
112
+ self.proj = nn.Linear(dim, dim)
113
+
114
+ self.use_rel_pos = use_rel_pos
115
+ self.interp_type = interp_type
116
+ if self.use_rel_pos:
117
+ # initialize relative positional embeddings
118
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
119
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
120
+
121
+ if not rel_pos_zero_init:
122
+ trunc_normal_(self.rel_pos_h, std=0.02)
123
+ trunc_normal_(self.rel_pos_w, std=0.02)
124
+ self.qk_float = False
125
+
126
+ def forward(self, x):
127
+ B, H, W, _ = x.shape
128
+ # qkv with shape (3, B, nHead, H * W, C)
129
+ if self.beit_like_qkv_bias:
130
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
131
+ qkv = torch.nn.functional.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
132
+ qkv = qkv.reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
133
+ else:
134
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
135
+ # q, k, v with shape (B * nHead, H * W, C)
136
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
137
+
138
+ if self.qk_float:
139
+ attn = (q.float() * self.scale) @ k.float().transpose(-2, -1)
140
+ if self.use_rel_pos:
141
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W), self.interp_type)
142
+ attn = attn.softmax(dim=-1).type_as(x)
143
+ else:
144
+ attn = (q * self.scale) @ k.transpose(-2, -1)
145
+ if self.use_rel_pos:
146
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W), self.interp_type)
147
+ attn = attn.softmax(dim=-1)
148
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
149
+ x = self.proj(x)
150
+
151
+ return x
152
+
153
+
154
+ class ResBottleneckBlock(CNNBlockBase):
155
+ """
156
+ The standard bottleneck residual block without the last activation layer.
157
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ bottleneck_channels,
165
+ norm="LN",
166
+ act_layer=nn.GELU,
167
+ ):
168
+ """
169
+ Args:
170
+ in_channels (int): Number of input channels.
171
+ out_channels (int): Number of output channels.
172
+ bottleneck_channels (int): number of output channels for the 3x3
173
+ "bottleneck" conv layers.
174
+ norm (str or callable): normalization for all conv layers.
175
+ See :func:`layers.get_norm` for supported format.
176
+ act_layer (callable): activation for all conv layers.
177
+ """
178
+ super().__init__(in_channels, out_channels, 1)
179
+
180
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
181
+ self.norm1 = get_norm(norm, bottleneck_channels)
182
+ self.act1 = act_layer()
183
+
184
+ self.conv2 = Conv2d(
185
+ bottleneck_channels,
186
+ bottleneck_channels,
187
+ 3,
188
+ padding=1,
189
+ bias=False,
190
+ )
191
+ self.norm2 = get_norm(norm, bottleneck_channels)
192
+ self.act2 = act_layer()
193
+
194
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
195
+ self.norm3 = get_norm(norm, out_channels)
196
+
197
+ for layer in [self.conv1, self.conv2, self.conv3]:
198
+ weight_init.c2_msra_fill(layer)
199
+ for layer in [self.norm1, self.norm2]:
200
+ layer.weight.data.fill_(1.0)
201
+ layer.bias.data.zero_()
202
+ # zero init last norm layer.
203
+ self.norm3.weight.data.zero_()
204
+ self.norm3.bias.data.zero_()
205
+
206
+ def forward(self, x):
207
+ out = x
208
+ for layer in self.children():
209
+ out = layer(out)
210
+
211
+ out = x + out
212
+ return out
213
+
214
+
215
+ class Block(nn.Module):
216
+ """Transformer blocks with support of window attention and residual propagation blocks"""
217
+
218
+ def __init__(
219
+ self,
220
+ dim,
221
+ num_heads,
222
+ mlp_ratio=4.0,
223
+ qkv_bias=True,
224
+ drop_path=0.0,
225
+ norm_layer=LayerNormWithForceFP32,
226
+ act_layer=nn.GELU,
227
+ use_rel_pos=False,
228
+ rel_pos_zero_init=True,
229
+ window_size=0,
230
+ use_residual_block=False,
231
+ input_size=None,
232
+ beit_like_qkv_bias=False,
233
+ beit_like_gamma=False,
234
+ interp_type="vitdet",
235
+ ):
236
+ """
237
+ Args:
238
+ dim (int): Number of input channels.
239
+ num_heads (int): Number of attention heads in each ViT block.
240
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
241
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
242
+ drop_path (float): Stochastic depth rate.
243
+ norm_layer (nn.Module): Normalization layer.
244
+ act_layer (nn.Module): Activation layer.
245
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
246
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
247
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
248
+ use window attention.
249
+ use_residual_block (bool): If True, use a residual block after the MLP block.
250
+ input_size (int or None): Input resolution for calculating the relative positional
251
+ parameter size.
252
+ beit_like_qkv_bias (bool)
253
+ beit_like_gamma (bool)
254
+ """
255
+ super().__init__()
256
+ self.norm1 = norm_layer(dim)
257
+ self.attn = Attention(
258
+ dim,
259
+ num_heads=num_heads,
260
+ qkv_bias=qkv_bias,
261
+ use_rel_pos=use_rel_pos,
262
+ rel_pos_zero_init=rel_pos_zero_init,
263
+ input_size=input_size if window_size == 0 else (window_size, window_size),
264
+ beit_like_qkv_bias=beit_like_qkv_bias,
265
+ interp_type=interp_type,
266
+ )
267
+
268
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
269
+ self.norm2 = norm_layer(dim)
270
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
271
+
272
+ self.window_size = window_size
273
+
274
+ self.use_residual_block = use_residual_block
275
+ if use_residual_block:
276
+ # Use a residual block with bottleneck channel as dim // 2
277
+ self.residual = ResBottleneckBlock(
278
+ in_channels=dim,
279
+ out_channels=dim,
280
+ bottleneck_channels=dim // 2,
281
+ norm="LN",
282
+ act_layer=act_layer,
283
+ )
284
+
285
+ self.beit_like_gamma = beit_like_gamma
286
+ if beit_like_gamma:
287
+ self.gamma_1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
288
+ self.gamma_2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
289
+
290
+ def forward(self, x):
291
+ shortcut = x
292
+ x = self.norm1(x)
293
+ # Window partition
294
+ if self.window_size > 0:
295
+ H, W = x.shape[1], x.shape[2]
296
+ x, pad_hw = window_partition(x, self.window_size)
297
+
298
+ x = self.attn(x)
299
+ # Reverse window partition
300
+ if self.window_size > 0:
301
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
302
+
303
+ if self.beit_like_gamma:
304
+ x = shortcut + self.drop_path(self.gamma_1 * x)
305
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
306
+ else:
307
+ x = shortcut + self.drop_path(x)
308
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
309
+
310
+ if self.use_residual_block:
311
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
312
+
313
+ return x
314
+
315
+
316
+ class EVAViT(Backbone):
317
+ """
318
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
319
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
320
+ https://arxiv.org/abs/2203.16527
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ img_size=1024,
326
+ patch_size=16,
327
+ in_chans=3,
328
+ embed_dim=768,
329
+ depth=12,
330
+ num_heads=12,
331
+ mlp_ratio=4.0,
332
+ qkv_bias=True,
333
+ drop_path_rate=0.0,
334
+ norm_layer=LayerNormWithForceFP32,
335
+ act_layer=nn.GELU,
336
+ use_abs_pos=True,
337
+ use_rel_pos=False,
338
+ rel_pos_zero_init=True,
339
+ window_size=0,
340
+ window_block_indexes=(),
341
+ residual_block_indexes=(),
342
+ use_act_checkpoint=False,
343
+ pretrain_img_size=224,
344
+ pretrain_use_cls_token=True,
345
+ out_feature="last_feat",
346
+ beit_like_qkv_bias=True,
347
+ beit_like_gamma=False,
348
+ freeze_patch_embed=False,
349
+ interp_type="vitdet",
350
+ ):
351
+ """
352
+ Args:
353
+ img_size (int): Input image size.
354
+ patch_size (int): Patch size.
355
+ in_chans (int): Number of input image channels.
356
+ embed_dim (int): Patch embedding dimension.
357
+ depth (int): Depth of ViT.
358
+ num_heads (int): Number of attention heads in each ViT block.
359
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
360
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
361
+ drop_path_rate (float): Stochastic depth rate.
362
+ norm_layer (nn.Module): Normalization layer.
363
+ act_layer (nn.Module): Activation layer.
364
+ use_abs_pos (bool): If True, use absolute positional embeddings.
365
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
366
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
367
+ window_size (int): Window size for window attention blocks.
368
+ window_block_indexes (list): Indexes for blocks using window attention.
369
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
370
+ use_act_checkpoint (bool): If True, use activation checkpointing.
371
+ pretrain_img_size (int): input image size for pretraining models.
372
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
373
+ out_feature (str): name of the feature from the last block.
374
+ beit_like_qkv_bias (bool): beit_like_model that has gamma_1 and gamma_2 in blocks and qkv_bias=False
375
+ beit_like_gamma (bool)
376
+ freeze_patch_embed (bool)
377
+ interp_type: "vitdet" for training / fine-ting, "beit" for eval (slightly improvement at a higher res)
378
+ """
379
+ super().__init__()
380
+ self.pretrain_use_cls_token = pretrain_use_cls_token
381
+
382
+ self.patch_embed = PatchEmbed(
383
+ kernel_size=(patch_size, patch_size),
384
+ stride=(patch_size, patch_size),
385
+ in_chans=in_chans,
386
+ embed_dim=embed_dim,
387
+ )
388
+
389
+ if use_abs_pos:
390
+ # Initialize absolute positional embedding with pretrain image size.
391
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
392
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
393
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
394
+ else:
395
+ self.pos_embed = None
396
+
397
+ # stochastic depth decay rule
398
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
399
+
400
+ self.blocks = nn.ModuleList()
401
+ if beit_like_qkv_bias:
402
+ qkv_bias = False
403
+ for i in range(depth):
404
+ block = Block(
405
+ dim=embed_dim,
406
+ num_heads=num_heads,
407
+ mlp_ratio=mlp_ratio,
408
+ qkv_bias=qkv_bias,
409
+ drop_path=dpr[i],
410
+ norm_layer=norm_layer,
411
+ act_layer=act_layer,
412
+ use_rel_pos=use_rel_pos,
413
+ rel_pos_zero_init=rel_pos_zero_init,
414
+ window_size=window_size if i in window_block_indexes else 0,
415
+ use_residual_block=i in residual_block_indexes,
416
+ input_size=(img_size // patch_size, img_size // patch_size),
417
+ beit_like_qkv_bias=beit_like_qkv_bias,
418
+ beit_like_gamma=beit_like_gamma,
419
+ interp_type=interp_type,
420
+ )
421
+ if use_act_checkpoint:
422
+ block = checkpoint_wrapper(block)
423
+ self.blocks.append(block)
424
+
425
+ self._out_feature_channels = {out_feature: embed_dim}
426
+ self._out_feature_strides = {out_feature: patch_size}
427
+ self._out_features = [out_feature]
428
+
429
+ if self.pos_embed is not None:
430
+ trunc_normal_(self.pos_embed, std=0.02)
431
+
432
+ self.freeze_patch_embed = freeze_patch_embed
433
+ self.apply(self._init_weights)
434
+
435
+ def _init_weights(self, m):
436
+ if isinstance(m, nn.Linear):
437
+ trunc_normal_(m.weight, std=0.02)
438
+ if isinstance(m, nn.Linear) and m.bias is not None:
439
+ nn.init.constant_(m.bias, 0)
440
+ elif isinstance(m, LayerNormWithForceFP32):
441
+ nn.init.constant_(m.bias, 0)
442
+ nn.init.constant_(m.weight, 1.0)
443
+
444
+ if self.freeze_patch_embed:
445
+ for n, p in self.patch_embed.named_parameters():
446
+ p.requires_grad = False
447
+
448
+ def forward(self, x):
449
+ x = self.patch_embed(x)
450
+ if self.pos_embed is not None:
451
+ x = x + get_abs_pos(
452
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
453
+ )
454
+
455
+ for blk in self.blocks:
456
+ x = blk(x)
457
+
458
+ outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
459
+ return outputs
460
+
461
+
462
+ class SimpleFeaturePyramid(Backbone):
463
+ """
464
+ This module implements SimpleFeaturePyramid in :paper:`vitdet`.
465
+ It creates pyramid features built on top of the input feature map.
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ net,
471
+ in_feature,
472
+ out_channels,
473
+ scale_factors,
474
+ top_block=None,
475
+ norm="LN",
476
+ square_pad=0,
477
+ ):
478
+ """
479
+ Args:
480
+ net (Backbone): module representing the subnetwork backbone.
481
+ Must be a subclass of :class:`Backbone`.
482
+ in_feature (str): names of the input feature maps coming
483
+ from the net.
484
+ out_channels (int): number of channels in the output feature maps.
485
+ scale_factors (list[float]): list of scaling factors to upsample or downsample
486
+ the input features for creating pyramid features.
487
+ top_block (nn.Module or None): if provided, an extra operation will
488
+ be performed on the output of the last (smallest resolution)
489
+ pyramid output, and the result will extend the result list. The top_block
490
+ further downsamples the feature map. It must have an attribute
491
+ "num_levels", meaning the number of extra pyramid levels added by
492
+ this block, and "in_feature", which is a string representing
493
+ its input feature (e.g., p5).
494
+ norm (str): the normalization to use.
495
+ square_pad (int): If > 0, require input images to be padded to specific square size.
496
+ """
497
+ super(SimpleFeaturePyramid, self).__init__()
498
+ assert isinstance(net, Backbone)
499
+ self.scale_factors = scale_factors
500
+
501
+ input_shapes = net.output_shape()
502
+ strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
503
+ _assert_strides_are_log2_contiguous(strides)
504
+
505
+ dim = input_shapes[in_feature].channels
506
+ self.stages = []
507
+ use_bias = norm == ""
508
+ for idx, scale in enumerate(scale_factors):
509
+ out_dim = dim
510
+ if scale == 4.0:
511
+ layers = [
512
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
513
+ get_norm(norm, dim // 2),
514
+ nn.GELU(),
515
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
516
+ ]
517
+ out_dim = dim // 4
518
+ elif scale == 2.0:
519
+ layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
520
+ out_dim = dim // 2
521
+ elif scale == 1.0:
522
+ layers = []
523
+ elif scale == 0.5:
524
+ layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
525
+ else:
526
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
527
+
528
+ layers.extend(
529
+ [
530
+ Conv2d(
531
+ out_dim,
532
+ out_channels,
533
+ kernel_size=1,
534
+ bias=use_bias,
535
+ norm=get_norm(norm, out_channels),
536
+ ),
537
+ Conv2d(
538
+ out_channels,
539
+ out_channels,
540
+ kernel_size=3,
541
+ padding=1,
542
+ bias=use_bias,
543
+ norm=get_norm(norm, out_channels),
544
+ ),
545
+ ]
546
+ )
547
+ layers = nn.Sequential(*layers)
548
+
549
+ stage = int(math.log2(strides[idx]))
550
+ self.add_module(f"simfp_{stage}", layers)
551
+ self.stages.append(layers)
552
+
553
+ self.net = net
554
+ self.in_feature = in_feature
555
+ self.top_block = top_block
556
+ # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
557
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
558
+ # top block output feature maps.
559
+ if self.top_block is not None:
560
+ for s in range(stage, stage + self.top_block.num_levels):
561
+ self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
562
+
563
+ self._out_features = list(self._out_feature_strides.keys())
564
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
565
+ self._size_divisibility = strides[-1]
566
+ self._square_pad = square_pad
567
+
568
+ @property
569
+ def padding_constraints(self):
570
+ return {
571
+ "size_divisiblity": self._size_divisibility,
572
+ "square_size": self._square_pad,
573
+ }
574
+
575
+ def forward(self, x):
576
+ """
577
+ Args:
578
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
579
+
580
+ Returns:
581
+ dict[str->Tensor]:
582
+ mapping from feature map name to pyramid feature map tensor
583
+ in high to low resolution order. Returned feature names follow the FPN
584
+ convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
585
+ ["p2", "p3", ..., "p6"].
586
+ """
587
+ bottom_up_features = self.net(x)
588
+ features = bottom_up_features[self.in_feature]
589
+ results = []
590
+
591
+ for stage in self.stages:
592
+ results.append(stage(features))
593
+
594
+ if self.top_block is not None:
595
+ if self.top_block.in_feature in bottom_up_features:
596
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
597
+ else:
598
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
599
+ results.extend(self.top_block(top_block_in_feature))
600
+ assert len(self._out_features) == len(results)
601
+ return {f: res for f, res in zip(self._out_features, results)}
602
+
603
+
604
+
605
+ @BACKBONE_REGISTRY.register()
606
+ class D2_EVA01(SimpleFeaturePyramid):
607
+ def __init__(self, cfg, input_shape):
608
+
609
+ super().__init__(
610
+ net = EVAViT(
611
+ img_size= cfg.MODEL.EVA01.IMAGE_SIZE,
612
+ patch_size=cfg.MODEL.EVA01.PATCH_SIZE,
613
+ window_size= cfg.MODEL.EVA01.WINDOW_SIZE,
614
+ embed_dim= cfg.MODEL.EVA01.DMBED_DIM,
615
+ depth= cfg.MODEL.EVA01.DEPTH,
616
+ num_heads= cfg.MODEL.EVA01.NUM_HEADS ,
617
+ drop_path_rate= cfg.MODEL.EVA01.DROP_PATH_RATE,
618
+ mlp_ratio= cfg.MODEL.EVA01.MLP_RATIO,
619
+ qkv_bias=True,
620
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
621
+ window_block_indexes= cfg.MODEL.EVA01.WINDOW_BLOCK_INDEXES,
622
+ residual_block_indexes=[],
623
+ use_act_checkpoint = True,
624
+ use_rel_pos = True,
625
+ out_feature="last_feat",
626
+ beit_like_qkv_bias=cfg.MODEL.EVA01.BEIT_LIKE_QKV_BIAS ,
627
+ beit_like_gamma= cfg.MODEL.EVA01.BEIT_LIKE_GAMMA,
628
+ freeze_patch_embed= cfg.MODEL.EVA01.FREEZE_PATH_EMBED,
629
+ ),
630
+ in_feature = "last_feat",
631
+ out_channels=256,
632
+ scale_factors=(2.0, 1.0, 0.5), # (4.0, 2.0, 1.0, 0.5) in ViTDet
633
+ top_block=LastLevelMaxPool(),
634
+ norm="LN",
635
+ square_pad=cfg.MODEL.EVA01.IMAGE_SIZE,
636
+
637
+ )
638
+ pretrained_weight = cfg.MODEL.EVA01.PRETRAINED_WEIGHT
639
+ if pretrained_weight:
640
+ checkpoint = torch.load(pretrained_weight, map_location='cpu')
641
+ print(f'\nload pretrain weight from {pretrained_weight} \n')
642
+ self.load_state_dict(checkpoint['model'], strict=False)
643
+
644
+ def output_shape(self):
645
+ return {
646
+ name: ShapeSpec(
647
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
648
+ )
649
+ for name in self._out_features
650
+ }
651
+
652
+ @property
653
+ def size_divisibility(self):
654
+ return 32
655
+
656
+
657
+
658
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
659
+ """
660
+ Calculate lr decay rate for different ViT blocks.
661
+ Args:
662
+ name (string): parameter name.
663
+ lr_decay_rate (float): base lr decay rate.
664
+ num_layers (int): number of ViT blocks.
665
+
666
+ Returns:
667
+ lr decay rate for the given parameter.
668
+ """
669
+ layer_id = num_layers + 1
670
+ if 'backbone' in name: #name.startswith("backbone"):
671
+ if ".pos_embed" in name or ".patch_embed" in name:
672
+ layer_id = 0
673
+ elif ".blocks." in name and ".residual." not in name:
674
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
675
+
676
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
GLEE/glee/backbone/eva02-dino.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from functools import partial
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from detectron2.layers import CNNBlockBase, Conv2d, get_norm
11
+ from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
12
+
13
+ from detectron2.modeling.backbone import Backbone
14
+ from .eva_02_utils import (
15
+ PatchEmbed,
16
+ add_decomposed_rel_pos,
17
+ get_abs_pos,
18
+ window_partition,
19
+ window_unpartition,
20
+ VisionRotaryEmbeddingFast,
21
+ )
22
+
23
+ try:
24
+ import xformers.ops as xops
25
+ HAS_XFORMER=True
26
+ except:
27
+ HAS_XFORMER=False
28
+ pass
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+
35
+ __all__ = ["EVA02_ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
36
+
37
+
38
+
39
+ class SwiGLU(nn.Module):
40
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
41
+ norm_layer=nn.LayerNorm, subln=False
42
+ ):
43
+ super().__init__()
44
+ out_features = out_features or in_features
45
+ hidden_features = hidden_features or in_features
46
+
47
+ self.w1 = nn.Linear(in_features, hidden_features)
48
+ self.w2 = nn.Linear(in_features, hidden_features)
49
+
50
+ self.act = act_layer()
51
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
52
+ self.w3 = nn.Linear(hidden_features, out_features)
53
+
54
+ self.drop = nn.Dropout(drop)
55
+
56
+ def forward(self, x):
57
+ x1 = self.w1(x)
58
+ x2 = self.w2(x)
59
+ hidden = self.act(x1) * x2
60
+ x = self.ffn_ln(hidden)
61
+ x = self.w3(x)
62
+ x = self.drop(x)
63
+ return x
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(
68
+ self,
69
+ dim,
70
+ num_heads=8,
71
+ qkv_bias=True,
72
+ qk_scale=None,
73
+ attn_head_dim=None,
74
+ rope=None,
75
+ xattn=True,
76
+ ):
77
+ super().__init__()
78
+ self.num_heads = num_heads
79
+ head_dim = dim // num_heads
80
+ if attn_head_dim is not None:
81
+ head_dim = attn_head_dim
82
+ all_head_dim = head_dim * self.num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
86
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
87
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
88
+
89
+ if qkv_bias:
90
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
91
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
92
+ else:
93
+ self.q_bias = None
94
+ self.v_bias = None
95
+
96
+ self.rope = rope
97
+ self.xattn = xattn
98
+ self.proj = nn.Linear(all_head_dim, dim)
99
+
100
+ if not HAS_XFORMER:
101
+ self.xattn = False
102
+
103
+ def forward(self, x):
104
+ B, H, W, C = x.shape
105
+ x = x.view(B, -1, C)
106
+ N = H * W
107
+
108
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
109
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
110
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
111
+
112
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
113
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
114
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
115
+
116
+ ## rope
117
+ q = self.rope(q).type_as(v)
118
+ k = self.rope(k).type_as(v)
119
+
120
+ if self.xattn:
121
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
122
+ k = k.permute(0, 2, 1, 3)
123
+ v = v.permute(0, 2, 1, 3)
124
+
125
+ x = xops.memory_efficient_attention(q, k, v)
126
+ x = x.reshape(B, N, -1)
127
+ else:
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+ attn = attn.softmax(dim=-1).type_as(x)
131
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
132
+
133
+ x = self.proj(x)
134
+ x = x.view(B, H, W, C)
135
+
136
+ return x
137
+
138
+
139
+ class ResBottleneckBlock(CNNBlockBase):
140
+ """
141
+ The standard bottleneck residual block without the last activation layer.
142
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ in_channels,
148
+ out_channels,
149
+ bottleneck_channels,
150
+ norm="LN",
151
+ act_layer=nn.GELU,
152
+ ):
153
+ """
154
+ Args:
155
+ in_channels (int): Number of input channels.
156
+ out_channels (int): Number of output channels.
157
+ bottleneck_channels (int): number of output channels for the 3x3
158
+ "bottleneck" conv layers.
159
+ norm (str or callable): normalization for all conv layers.
160
+ See :func:`layers.get_norm` for supported format.
161
+ act_layer (callable): activation for all conv layers.
162
+ """
163
+ super().__init__(in_channels, out_channels, 1)
164
+
165
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
166
+ self.norm1 = get_norm(norm, bottleneck_channels)
167
+ self.act1 = act_layer()
168
+
169
+ self.conv2 = Conv2d(
170
+ bottleneck_channels,
171
+ bottleneck_channels,
172
+ 3,
173
+ padding=1,
174
+ bias=False,
175
+ )
176
+ self.norm2 = get_norm(norm, bottleneck_channels)
177
+ self.act2 = act_layer()
178
+
179
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
180
+ self.norm3 = get_norm(norm, out_channels)
181
+
182
+ for layer in [self.conv1, self.conv2, self.conv3]:
183
+ weight_init.c2_msra_fill(layer)
184
+ for layer in [self.norm1, self.norm2]:
185
+ layer.weight.data.fill_(1.0)
186
+ layer.bias.data.zero_()
187
+ # zero init last norm layer.
188
+ self.norm3.weight.data.zero_()
189
+ self.norm3.bias.data.zero_()
190
+
191
+ def forward(self, x):
192
+ out = x
193
+ for layer in self.children():
194
+ out = layer(out)
195
+
196
+ out = x + out
197
+ return out
198
+
199
+
200
+ class Block(nn.Module):
201
+ """Transformer blocks with support of window attention and residual propagation blocks"""
202
+
203
+ def __init__(
204
+ self,
205
+ dim,
206
+ num_heads,
207
+ mlp_ratio=4*2/3,
208
+ qkv_bias=True,
209
+ drop_path=0.0,
210
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
211
+ window_size=0,
212
+ use_residual_block=False,
213
+ rope=None,
214
+ xattn=True,
215
+ ):
216
+ """
217
+ Args:
218
+ dim (int): Number of input channels.
219
+ num_heads (int): Number of attention heads in each ViT block.
220
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
221
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
222
+ drop_path (float): Stochastic depth rate.
223
+ norm_layer (nn.Module): Normalization layer.
224
+ act_layer (nn.Module): Activation layer.
225
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
226
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
227
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
228
+ use window attention.
229
+ use_residual_block (bool): If True, use a residual block after the MLP block.
230
+ input_size (int or None): Input resolution for calculating the relative positional
231
+ parameter size.
232
+ """
233
+ super().__init__()
234
+ self.norm1 = norm_layer(dim)
235
+ self.attn = Attention(
236
+ dim,
237
+ num_heads=num_heads,
238
+ qkv_bias=qkv_bias,
239
+ rope=rope,
240
+ xattn=xattn,
241
+ )
242
+
243
+ from timm.models.layers import DropPath
244
+
245
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
246
+ self.norm2 = norm_layer(dim)
247
+ self.mlp = SwiGLU(
248
+ in_features=dim,
249
+ hidden_features=int(dim * mlp_ratio),
250
+ subln=True,
251
+ norm_layer=norm_layer,
252
+ )
253
+
254
+ self.window_size = window_size
255
+
256
+ self.use_residual_block = use_residual_block
257
+ if use_residual_block:
258
+ # Use a residual block with bottleneck channel as dim // 2
259
+ self.residual = ResBottleneckBlock(
260
+ in_channels=dim,
261
+ out_channels=dim,
262
+ bottleneck_channels=dim // 2,
263
+ norm="LN",
264
+ )
265
+
266
+ def forward(self, x):
267
+ shortcut = x
268
+ x = self.norm1(x)
269
+
270
+ # Window partition
271
+ if self.window_size > 0:
272
+ H, W = x.shape[1], x.shape[2]
273
+ x, pad_hw = window_partition(x, self.window_size)
274
+
275
+ x = self.attn(x)
276
+
277
+ # Reverse window partition
278
+ if self.window_size > 0:
279
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
280
+
281
+ x = shortcut + self.drop_path(x)
282
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
283
+
284
+ if self.use_residual_block:
285
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
286
+
287
+ return x
288
+
289
+
290
+ class EVA02_ViT(Backbone):
291
+ """
292
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
293
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
294
+ https://arxiv.org/abs/2203.16527
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ img_size=1024,
300
+ patch_size=16,
301
+ in_chans=3,
302
+ embed_dim=768,
303
+ depth=12,
304
+ num_heads=12,
305
+ mlp_ratio=4*2/3,
306
+ qkv_bias=True,
307
+ drop_path_rate=0.0,
308
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
309
+ act_layer=nn.GELU,
310
+ use_abs_pos=True,
311
+ use_rel_pos=False,
312
+ rope=True,
313
+ pt_hw_seq_len=16,
314
+ intp_freq=True,
315
+ window_size=0,
316
+ window_block_indexes=(),
317
+ residual_block_indexes=(),
318
+ use_act_checkpoint=False,
319
+ pretrain_img_size=224,
320
+ pretrain_use_cls_token=True,
321
+ out_feature="last_feat",
322
+ xattn=True,
323
+ ):
324
+ """
325
+ Args:
326
+ img_size (int): Input image size.
327
+ patch_size (int): Patch size.
328
+ in_chans (int): Number of input image channels.
329
+ embed_dim (int): Patch embedding dimension.
330
+ depth (int): Depth of ViT.
331
+ num_heads (int): Number of attention heads in each ViT block.
332
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
333
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
334
+ drop_path_rate (float): Stochastic depth rate.
335
+ norm_layer (nn.Module): Normalization layer.
336
+ act_layer (nn.Module): Activation layer.
337
+ use_abs_pos (bool): If True, use absolute positional embeddings.
338
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
339
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
340
+ window_size (int): Window size for window attention blocks.
341
+ window_block_indexes (list): Indexes for blocks using window attention.
342
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
343
+ use_act_checkpoint (bool): If True, use activation checkpointing.
344
+ pretrain_img_size (int): input image size for pretraining models.
345
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
346
+ out_feature (str): name of the feature from the last block.
347
+ """
348
+ super().__init__()
349
+ self.pretrain_use_cls_token = pretrain_use_cls_token
350
+
351
+ self.patch_embed = PatchEmbed(
352
+ kernel_size=(patch_size, patch_size),
353
+ stride=(patch_size, patch_size),
354
+ in_chans=in_chans,
355
+ embed_dim=embed_dim,
356
+ )
357
+
358
+ if use_abs_pos:
359
+ # Initialize absolute positional embedding with pretrain image size.
360
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
361
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
362
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
363
+ else:
364
+ self.pos_embed = None
365
+
366
+
367
+ half_head_dim = embed_dim // num_heads // 2
368
+ hw_seq_len = img_size // patch_size
369
+
370
+ self.rope_win = VisionRotaryEmbeddingFast(
371
+ dim=half_head_dim,
372
+ pt_seq_len=pt_hw_seq_len,
373
+ ft_seq_len=window_size if intp_freq else None,
374
+ )
375
+ self.rope_glb = VisionRotaryEmbeddingFast(
376
+ dim=half_head_dim,
377
+ pt_seq_len=pt_hw_seq_len,
378
+ ft_seq_len=hw_seq_len if intp_freq else None,
379
+ )
380
+
381
+ # stochastic depth decay rule
382
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
383
+
384
+ self.blocks = nn.ModuleList()
385
+ for i in range(depth):
386
+ block = Block(
387
+ dim=embed_dim,
388
+ num_heads=num_heads,
389
+ mlp_ratio=mlp_ratio,
390
+ qkv_bias=qkv_bias,
391
+ drop_path=dpr[i],
392
+ norm_layer=norm_layer,
393
+ window_size=window_size if i in window_block_indexes else 0,
394
+ use_residual_block=i in residual_block_indexes,
395
+ rope=self.rope_win if i in window_block_indexes else self.rope_glb,
396
+ xattn=xattn
397
+ )
398
+ if use_act_checkpoint:
399
+ # TODO: use torch.utils.checkpoint
400
+ from fairscale.nn.checkpoint import checkpoint_wrapper
401
+
402
+ block = checkpoint_wrapper(block)
403
+ self.blocks.append(block)
404
+
405
+ self._out_feature_channels = {out_feature: embed_dim}
406
+ self._out_feature_strides = {out_feature: patch_size}
407
+ self._out_features = [out_feature]
408
+
409
+ if self.pos_embed is not None:
410
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
411
+
412
+ self.apply(self._init_weights)
413
+
414
+ def _init_weights(self, m):
415
+ if isinstance(m, nn.Linear):
416
+ nn.init.trunc_normal_(m.weight, std=0.02)
417
+ if isinstance(m, nn.Linear) and m.bias is not None:
418
+ nn.init.constant_(m.bias, 0)
419
+ elif isinstance(m, nn.LayerNorm):
420
+ nn.init.constant_(m.bias, 0)
421
+ nn.init.constant_(m.weight, 1.0)
422
+
423
+ def forward(self, x):
424
+ x = self.patch_embed(x)
425
+ if self.pos_embed is not None:
426
+ x = x + get_abs_pos(
427
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
428
+ )
429
+
430
+ for blk in self.blocks:
431
+ x = blk(x)
432
+
433
+ outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
434
+ return outputs
435
+
436
+
437
+ class SimpleFeaturePyramid(Backbone):
438
+ """
439
+ This module implements SimpleFeaturePyramid in :paper:`vitdet`.
440
+ It creates pyramid features built on top of the input feature map.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ net,
446
+ in_feature,
447
+ out_channels,
448
+ scale_factors,
449
+ top_block=None,
450
+ norm="LN",
451
+ square_pad=0,
452
+ ):
453
+ """
454
+ Args:
455
+ net (Backbone): module representing the subnetwork backbone.
456
+ Must be a subclass of :class:`Backbone`.
457
+ in_feature (str): names of the input feature maps coming
458
+ from the net.
459
+ out_channels (int): number of channels in the output feature maps.
460
+ scale_factors (list[float]): list of scaling factors to upsample or downsample
461
+ the input features for creating pyramid features.
462
+ top_block (nn.Module or None): if provided, an extra operation will
463
+ be performed on the output of the last (smallest resolution)
464
+ pyramid output, and the result will extend the result list. The top_block
465
+ further downsamples the feature map. It must have an attribute
466
+ "num_levels", meaning the number of extra pyramid levels added by
467
+ this block, and "in_feature", which is a string representing
468
+ its input feature (e.g., p5).
469
+ norm (str): the normalization to use.
470
+ square_pad (int): If > 0, require input images to be padded to specific square size.
471
+ """
472
+ super(SimpleFeaturePyramid, self).__init__()
473
+ assert isinstance(net, Backbone)
474
+
475
+ self.scale_factors = scale_factors
476
+
477
+ input_shapes = net.output_shape()
478
+ strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
479
+ _assert_strides_are_log2_contiguous(strides)
480
+
481
+ dim = input_shapes[in_feature].channels
482
+ self.stages = []
483
+ use_bias = norm == ""
484
+ for idx, scale in enumerate(scale_factors):
485
+ out_dim = dim
486
+ if scale == 4.0:
487
+ layers = [
488
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
489
+ get_norm(norm, dim // 2),
490
+ nn.GELU(),
491
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
492
+ ]
493
+ out_dim = dim // 4
494
+ elif scale == 2.0:
495
+ layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
496
+ out_dim = dim // 2
497
+ elif scale == 1.0:
498
+ layers = []
499
+ elif scale == 0.5:
500
+ layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
501
+ else:
502
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
503
+
504
+ layers.extend(
505
+ [
506
+ Conv2d(
507
+ out_dim,
508
+ out_channels,
509
+ kernel_size=1,
510
+ bias=use_bias,
511
+ norm=get_norm(norm, out_channels),
512
+ ),
513
+ Conv2d(
514
+ out_channels,
515
+ out_channels,
516
+ kernel_size=3,
517
+ padding=1,
518
+ bias=use_bias,
519
+ norm=get_norm(norm, out_channels),
520
+ ),
521
+ ]
522
+ )
523
+ layers = nn.Sequential(*layers)
524
+
525
+ stage = int(math.log2(strides[idx]))
526
+ self.add_module(f"simfp_{stage}", layers)
527
+ self.stages.append(layers)
528
+
529
+ self.net = net
530
+ self.in_feature = in_feature
531
+ self.top_block = top_block
532
+ # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
533
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
534
+ # top block output feature maps.
535
+ if self.top_block is not None:
536
+ for s in range(stage, stage + self.top_block.num_levels):
537
+ self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
538
+
539
+ self._out_features = list(self._out_feature_strides.keys())
540
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
541
+ self._size_divisibility = strides[-1]
542
+ self._square_pad = square_pad
543
+
544
+ @property
545
+ def padding_constraints(self):
546
+ return {
547
+ "size_divisiblity": self._size_divisibility,
548
+ "square_size": self._square_pad,
549
+ }
550
+
551
+ def forward(self, x):
552
+ """
553
+ Args:
554
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
555
+
556
+ Returns:
557
+ dict[str->Tensor]:
558
+ mapping from feature map name to pyramid feature map tensor
559
+ in high to low resolution order. Returned feature names follow the FPN
560
+ convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
561
+ ["p2", "p3", ..., "p6"].
562
+ """
563
+ bottom_up_features = self.net(x)
564
+ features = bottom_up_features[self.in_feature]
565
+ results = []
566
+
567
+ for stage in self.stages:
568
+ results.append(stage(features))
569
+
570
+ if self.top_block is not None:
571
+ if self.top_block.in_feature in bottom_up_features:
572
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
573
+ else:
574
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
575
+ results.extend(self.top_block(top_block_in_feature))
576
+ assert len(self._out_features) == len(results)
577
+ return {f: res for f, res in zip(self._out_features, results)}
578
+
579
+
580
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
581
+ """
582
+ Calculate lr decay rate for different ViT blocks.
583
+ Args:
584
+ name (string): parameter name.
585
+ lr_decay_rate (float): base lr decay rate.
586
+ num_layers (int): number of ViT blocks.
587
+
588
+ Returns:
589
+ lr decay rate for the given parameter.
590
+ """
591
+ layer_id = num_layers + 1
592
+ if name.startswith("backbone"):
593
+ if ".pos_embed" in name or ".patch_embed" in name:
594
+ layer_id = 0
595
+ elif ".blocks." in name and ".residual." not in name:
596
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
597
+
598
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
GLEE/glee/backbone/eva02.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EVA02
3
+ # --------------------------------------------------------
4
+ import logging
5
+ import math
6
+ from functools import partial
7
+
8
+ import fvcore.nn.weight_init as weight_init
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from detectron2.layers import CNNBlockBase, Conv2d, get_norm
14
+ from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
15
+
16
+ from detectron2.modeling.backbone import Backbone
17
+ from timm.models.layers import DropPath, Mlp, trunc_normal_
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+
21
+ from .eva_02_utils import (
22
+ PatchEmbed,
23
+ add_decomposed_rel_pos,
24
+ get_abs_pos,
25
+ window_partition,
26
+ window_unpartition,
27
+ VisionRotaryEmbeddingFast,
28
+ )
29
+ from detectron2.modeling.backbone.fpn import LastLevelMaxPool
30
+
31
+
32
+ try:
33
+ import xformers.ops as xops
34
+ HAS_XFORMER=True
35
+ except:
36
+ HAS_XFORMER=False
37
+ pass
38
+
39
+ try:
40
+ from apex.normalization import FusedLayerNorm
41
+ except:
42
+ pass
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+
48
+ __all__ = ["EVA02_ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
49
+
50
+
51
+
52
+ class SwiGLU(nn.Module):
53
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
54
+ norm_layer=nn.LayerNorm, subln=False
55
+ ):
56
+ super().__init__()
57
+ out_features = out_features or in_features
58
+ hidden_features = hidden_features or in_features
59
+
60
+ self.w1 = nn.Linear(in_features, hidden_features)
61
+ self.w2 = nn.Linear(in_features, hidden_features)
62
+
63
+ self.act = act_layer()
64
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
65
+ self.w3 = nn.Linear(hidden_features, out_features)
66
+
67
+ self.drop = nn.Dropout(drop)
68
+
69
+ def forward(self, x):
70
+ x1 = self.w1(x)
71
+ x2 = self.w2(x)
72
+ hidden = self.act(x1) * x2
73
+ x = self.ffn_ln(hidden)
74
+ x = self.w3(x)
75
+ x = self.drop(x)
76
+ return x
77
+
78
+
79
+ class Attention(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ num_heads=8,
84
+ qkv_bias=True,
85
+ qk_scale=None,
86
+ attn_head_dim=None,
87
+ rope=None,
88
+ xattn=True,
89
+ ):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ head_dim = dim // num_heads
93
+ if attn_head_dim is not None:
94
+ head_dim = attn_head_dim
95
+ all_head_dim = head_dim * self.num_heads
96
+ self.scale = qk_scale or head_dim ** -0.5
97
+
98
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
99
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
100
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
101
+
102
+ if qkv_bias:
103
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
104
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
105
+ else:
106
+ self.q_bias = None
107
+ self.v_bias = None
108
+
109
+ self.rope = rope
110
+ self.xattn = xattn
111
+ self.proj = nn.Linear(all_head_dim, dim)
112
+ if not HAS_XFORMER:
113
+ self.xattn = False
114
+
115
+ def forward(self, x):
116
+ B, H, W, C = x.shape
117
+ x = x.view(B, -1, C)
118
+ N = H * W
119
+
120
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
121
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
122
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
123
+
124
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
125
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
126
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
127
+
128
+ ## rope
129
+ q = self.rope(q).type_as(v)
130
+ k = self.rope(k).type_as(v)
131
+
132
+ if self.xattn:
133
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
134
+ k = k.permute(0, 2, 1, 3)
135
+ v = v.permute(0, 2, 1, 3)
136
+
137
+ x = xops.memory_efficient_attention(q, k, v)
138
+ x = x.reshape(B, N, -1)
139
+ else:
140
+ q = q * self.scale
141
+ attn = (q @ k.transpose(-2, -1))
142
+ attn = attn.softmax(dim=-1).type_as(x)
143
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
144
+
145
+ x = self.proj(x)
146
+ x = x.view(B, H, W, C)
147
+
148
+ return x
149
+
150
+
151
+ class ResBottleneckBlock(CNNBlockBase):
152
+ """
153
+ The standard bottleneck residual block without the last activation layer.
154
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ in_channels,
160
+ out_channels,
161
+ bottleneck_channels,
162
+ norm="LN",
163
+ act_layer=nn.GELU,
164
+ ):
165
+ """
166
+ Args:
167
+ in_channels (int): Number of input channels.
168
+ out_channels (int): Number of output channels.
169
+ bottleneck_channels (int): number of output channels for the 3x3
170
+ "bottleneck" conv layers.
171
+ norm (str or callable): normalization for all conv layers.
172
+ See :func:`layers.get_norm` for supported format.
173
+ act_layer (callable): activation for all conv layers.
174
+ """
175
+ super().__init__(in_channels, out_channels, 1)
176
+
177
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
178
+ self.norm1 = get_norm(norm, bottleneck_channels)
179
+ self.act1 = act_layer()
180
+
181
+ self.conv2 = Conv2d(
182
+ bottleneck_channels,
183
+ bottleneck_channels,
184
+ 3,
185
+ padding=1,
186
+ bias=False,
187
+ )
188
+ self.norm2 = get_norm(norm, bottleneck_channels)
189
+ self.act2 = act_layer()
190
+
191
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
192
+ self.norm3 = get_norm(norm, out_channels)
193
+
194
+ for layer in [self.conv1, self.conv2, self.conv3]:
195
+ weight_init.c2_msra_fill(layer)
196
+ for layer in [self.norm1, self.norm2]:
197
+ layer.weight.data.fill_(1.0)
198
+ layer.bias.data.zero_()
199
+ # zero init last norm layer.
200
+ self.norm3.weight.data.zero_()
201
+ self.norm3.bias.data.zero_()
202
+
203
+ def forward(self, x):
204
+ out = x
205
+ for layer in self.children():
206
+ out = layer(out)
207
+
208
+ out = x + out
209
+ return out
210
+
211
+
212
+ class Block(nn.Module):
213
+ """Transformer blocks with support of window attention and residual propagation blocks"""
214
+
215
+ def __init__(
216
+ self,
217
+ dim,
218
+ num_heads,
219
+ mlp_ratio=4*2/3,
220
+ qkv_bias=True,
221
+ drop_path=0.0,
222
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
223
+ window_size=0,
224
+ use_residual_block=False,
225
+ rope=None,
226
+ xattn=True,
227
+ ):
228
+ """
229
+ Args:
230
+ dim (int): Number of input channels.
231
+ num_heads (int): Number of attention heads in each ViT block.
232
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
233
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
234
+ drop_path (float): Stochastic depth rate.
235
+ norm_layer (nn.Module): Normalization layer.
236
+ act_layer (nn.Module): Activation layer.
237
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
238
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
239
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
240
+ use window attention.
241
+ use_residual_block (bool): If True, use a residual block after the MLP block.
242
+ input_size (int or None): Input resolution for calculating the relative positional
243
+ parameter size.
244
+ """
245
+ super().__init__()
246
+ self.norm1 = norm_layer(dim)
247
+ self.attn = Attention(
248
+ dim,
249
+ num_heads=num_heads,
250
+ qkv_bias=qkv_bias,
251
+ rope=rope,
252
+ xattn=xattn,
253
+ )
254
+
255
+ from timm.models.layers import DropPath
256
+
257
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
258
+ self.norm2 = norm_layer(dim)
259
+ self.mlp = SwiGLU(
260
+ in_features=dim,
261
+ hidden_features=int(dim * mlp_ratio),
262
+ subln=True,
263
+ norm_layer=norm_layer,
264
+ )
265
+
266
+ self.window_size = window_size
267
+
268
+ self.use_residual_block = use_residual_block
269
+ if use_residual_block:
270
+ # Use a residual block with bottleneck channel as dim // 2
271
+ self.residual = ResBottleneckBlock(
272
+ in_channels=dim,
273
+ out_channels=dim,
274
+ bottleneck_channels=dim // 2,
275
+ norm="LN",
276
+ )
277
+
278
+ def forward(self, x):
279
+ shortcut = x
280
+ x = self.norm1(x)
281
+
282
+ # Window partition
283
+ if self.window_size > 0:
284
+ H, W = x.shape[1], x.shape[2]
285
+ x, pad_hw = window_partition(x, self.window_size)
286
+
287
+ x = self.attn(x)
288
+
289
+ # Reverse window partition
290
+ if self.window_size > 0:
291
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
292
+
293
+ x = shortcut + self.drop_path(x)
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+
296
+ if self.use_residual_block:
297
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
298
+
299
+ return x
300
+
301
+
302
+ class EVA02_ViT(Backbone):
303
+ """
304
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
305
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
306
+ https://arxiv.org/abs/2203.16527
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ img_size=1024,
312
+ patch_size=16,
313
+ in_chans=3,
314
+ embed_dim=768,
315
+ depth=12,
316
+ num_heads=12,
317
+ mlp_ratio=4*2/3,
318
+ qkv_bias=True,
319
+ drop_path_rate=0.0,
320
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
321
+ act_layer=nn.GELU,
322
+ use_abs_pos=True,
323
+ use_rel_pos=False,
324
+ rope=True,
325
+ pt_hw_seq_len=16,
326
+ intp_freq=True,
327
+ window_size=0,
328
+ window_block_indexes=(),
329
+ residual_block_indexes=(),
330
+ use_act_checkpoint=False,
331
+ pretrain_img_size=224,
332
+ pretrain_use_cls_token=True,
333
+ out_feature="last_feat",
334
+ xattn=True,
335
+ ):
336
+ """
337
+ Args:
338
+ img_size (int): Input image size.
339
+ patch_size (int): Patch size.
340
+ in_chans (int): Number of input image channels.
341
+ embed_dim (int): Patch embedding dimension.
342
+ depth (int): Depth of ViT.
343
+ num_heads (int): Number of attention heads in each ViT block.
344
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
345
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
346
+ drop_path_rate (float): Stochastic depth rate.
347
+ norm_layer (nn.Module): Normalization layer.
348
+ act_layer (nn.Module): Activation layer.
349
+ use_abs_pos (bool): If True, use absolute positional embeddings.
350
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
351
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
352
+ window_size (int): Window size for window attention blocks.
353
+ window_block_indexes (list): Indexes for blocks using window attention.
354
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
355
+ use_act_checkpoint (bool): If True, use activation checkpointing.
356
+ pretrain_img_size (int): input image size for pretraining models.
357
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
358
+ out_feature (str): name of the feature from the last block.
359
+ """
360
+ super().__init__()
361
+ self.pretrain_use_cls_token = pretrain_use_cls_token
362
+
363
+ self.patch_embed = PatchEmbed(
364
+ kernel_size=(patch_size, patch_size),
365
+ stride=(patch_size, patch_size),
366
+ in_chans=in_chans,
367
+ embed_dim=embed_dim,
368
+ )
369
+
370
+ if use_abs_pos:
371
+ # Initialize absolute positional embedding with pretrain image size.
372
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
373
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
374
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
375
+ else:
376
+ self.pos_embed = None
377
+
378
+
379
+ half_head_dim = embed_dim // num_heads // 2
380
+ hw_seq_len = img_size // patch_size
381
+
382
+ self.rope_win = VisionRotaryEmbeddingFast(
383
+ dim=half_head_dim,
384
+ pt_seq_len=pt_hw_seq_len,
385
+ ft_seq_len=window_size if intp_freq else None,
386
+ )
387
+ self.rope_glb = VisionRotaryEmbeddingFast(
388
+ dim=half_head_dim,
389
+ pt_seq_len=pt_hw_seq_len,
390
+ ft_seq_len=hw_seq_len if intp_freq else None,
391
+ )
392
+
393
+ # stochastic depth decay rule
394
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
395
+
396
+ self.blocks = nn.ModuleList()
397
+ for i in range(depth):
398
+ block = Block(
399
+ dim=embed_dim,
400
+ num_heads=num_heads,
401
+ mlp_ratio=mlp_ratio,
402
+ qkv_bias=qkv_bias,
403
+ drop_path=dpr[i],
404
+ norm_layer=norm_layer,
405
+ window_size=window_size if i in window_block_indexes else 0,
406
+ use_residual_block=i in residual_block_indexes,
407
+ rope=self.rope_win if i in window_block_indexes else self.rope_glb,
408
+ xattn=xattn
409
+ )
410
+ if use_act_checkpoint:
411
+ # TODO: use torch.utils.checkpoint
412
+ from fairscale.nn.checkpoint import checkpoint_wrapper
413
+
414
+ block = checkpoint_wrapper(block)
415
+ self.blocks.append(block)
416
+
417
+ self._out_feature_channels = {out_feature: embed_dim}
418
+ self._out_feature_strides = {out_feature: patch_size}
419
+ self._out_features = [out_feature]
420
+
421
+ if self.pos_embed is not None:
422
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
423
+
424
+ self.apply(self._init_weights)
425
+
426
+ def _init_weights(self, m):
427
+ if isinstance(m, nn.Linear):
428
+ nn.init.trunc_normal_(m.weight, std=0.02)
429
+ if isinstance(m, nn.Linear) and m.bias is not None:
430
+ nn.init.constant_(m.bias, 0)
431
+ elif isinstance(m, nn.LayerNorm):
432
+ nn.init.constant_(m.bias, 0)
433
+ nn.init.constant_(m.weight, 1.0)
434
+
435
+ def forward(self, x):
436
+ x = self.patch_embed(x)
437
+ if self.pos_embed is not None:
438
+ x = x + get_abs_pos(
439
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
440
+ )
441
+
442
+ for blk in self.blocks:
443
+ x = blk(x)
444
+
445
+ outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
446
+ return outputs
447
+
448
+
449
+ class SimpleFeaturePyramid(Backbone):
450
+ """
451
+ This module implements SimpleFeaturePyramid in :paper:`vitdet`.
452
+ It creates pyramid features built on top of the input feature map.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ net,
458
+ in_feature,
459
+ out_channels,
460
+ scale_factors,
461
+ top_block=None,
462
+ norm="LN",
463
+ square_pad=0,
464
+ ):
465
+ """
466
+ Args:
467
+ net (Backbone): module representing the subnetwork backbone.
468
+ Must be a subclass of :class:`Backbone`.
469
+ in_feature (str): names of the input feature maps coming
470
+ from the net.
471
+ out_channels (int): number of channels in the output feature maps.
472
+ scale_factors (list[float]): list of scaling factors to upsample or downsample
473
+ the input features for creating pyramid features.
474
+ top_block (nn.Module or None): if provided, an extra operation will
475
+ be performed on the output of the last (smallest resolution)
476
+ pyramid output, and the result will extend the result list. The top_block
477
+ further downsamples the feature map. It must have an attribute
478
+ "num_levels", meaning the number of extra pyramid levels added by
479
+ this block, and "in_feature", which is a string representing
480
+ its input feature (e.g., p5).
481
+ norm (str): the normalization to use.
482
+ square_pad (int): If > 0, require input images to be padded to specific square size.
483
+ """
484
+ super(SimpleFeaturePyramid, self).__init__()
485
+ assert isinstance(net, Backbone)
486
+
487
+ self.scale_factors = scale_factors
488
+
489
+ input_shapes = net.output_shape()
490
+ strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
491
+ _assert_strides_are_log2_contiguous(strides)
492
+
493
+ dim = input_shapes[in_feature].channels
494
+ self.stages = []
495
+ use_bias = norm == ""
496
+ for idx, scale in enumerate(scale_factors):
497
+ out_dim = dim
498
+ if scale == 4.0:
499
+ layers = [
500
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
501
+ get_norm(norm, dim // 2),
502
+ nn.GELU(),
503
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
504
+ ]
505
+ out_dim = dim // 4
506
+ elif scale == 2.0:
507
+ layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
508
+ out_dim = dim // 2
509
+ elif scale == 1.0:
510
+ layers = []
511
+ elif scale == 0.5:
512
+ layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
513
+ else:
514
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
515
+
516
+ layers.extend(
517
+ [
518
+ Conv2d(
519
+ out_dim,
520
+ out_channels,
521
+ kernel_size=1,
522
+ bias=use_bias,
523
+ norm=get_norm(norm, out_channels),
524
+ ),
525
+ Conv2d(
526
+ out_channels,
527
+ out_channels,
528
+ kernel_size=3,
529
+ padding=1,
530
+ bias=use_bias,
531
+ norm=get_norm(norm, out_channels),
532
+ ),
533
+ ]
534
+ )
535
+ layers = nn.Sequential(*layers)
536
+
537
+ stage = int(math.log2(strides[idx]))
538
+ self.add_module(f"simfp_{stage}", layers)
539
+ self.stages.append(layers)
540
+
541
+ self.net = net
542
+ self.in_feature = in_feature
543
+ self.top_block = top_block
544
+ # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
545
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
546
+ # top block output feature maps.
547
+ if self.top_block is not None:
548
+ for s in range(stage, stage + self.top_block.num_levels):
549
+ self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
550
+
551
+ self._out_features = list(self._out_feature_strides.keys())
552
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
553
+ self._size_divisibility = strides[-1]
554
+ self._square_pad = square_pad
555
+
556
+ @property
557
+ def padding_constraints(self):
558
+ return {
559
+ "size_divisiblity": self._size_divisibility,
560
+ "square_size": self._square_pad,
561
+ }
562
+
563
+ def forward(self, x):
564
+ """
565
+ Args:
566
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
567
+
568
+ Returns:
569
+ dict[str->Tensor]:
570
+ mapping from feature map name to pyramid feature map tensor
571
+ in high to low resolution order. Returned feature names follow the FPN
572
+ convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
573
+ ["p2", "p3", ..., "p6"].
574
+ """
575
+ bottom_up_features = self.net(x)
576
+ features = bottom_up_features[self.in_feature]
577
+ results = []
578
+
579
+ for stage in self.stages:
580
+ results.append(stage(features))
581
+
582
+ if self.top_block is not None:
583
+ if self.top_block.in_feature in bottom_up_features:
584
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
585
+ else:
586
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
587
+ results.extend(self.top_block(top_block_in_feature))
588
+ assert len(self._out_features) == len(results)
589
+ return {f: res for f, res in zip(self._out_features, results)}
590
+
591
+
592
+
593
+ @BACKBONE_REGISTRY.register()
594
+ class D2_EVA02(SimpleFeaturePyramid):
595
+ def __init__(self, cfg, input_shape):
596
+
597
+ super().__init__(
598
+
599
+ net = EVA02_ViT(
600
+ img_size= cfg.MODEL.EVA02.IMAGE_SIZE,
601
+ patch_size=cfg.MODEL.EVA02.PATCH_SIZE,
602
+ window_size= cfg.MODEL.EVA02.WINDOW_SIZE,
603
+ embed_dim= cfg.MODEL.EVA02.DMBED_DIM,
604
+ depth= cfg.MODEL.EVA02.DEPTH,
605
+ num_heads= cfg.MODEL.EVA02.NUM_HEADS ,
606
+ drop_path_rate= cfg.MODEL.EVA02.DROP_PATH_RATE,
607
+ mlp_ratio= cfg.MODEL.EVA02.MLP_RATIO,
608
+ # qkv_bias=True,
609
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
610
+ window_block_indexes= cfg.MODEL.EVA02.WINDOW_BLOCK_INDEXES,
611
+ # residual_block_indexes=[],
612
+ # use_rel_pos=False,
613
+ use_act_checkpoint = cfg.MODEL.EVA02.CHECKPOINT,
614
+ out_feature="last_feat",
615
+ # intp_freq=True,
616
+ ),
617
+ in_feature = "last_feat",
618
+ out_channels=256,
619
+ scale_factors=(2.0, 1.0, 0.5), # (4.0, 2.0, 1.0, 0.5) in ViTDet
620
+ top_block=LastLevelMaxPool(),
621
+ norm="LN",
622
+ square_pad=cfg.MODEL.EVA02.IMAGE_SIZE,
623
+
624
+ )
625
+
626
+ pretrained_weight = cfg.MODEL.EVA02.PRETRAINED_WEIGHT
627
+ if pretrained_weight:
628
+ checkpoint = torch.load(pretrained_weight, map_location='cpu')
629
+ print(f'\nload pretrain weight from {pretrained_weight} \n')
630
+
631
+ self.load_state_dict(checkpoint['model'], strict=False)
632
+
633
+
634
+
635
+ def output_shape(self):
636
+ return {
637
+ name: ShapeSpec(
638
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
639
+ )
640
+ for name in self._out_features
641
+ }
642
+
643
+ @property
644
+ def size_divisibility(self):
645
+ return 32
646
+
647
+
GLEE/glee/backbone/eva_01_utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import math
3
+ import numpy as np
4
+ from scipy import interpolate
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ __all__ = [
10
+ "window_partition",
11
+ "window_unpartition",
12
+ "add_decomposed_rel_pos",
13
+ "get_abs_pos",
14
+ "PatchEmbed",
15
+ ]
16
+
17
+
18
+ def window_partition(x, window_size):
19
+ """
20
+ Partition into non-overlapping windows with padding if needed.
21
+ Args:
22
+ x (tensor): input tokens with [B, H, W, C].
23
+ window_size (int): window size.
24
+
25
+ Returns:
26
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
27
+ (Hp, Wp): padded height and width before partition
28
+ """
29
+ B, H, W, C = x.shape
30
+
31
+ pad_h = (window_size - H % window_size) % window_size
32
+ pad_w = (window_size - W % window_size) % window_size
33
+ if pad_h > 0 or pad_w > 0:
34
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
35
+ Hp, Wp = H + pad_h, W + pad_w
36
+
37
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
38
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
39
+ return windows, (Hp, Wp)
40
+
41
+
42
+ def window_unpartition(windows, window_size, pad_hw, hw):
43
+ """
44
+ Window unpartition into original sequences and removing padding.
45
+ Args:
46
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
47
+ window_size (int): window size.
48
+ pad_hw (Tuple): padded height and width (Hp, Wp).
49
+ hw (Tuple): original height and width (H, W) before padding.
50
+
51
+ Returns:
52
+ x: unpartitioned sequences with [B, H, W, C].
53
+ """
54
+ Hp, Wp = pad_hw
55
+ H, W = hw
56
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
57
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ def get_rel_pos(q_size, k_size, rel_pos, interp_type):
66
+ """
67
+ Get relative positional embeddings according to the relative positions of
68
+ query and key sizes.
69
+ Args:
70
+ q_size (int): size of query q.
71
+ k_size (int): size of key k.
72
+ rel_pos (Tensor): relative position embeddings (L, C).
73
+
74
+ Returns:
75
+ Extracted positional embeddings according to relative positions.
76
+ """
77
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
78
+ # Interpolate rel pos if needed.
79
+ if rel_pos.shape[0] != max_rel_dist:
80
+ if interp_type == "vitdet":
81
+ # the vitdet impl:
82
+ # https://github.com/facebookresearch/detectron2/blob/96c752ce821a3340e27edd51c28a00665dd32a30/detectron2/modeling/backbone/utils.py#L77.
83
+
84
+ rel_pos_resized = F.interpolate(
85
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
86
+ size=max_rel_dist,
87
+ mode="linear",
88
+ )
89
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
90
+ elif interp_type == "beit":
91
+ # steal from beit https://github.com/microsoft/unilm/tree/master/beit
92
+ # modified by Yuxin Fang
93
+
94
+ src_size = rel_pos.shape[0]
95
+ dst_size = max_rel_dist
96
+
97
+ q = 1.0903078
98
+ dis = []
99
+
100
+ cur = 1
101
+ for i in range(src_size // 2):
102
+ dis.append(cur)
103
+ cur += q ** (i + 1)
104
+
105
+ r_ids = [-_ for _ in reversed(dis)]
106
+ x = r_ids + [0] + dis
107
+ t = dst_size // 2.0
108
+ dx = np.arange(-t, t + 0.1, 1.0)
109
+
110
+ all_rel_pos_bias = []
111
+ for i in range(rel_pos.shape[1]):
112
+ # a hack from https://github.com/baaivision/EVA/issues/8,
113
+ # could also be used in fine-tuning but the performance haven't been tested.
114
+ z = rel_pos[:, i].view(src_size).cpu().float().detach().numpy()
115
+ f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
116
+ all_rel_pos_bias.append(
117
+ torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
118
+ rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
119
+ else:
120
+ raise NotImplementedError()
121
+ else:
122
+ rel_pos_resized = rel_pos
123
+
124
+ # Scale the coords with short length if shapes for q and k are different.
125
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
126
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
127
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
128
+
129
+ return rel_pos_resized[relative_coords.long()]
130
+
131
+
132
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size, interp_type):
133
+ """
134
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
135
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
136
+ Args:
137
+ attn (Tensor): attention map.
138
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
139
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
140
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
141
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
142
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
143
+
144
+ Returns:
145
+ attn (Tensor): attention map with added relative positional embeddings.
146
+ """
147
+ q_h, q_w = q_size
148
+ k_h, k_w = k_size
149
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h, interp_type)
150
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w, interp_type)
151
+
152
+ B, _, dim = q.shape
153
+ r_q = q.reshape(B, q_h, q_w, dim)
154
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
155
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
156
+
157
+ attn = (
158
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
159
+ ).view(B, q_h * q_w, k_h * k_w)
160
+
161
+ return attn
162
+
163
+
164
+ def get_abs_pos(abs_pos, has_cls_token, hw):
165
+ """
166
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
167
+ dimension for the original embeddings.
168
+ Args:
169
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
170
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
171
+ hw (Tuple): size of input image tokens.
172
+
173
+ Returns:
174
+ Absolute positional embeddings after processing with shape (1, H, W, C)
175
+ """
176
+ h, w = hw
177
+ if has_cls_token:
178
+ abs_pos = abs_pos[:, 1:]
179
+ xy_num = abs_pos.shape[1]
180
+ size = int(math.sqrt(xy_num))
181
+ assert size * size == xy_num
182
+
183
+ if size != h or size != w:
184
+ new_abs_pos = F.interpolate(
185
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
186
+ size=(h, w),
187
+ mode="bicubic",
188
+ align_corners=False,
189
+ )
190
+
191
+ return new_abs_pos.permute(0, 2, 3, 1)
192
+ else:
193
+ return abs_pos.reshape(1, h, w, -1)
194
+
195
+
196
+ class PatchEmbed(nn.Module):
197
+ """
198
+ Image to Patch Embedding.
199
+ """
200
+
201
+ def __init__(
202
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
203
+ ):
204
+ """
205
+ Args:
206
+ kernel_size (Tuple): kernel size of the projection layer.
207
+ stride (Tuple): stride of the projection layer.
208
+ padding (Tuple): padding size of the projection layer.
209
+ in_chans (int): Number of input image channels.
210
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
211
+ """
212
+ super().__init__()
213
+
214
+ self.proj = nn.Conv2d(
215
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
216
+ )
217
+
218
+ def forward(self, x):
219
+ x = self.proj(x)
220
+ # B C H W -> B H W C
221
+ x = x.permute(0, 2, 3, 1)
222
+ return x
GLEE/glee/backbone/eva_02_utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import math
3
+ import numpy as np
4
+ from scipy import interpolate
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ __all__ = [
10
+ "window_partition",
11
+ "window_unpartition",
12
+ "add_decomposed_rel_pos",
13
+ "get_abs_pos",
14
+ "PatchEmbed",
15
+ "VisionRotaryEmbeddingFast",
16
+ ]
17
+
18
+
19
+ def window_partition(x, window_size):
20
+ """
21
+ Partition into non-overlapping windows with padding if needed.
22
+ Args:
23
+ x (tensor): input tokens with [B, H, W, C].
24
+ window_size (int): window size.
25
+
26
+ Returns:
27
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
28
+ (Hp, Wp): padded height and width before partition
29
+ """
30
+ B, H, W, C = x.shape
31
+
32
+ pad_h = (window_size - H % window_size) % window_size
33
+ pad_w = (window_size - W % window_size) % window_size
34
+ if pad_h > 0 or pad_w > 0:
35
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
36
+ Hp, Wp = H + pad_h, W + pad_w
37
+
38
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
39
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
40
+ return windows, (Hp, Wp)
41
+
42
+
43
+ def window_unpartition(windows, window_size, pad_hw, hw):
44
+ """
45
+ Window unpartition into original sequences and removing padding.
46
+ Args:
47
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
48
+ window_size (int): window size.
49
+ pad_hw (Tuple): padded height and width (Hp, Wp).
50
+ hw (Tuple): original height and width (H, W) before padding.
51
+
52
+ Returns:
53
+ x: unpartitioned sequences with [B, H, W, C].
54
+ """
55
+ Hp, Wp = pad_hw
56
+ H, W = hw
57
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
58
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
59
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
60
+
61
+ if Hp > H or Wp > W:
62
+ x = x[:, :H, :W, :].contiguous()
63
+ return x
64
+
65
+
66
+ def get_rel_pos(q_size, k_size, rel_pos):
67
+ """
68
+ Get relative positional embeddings according to the relative positions of
69
+ query and key sizes.
70
+ Args:
71
+ q_size (int): size of query q.
72
+ k_size (int): size of key k.
73
+ rel_pos (Tensor): relative position embeddings (L, C).
74
+
75
+ Returns:
76
+ Extracted positional embeddings according to relative positions.
77
+ """
78
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
79
+ use_log_interpolation = True
80
+
81
+ # Interpolate rel pos if needed.
82
+ if rel_pos.shape[0] != max_rel_dist:
83
+ if not use_log_interpolation:
84
+ # Interpolate rel pos.
85
+ rel_pos_resized = F.interpolate(
86
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
87
+ size=max_rel_dist,
88
+ mode="linear",
89
+ )
90
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
91
+ else:
92
+ src_size = rel_pos.shape[0]
93
+ dst_size = max_rel_dist
94
+
95
+ # q = 1.13492
96
+ q = 1.0903078
97
+ dis = []
98
+
99
+ cur = 1
100
+ for i in range(src_size // 2):
101
+ dis.append(cur)
102
+ cur += q ** (i + 1)
103
+
104
+ r_ids = [-_ for _ in reversed(dis)]
105
+ x = r_ids + [0] + dis
106
+ t = dst_size // 2.0
107
+ dx = np.arange(-t, t + 0.1, 1.0)
108
+ # print("x = %s" % str(x))
109
+ # print("dx = %s" % str(dx))
110
+ all_rel_pos_bias = []
111
+ for i in range(rel_pos.shape[1]):
112
+ z = rel_pos[:, i].view(src_size).cpu().float().numpy()
113
+ f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
114
+ all_rel_pos_bias.append(
115
+ torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
116
+ rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
117
+ else:
118
+ rel_pos_resized = rel_pos
119
+
120
+ # Scale the coords with short length if shapes for q and k are different.
121
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
122
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
123
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
124
+
125
+ return rel_pos_resized[relative_coords.long()]
126
+
127
+
128
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
129
+ """
130
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
131
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
132
+ Args:
133
+ attn (Tensor): attention map.
134
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
135
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
136
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
137
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
138
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
139
+
140
+ Returns:
141
+ attn (Tensor): attention map with added relative positional embeddings.
142
+ """
143
+ q_h, q_w = q_size
144
+ k_h, k_w = k_size
145
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
146
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
147
+
148
+ B, _, dim = q.shape
149
+ r_q = q.reshape(B, q_h, q_w, dim)
150
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
151
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
152
+
153
+ attn = (
154
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
155
+ ).view(B, q_h * q_w, k_h * k_w)
156
+
157
+ return attn
158
+
159
+
160
+ def get_abs_pos(abs_pos, has_cls_token, hw):
161
+ """
162
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
163
+ dimension for the original embeddings.
164
+ Args:
165
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
166
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
167
+ hw (Tuple): size of input image tokens.
168
+
169
+ Returns:
170
+ Absolute positional embeddings after processing with shape (1, H, W, C)
171
+ """
172
+ h, w = hw
173
+ if has_cls_token:
174
+ abs_pos = abs_pos[:, 1:]
175
+ xy_num = abs_pos.shape[1]
176
+ size = int(math.sqrt(xy_num))
177
+ assert size * size == xy_num
178
+
179
+ if size != h or size != w:
180
+ new_abs_pos = F.interpolate(
181
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
182
+ size=(h, w),
183
+ mode="bicubic",
184
+ align_corners=False,
185
+ )
186
+
187
+ return new_abs_pos.permute(0, 2, 3, 1)
188
+ else:
189
+ return abs_pos.reshape(1, h, w, -1)
190
+
191
+
192
+ class PatchEmbed(nn.Module):
193
+ """
194
+ Image to Patch Embedding.
195
+ """
196
+
197
+ def __init__(
198
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
199
+ ):
200
+ """
201
+ Args:
202
+ kernel_size (Tuple): kernel size of the projection layer.
203
+ stride (Tuple): stride of the projection layer.
204
+ padding (Tuple): padding size of the projection layer.
205
+ in_chans (int): Number of input image channels.
206
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
207
+ """
208
+ super().__init__()
209
+
210
+ self.proj = nn.Conv2d(
211
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
212
+ )
213
+
214
+ def forward(self, x):
215
+ x = self.proj(x)
216
+ # B C H W -> B H W C
217
+ x = x.permute(0, 2, 3, 1)
218
+ return x
219
+
220
+
221
+
222
+
223
+ from math import pi
224
+
225
+ import torch
226
+ from torch import nn
227
+
228
+ from einops import rearrange, repeat
229
+
230
+
231
+
232
+ def broadcat(tensors, dim = -1):
233
+ num_tensors = len(tensors)
234
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
235
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
236
+ shape_len = list(shape_lens)[0]
237
+ dim = (dim + shape_len) if dim < 0 else dim
238
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
239
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
240
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
241
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
242
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
243
+ expanded_dims.insert(dim, (dim, dims[dim]))
244
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
245
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
246
+ return torch.cat(tensors, dim = dim)
247
+
248
+
249
+
250
+ def rotate_half(x):
251
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
252
+ x1, x2 = x.unbind(dim = -1)
253
+ x = torch.stack((-x2, x1), dim = -1)
254
+ return rearrange(x, '... d r -> ... (d r)')
255
+
256
+
257
+
258
+ class VisionRotaryEmbedding(nn.Module):
259
+ def __init__(
260
+ self,
261
+ dim,
262
+ pt_seq_len,
263
+ ft_seq_len=None,
264
+ custom_freqs = None,
265
+ freqs_for = 'lang',
266
+ theta = 10000,
267
+ max_freq = 10,
268
+ num_freqs = 1,
269
+ ):
270
+ super().__init__()
271
+ if custom_freqs:
272
+ freqs = custom_freqs
273
+ elif freqs_for == 'lang':
274
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
275
+ elif freqs_for == 'pixel':
276
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
277
+ elif freqs_for == 'constant':
278
+ freqs = torch.ones(num_freqs).float()
279
+ else:
280
+ raise ValueError(f'unknown modality {freqs_for}')
281
+
282
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
283
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
284
+
285
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
286
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
287
+
288
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
289
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
290
+
291
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
292
+
293
+ self.register_buffer("freqs_cos", freqs.cos())
294
+ self.register_buffer("freqs_sin", freqs.sin())
295
+
296
+ print('======== shape of rope freq', self.freqs_cos.shape, '========')
297
+
298
+ def forward(self, t, start_index = 0):
299
+ rot_dim = self.freqs_cos.shape[-1]
300
+ end_index = start_index + rot_dim
301
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
302
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
303
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
304
+ return torch.cat((t_left, t, t_right), dim = -1)
305
+
306
+
307
+
308
+
309
+ class VisionRotaryEmbeddingFast(nn.Module):
310
+ def __init__(
311
+ self,
312
+ dim,
313
+ pt_seq_len=16,
314
+ ft_seq_len=None,
315
+ custom_freqs = None,
316
+ freqs_for = 'lang',
317
+ theta = 10000,
318
+ max_freq = 10,
319
+ num_freqs = 1,
320
+ ):
321
+ super().__init__()
322
+ if custom_freqs:
323
+ freqs = custom_freqs
324
+ elif freqs_for == 'lang':
325
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
326
+ elif freqs_for == 'pixel':
327
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
328
+ elif freqs_for == 'constant':
329
+ freqs = torch.ones(num_freqs).float()
330
+ else:
331
+ raise ValueError(f'unknown modality {freqs_for}')
332
+
333
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
334
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
335
+
336
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
337
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
338
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
339
+
340
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
341
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
342
+
343
+ self.register_buffer("freqs_cos", freqs_cos)
344
+ self.register_buffer("freqs_sin", freqs_sin)
345
+
346
+ print('======== shape of rope freq', self.freqs_cos.shape, '========')
347
+
348
+ # def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
349
+ def forward(self, t):
350
+ if t.shape[2] != self.freqs_cos.shape[0]:
351
+ t_len = t.shape[2]
352
+ output = t * self.freqs_cos[:t_len] + rotate_half(t) * self.freqs_sin[:t_len]
353
+ else:
354
+ output = t * self.freqs_cos + rotate_half(t) * self.freqs_sin
355
+ return output
356
+
GLEE/glee/backbone/internimage.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternImage
3
+ # Copyright (c) 2022 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import trunc_normal_, DropPath
12
+
13
+ from detectron2.utils.logger import setup_logger
14
+ from detectron2.modeling.backbone import Backbone
15
+
16
+
17
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
18
+ from .ops_dcnv3 import modules as opsm
19
+
20
+
21
+
22
+ class to_channels_first(nn.Module):
23
+
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def forward(self, x):
28
+ return x.permute(0, 3, 1, 2)
29
+
30
+
31
+ class to_channels_last(nn.Module):
32
+
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ def forward(self, x):
37
+ return x.permute(0, 2, 3, 1)
38
+
39
+
40
+ def build_norm_layer(dim,
41
+ norm_layer,
42
+ in_format='channels_last',
43
+ out_format='channels_last',
44
+ eps=1e-6):
45
+ layers = []
46
+ if norm_layer == 'BN':
47
+ if in_format == 'channels_last':
48
+ layers.append(to_channels_first())
49
+ layers.append(nn.BatchNorm2d(dim))
50
+ if out_format == 'channels_last':
51
+ layers.append(to_channels_last())
52
+ elif norm_layer == 'LN':
53
+ if in_format == 'channels_first':
54
+ layers.append(to_channels_last())
55
+ layers.append(nn.LayerNorm(dim, eps=eps))
56
+ if out_format == 'channels_first':
57
+ layers.append(to_channels_first())
58
+ else:
59
+ raise NotImplementedError(
60
+ f'build_norm_layer does not support {norm_layer}')
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ def build_act_layer(act_layer):
65
+ if act_layer == 'ReLU':
66
+ return nn.ReLU(inplace=True)
67
+ elif act_layer == 'SiLU':
68
+ return nn.SiLU(inplace=True)
69
+ elif act_layer == 'GELU':
70
+ return nn.GELU()
71
+
72
+ raise NotImplementedError(f'build_act_layer does not support {act_layer}')
73
+
74
+
75
+ class CrossAttention(nn.Module):
76
+ r""" Cross Attention Module
77
+ Args:
78
+ dim (int): Number of input channels.
79
+ num_heads (int): Number of attention heads. Default: 8
80
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
81
+ Default: False.
82
+ qk_scale (float | None, optional): Override default qk scale of
83
+ head_dim ** -0.5 if set. Default: None.
84
+ attn_drop (float, optional): Dropout ratio of attention weight.
85
+ Default: 0.0
86
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
87
+ attn_head_dim (int, optional): Dimension of attention head.
88
+ out_dim (int, optional): Dimension of output.
89
+ """
90
+
91
+ def __init__(self,
92
+ dim,
93
+ num_heads=8,
94
+ qkv_bias=False,
95
+ qk_scale=None,
96
+ attn_drop=0.,
97
+ proj_drop=0.,
98
+ attn_head_dim=None,
99
+ out_dim=None):
100
+ super().__init__()
101
+ if out_dim is None:
102
+ out_dim = dim
103
+ self.num_heads = num_heads
104
+ head_dim = dim // num_heads
105
+ if attn_head_dim is not None:
106
+ head_dim = attn_head_dim
107
+ all_head_dim = head_dim * self.num_heads
108
+ self.scale = qk_scale or head_dim ** -0.5
109
+ assert all_head_dim == dim
110
+
111
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
112
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
113
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
114
+
115
+ if qkv_bias:
116
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
117
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
118
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
119
+ else:
120
+ self.q_bias = None
121
+ self.k_bias = None
122
+ self.v_bias = None
123
+
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(all_head_dim, out_dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ def forward(self, x, k=None, v=None):
129
+ B, N, C = x.shape
130
+ N_k = k.shape[1]
131
+ N_v = v.shape[1]
132
+
133
+ q_bias, k_bias, v_bias = None, None, None
134
+ if self.q_bias is not None:
135
+ q_bias = self.q_bias
136
+ k_bias = self.k_bias
137
+ v_bias = self.v_bias
138
+
139
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
140
+ q = q.reshape(B, N, 1, self.num_heads,
141
+ -1).permute(2, 0, 3, 1,
142
+ 4).squeeze(0) # (B, N_head, N_q, dim)
143
+
144
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
145
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
146
+ 4).squeeze(0)
147
+
148
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
149
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
150
+ 4).squeeze(0)
151
+
152
+ q = q * self.scale
153
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
154
+
155
+ attn = attn.softmax(dim=-1)
156
+ attn = self.attn_drop(attn)
157
+
158
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
159
+ x = self.proj(x)
160
+ x = self.proj_drop(x)
161
+
162
+ return x
163
+
164
+
165
+ class AttentiveBlock(nn.Module):
166
+ r"""Attentive Block
167
+ Args:
168
+ dim (int): Number of input channels.
169
+ num_heads (int): Number of attention heads. Default: 8
170
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
171
+ Default: False.
172
+ qk_scale (float | None, optional): Override default qk scale of
173
+ head_dim ** -0.5 if set. Default: None.
174
+ drop (float, optional): Dropout rate. Default: 0.0.
175
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0.
176
+ drop_path (float | tuple[float], optional): Stochastic depth rate.
177
+ Default: 0.0.
178
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
179
+ attn_head_dim (int, optional): Dimension of attention head. Default: None.
180
+ out_dim (int, optional): Dimension of output. Default: None.
181
+ """
182
+
183
+ def __init__(self,
184
+ dim,
185
+ num_heads,
186
+ qkv_bias=False,
187
+ qk_scale=None,
188
+ drop=0.,
189
+ attn_drop=0.,
190
+ drop_path=0.,
191
+ norm_layer="LN",
192
+ attn_head_dim=None,
193
+ out_dim=None):
194
+ super().__init__()
195
+
196
+ self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
197
+ self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
198
+ self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
199
+ self.cross_dcn = CrossAttention(dim,
200
+ num_heads=num_heads,
201
+ qkv_bias=qkv_bias,
202
+ qk_scale=qk_scale,
203
+ attn_drop=attn_drop,
204
+ proj_drop=drop,
205
+ attn_head_dim=attn_head_dim,
206
+ out_dim=out_dim)
207
+
208
+ self.drop_path = DropPath(
209
+ drop_path) if drop_path > 0. else nn.Identity()
210
+
211
+ def forward(self,
212
+ x_q,
213
+ x_kv,
214
+ pos_q,
215
+ pos_k,
216
+ bool_masked_pos,
217
+ rel_pos_bias=None):
218
+ x_q = self.norm1_q(x_q + pos_q)
219
+ x_k = self.norm1_k(x_kv + pos_k)
220
+ x_v = self.norm1_v(x_kv)
221
+
222
+ x = self.cross_dcn(x_q, k=x_k, v=x_v)
223
+
224
+ return x
225
+
226
+
227
+ class AttentionPoolingBlock(AttentiveBlock):
228
+
229
+ def forward(self, x):
230
+ x_q = x.mean(1, keepdim=True)
231
+ x_kv = x
232
+ pos_q, pos_k = 0, 0
233
+ x = super().forward(x_q, x_kv, pos_q, pos_k,
234
+ bool_masked_pos=None,
235
+ rel_pos_bias=None)
236
+ x = x.squeeze(1)
237
+ return x
238
+
239
+
240
+ class StemLayer(nn.Module):
241
+ r""" Stem layer of InternImage
242
+ Args:
243
+ in_chans (int): number of input channels
244
+ out_chans (int): number of output channels
245
+ act_layer (str): activation layer
246
+ norm_layer (str): normalization layer
247
+ """
248
+
249
+ def __init__(self,
250
+ in_chans=3,
251
+ out_chans=96,
252
+ act_layer='GELU',
253
+ norm_layer='BN'):
254
+ super().__init__()
255
+ self.conv1 = nn.Conv2d(in_chans,
256
+ out_chans // 2,
257
+ kernel_size=3,
258
+ stride=2,
259
+ padding=1)
260
+ self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
261
+ 'channels_first', 'channels_first')
262
+ self.act = build_act_layer(act_layer)
263
+ self.conv2 = nn.Conv2d(out_chans // 2,
264
+ out_chans,
265
+ kernel_size=3,
266
+ stride=2,
267
+ padding=1)
268
+ self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
269
+ 'channels_last')
270
+
271
+ def forward(self, x):
272
+ x = self.conv1(x)
273
+ x = self.norm1(x)
274
+ x = self.act(x)
275
+ x = self.conv2(x)
276
+ x = self.norm2(x)
277
+ return x
278
+
279
+
280
+ class DownsampleLayer(nn.Module):
281
+ r""" Downsample layer of InternImage
282
+ Args:
283
+ channels (int): number of input channels
284
+ norm_layer (str): normalization layer
285
+ """
286
+
287
+ def __init__(self, channels, norm_layer='LN'):
288
+ super().__init__()
289
+ self.conv = nn.Conv2d(channels,
290
+ 2 * channels,
291
+ kernel_size=3,
292
+ stride=2,
293
+ padding=1,
294
+ bias=False)
295
+ self.norm = build_norm_layer(2 * channels, norm_layer,
296
+ 'channels_first', 'channels_last')
297
+
298
+ def forward(self, x):
299
+ x = self.conv(x.permute(0, 3, 1, 2))
300
+ x = self.norm(x)
301
+ return x
302
+
303
+
304
+ class MLPLayer(nn.Module):
305
+ r""" MLP layer of InternImage
306
+ Args:
307
+ in_features (int): number of input features
308
+ hidden_features (int): number of hidden features
309
+ out_features (int): number of output features
310
+ act_layer (str): activation layer
311
+ drop (float): dropout rate
312
+ """
313
+
314
+ def __init__(self,
315
+ in_features,
316
+ hidden_features=None,
317
+ out_features=None,
318
+ act_layer='GELU',
319
+ drop=0.):
320
+ super().__init__()
321
+ out_features = out_features or in_features
322
+ hidden_features = hidden_features or in_features
323
+ self.fc1 = nn.Linear(in_features, hidden_features)
324
+ self.act = build_act_layer(act_layer)
325
+ self.fc2 = nn.Linear(hidden_features, out_features)
326
+ self.drop = nn.Dropout(drop)
327
+
328
+ def forward(self, x):
329
+ x = self.fc1(x)
330
+ x = self.act(x)
331
+ x = self.drop(x)
332
+ x = self.fc2(x)
333
+ x = self.drop(x)
334
+ return x
335
+
336
+
337
+ class InternImageLayer(nn.Module):
338
+ r""" Basic layer of InternImage
339
+ Args:
340
+ core_op (nn.Module): core operation of InternImage
341
+ channels (int): number of input channels
342
+ groups (list): Groups of each block.
343
+ mlp_ratio (float): ratio of mlp hidden features to input channels
344
+ drop (float): dropout rate
345
+ drop_path (float): drop path rate
346
+ act_layer (str): activation layer
347
+ norm_layer (str): normalization layer
348
+ post_norm (bool): whether to use post normalization
349
+ layer_scale (float): layer scale
350
+ offset_scale (float): offset scale
351
+ with_cp (bool): whether to use checkpoint
352
+ """
353
+
354
+ def __init__(self,
355
+ core_op,
356
+ channels,
357
+ groups,
358
+ mlp_ratio=4.,
359
+ drop=0.,
360
+ drop_path=0.,
361
+ act_layer='GELU',
362
+ norm_layer='LN',
363
+ post_norm=False,
364
+ layer_scale=None,
365
+ offset_scale=1.0,
366
+ with_cp=False,
367
+ dw_kernel_size=None, # for InternImage-H/G
368
+ res_post_norm=False, # for InternImage-H/G
369
+ center_feature_scale=False): # for InternImage-H/G
370
+ super().__init__()
371
+ self.channels = channels
372
+ self.groups = groups
373
+ self.mlp_ratio = mlp_ratio
374
+ self.with_cp = with_cp
375
+
376
+ self.norm1 = build_norm_layer(channels, 'LN')
377
+ self.post_norm = post_norm
378
+ self.dcn = core_op(
379
+ channels=channels,
380
+ kernel_size=3,
381
+ stride=1,
382
+ pad=1,
383
+ dilation=1,
384
+ group=groups,
385
+ offset_scale=offset_scale,
386
+ act_layer=act_layer,
387
+ norm_layer=norm_layer,
388
+ dw_kernel_size=dw_kernel_size, # for InternImage-H/G
389
+ center_feature_scale=center_feature_scale) # for InternImage-H/G
390
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
391
+ else nn.Identity()
392
+ self.norm2 = build_norm_layer(channels, 'LN')
393
+ self.mlp = MLPLayer(in_features=channels,
394
+ hidden_features=int(channels * mlp_ratio),
395
+ act_layer=act_layer,
396
+ drop=drop)
397
+ self.layer_scale = layer_scale is not None
398
+ if self.layer_scale:
399
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
400
+ requires_grad=True)
401
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
402
+ requires_grad=True)
403
+ self.res_post_norm = res_post_norm
404
+ if res_post_norm:
405
+ self.res_post_norm1 = build_norm_layer(channels, 'LN')
406
+ self.res_post_norm2 = build_norm_layer(channels, 'LN')
407
+
408
+ def forward(self, x):
409
+
410
+ def _inner_forward(x):
411
+ if not self.layer_scale:
412
+ if self.post_norm:
413
+ x = x + self.drop_path(self.norm1(self.dcn(x)))
414
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
415
+ elif self.res_post_norm: # for InternImage-H/G
416
+ x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
417
+ x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
418
+ else:
419
+ x = x + self.drop_path(self.dcn(self.norm1(x)))
420
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
421
+ return x
422
+ if self.post_norm:
423
+ x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
424
+ x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
425
+ else:
426
+ x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
427
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
428
+ return x
429
+
430
+ if self.with_cp and x.requires_grad:
431
+ x = checkpoint.checkpoint(_inner_forward, x)
432
+ else:
433
+ x = _inner_forward(x)
434
+ return x
435
+
436
+
437
+ class InternImageBlock(nn.Module):
438
+ r""" Block of InternImage
439
+ Args:
440
+ core_op (nn.Module): core operation of InternImage
441
+ channels (int): number of input channels
442
+ depths (list): Depth of each block.
443
+ groups (list): Groups of each block.
444
+ mlp_ratio (float): ratio of mlp hidden features to input channels
445
+ drop (float): dropout rate
446
+ drop_path (float): drop path rate
447
+ act_layer (str): activation layer
448
+ norm_layer (str): normalization layer
449
+ post_norm (bool): whether to use post normalization
450
+ layer_scale (float): layer scale
451
+ offset_scale (float): offset scale
452
+ with_cp (bool): whether to use checkpoint
453
+ """
454
+
455
+ def __init__(self,
456
+ core_op,
457
+ channels,
458
+ depth,
459
+ groups,
460
+ downsample=True,
461
+ mlp_ratio=4.,
462
+ drop=0.,
463
+ drop_path=0.,
464
+ act_layer='GELU',
465
+ norm_layer='LN',
466
+ post_norm=False,
467
+ offset_scale=1.0,
468
+ layer_scale=None,
469
+ with_cp=False,
470
+ dw_kernel_size=None, # for InternImage-H/G
471
+ post_norm_block_ids=None, # for InternImage-H/G
472
+ res_post_norm=False, # for InternImage-H/G
473
+ center_feature_scale=False): # for InternImage-H/G
474
+ super().__init__()
475
+ self.channels = channels
476
+ self.depth = depth
477
+ self.post_norm = post_norm
478
+ self.center_feature_scale = center_feature_scale
479
+
480
+ self.blocks = nn.ModuleList([
481
+ InternImageLayer(
482
+ core_op=core_op,
483
+ channels=channels,
484
+ groups=groups,
485
+ mlp_ratio=mlp_ratio,
486
+ drop=drop,
487
+ drop_path=drop_path[i] if isinstance(
488
+ drop_path, list) else drop_path,
489
+ act_layer=act_layer,
490
+ norm_layer=norm_layer,
491
+ post_norm=post_norm,
492
+ layer_scale=layer_scale,
493
+ offset_scale=offset_scale,
494
+ with_cp=with_cp,
495
+ dw_kernel_size=dw_kernel_size, # for InternImage-H/G
496
+ res_post_norm=res_post_norm, # for InternImage-H/G
497
+ center_feature_scale=center_feature_scale # for InternImage-H/G
498
+ ) for i in range(depth)
499
+ ])
500
+ if not self.post_norm or center_feature_scale:
501
+ self.norm = build_norm_layer(channels, 'LN')
502
+ self.post_norm_block_ids = post_norm_block_ids
503
+ if post_norm_block_ids is not None: # for InternImage-H/G
504
+ self.post_norms = nn.ModuleList(
505
+ [build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
506
+ )
507
+ self.downsample = DownsampleLayer(
508
+ channels=channels, norm_layer=norm_layer) if downsample else None
509
+
510
+ def forward(self, x, return_wo_downsample=False):
511
+ for i, blk in enumerate(self.blocks):
512
+ x = blk(x)
513
+ if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
514
+ index = self.post_norm_block_ids.index(i)
515
+ x = self.post_norms[index](x) # for InternImage-H/G
516
+ if not self.post_norm or self.center_feature_scale:
517
+ x = self.norm(x)
518
+ if return_wo_downsample:
519
+ x_ = x
520
+ if self.downsample is not None:
521
+ x = self.downsample(x)
522
+
523
+ if return_wo_downsample:
524
+ return x, x_
525
+ return x
526
+
527
+ class InternImage(Backbone):
528
+ r""" InternImage
529
+ A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
530
+ https://arxiv.org/pdf/2103.14030
531
+ Args:
532
+ core_op (str): Core operator. Default: 'DCNv3'
533
+ channels (int): Number of the first stage. Default: 64
534
+ depths (list): Depth of each block. Default: [3, 4, 18, 5]
535
+ groups (list): Groups of each block. Default: [3, 6, 12, 24]
536
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
537
+ drop_rate (float): Probability of an element to be zeroed. Default: 0.
538
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
539
+ act_layer (str): Activation layer. Default: 'GELU'
540
+ norm_layer (str): Normalization layer. Default: 'LN'
541
+ layer_scale (bool): Whether to use layer scale. Default: False
542
+ cls_scale (bool): Whether to use class scale. Default: False
543
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
544
+ dw_kernel_size (int): Size of the dwconv. Default: None
545
+ level2_post_norm (bool): Whether to use level2 post norm. Default: False
546
+ level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
547
+ res_post_norm (bool): Whether to use res post norm. Default: False
548
+ center_feature_scale (bool): Whether to use center feature scale. Default: False
549
+ """
550
+
551
+ def __init__(self,
552
+ core_op='DCNv3',
553
+ channels=64,
554
+ depths=[3, 4, 18, 5],
555
+ groups=[3, 6, 12, 24],
556
+ mlp_ratio=4.,
557
+ drop_rate=0.,
558
+ drop_path_rate=0.2,
559
+ drop_path_type='linear',
560
+ act_layer='GELU',
561
+ norm_layer='LN',
562
+ layer_scale=None,
563
+ offset_scale=1.0,
564
+ post_norm=False,
565
+ with_cp=False,
566
+ dw_kernel_size=None, # for InternImage-H/G
567
+ level2_post_norm=False, # for InternImage-H/G
568
+ level2_post_norm_block_ids=None, # for InternImage-H/G
569
+ res_post_norm=False, # for InternImage-H/G
570
+ center_feature_scale=False, # for InternImage-H/G
571
+ out_indices=(0, 1, 2, 3),
572
+ init_cfg=None,
573
+ **kwargs):
574
+ super().__init__()
575
+ self.core_op = core_op
576
+ self.num_levels = len(depths)
577
+ self.depths = depths
578
+ self.channels = channels
579
+ self.num_features = int(channels * 2**(self.num_levels - 1))
580
+ self.post_norm = post_norm
581
+ self.mlp_ratio = mlp_ratio
582
+ self.init_cfg = init_cfg
583
+ self.out_indices = out_indices
584
+ self.level2_post_norm_block_ids = level2_post_norm_block_ids
585
+ logger = setup_logger(name="InternImage")
586
+ logger.info(f'using core type: {core_op}')
587
+ logger.info(f'using activation layer: {act_layer}')
588
+ logger.info(f'using main norm layer: {norm_layer}')
589
+ logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}')
590
+ logger.info(f"level2_post_norm: {level2_post_norm}")
591
+ logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
592
+ logger.info(f"res_post_norm: {res_post_norm}")
593
+
594
+ in_chans = 3
595
+ self.patch_embed = StemLayer(in_chans=in_chans,
596
+ out_chans=channels,
597
+ act_layer=act_layer,
598
+ norm_layer=norm_layer)
599
+ self.pos_drop = nn.Dropout(p=drop_rate)
600
+
601
+ dpr = [
602
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
603
+ ]
604
+ if drop_path_type == 'uniform':
605
+ for i in range(len(dpr)):
606
+ dpr[i] = drop_path_rate
607
+
608
+ self.levels = nn.ModuleList()
609
+ for i in range(self.num_levels):
610
+ post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
611
+ i == 2) else None # for InternImage-H/G
612
+ level = InternImageBlock(
613
+ core_op=getattr(opsm, core_op),
614
+ channels=int(channels * 2**i),
615
+ depth=depths[i],
616
+ groups=groups[i],
617
+ mlp_ratio=self.mlp_ratio,
618
+ drop=drop_rate,
619
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
620
+ act_layer=act_layer,
621
+ norm_layer=norm_layer,
622
+ post_norm=post_norm,
623
+ downsample=(i < self.num_levels - 1),
624
+ layer_scale=layer_scale,
625
+ offset_scale=offset_scale,
626
+ with_cp=with_cp,
627
+ dw_kernel_size=dw_kernel_size, # for InternImage-H/G
628
+ post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
629
+ res_post_norm=res_post_norm, # for InternImage-H/G
630
+ center_feature_scale=center_feature_scale # for InternImage-H/G
631
+ )
632
+ self.levels.append(level)
633
+
634
+ self.num_layers = len(depths)
635
+ self.apply(self._init_weights)
636
+ self.apply(self._init_deform_weights)
637
+
638
+ # add basic info for d2 backbone
639
+ self._out_features = ["res{}".format(i+2) for i in self.out_indices]
640
+ self._out_feature_channels = {
641
+ "res{}".format(i+2): self.channels * 2**i for i in self.out_indices
642
+ }
643
+ self._out_feature_strides = {"res{}".format(i+2): 2 ** (i + 2) for i in self.out_indices}
644
+ self._size_devisibility = 32
645
+
646
+
647
+ def _init_weights(self, m):
648
+ if isinstance(m, nn.Linear):
649
+ trunc_normal_(m.weight, std=.02)
650
+ if isinstance(m, nn.Linear) and m.bias is not None:
651
+ nn.init.constant_(m.bias, 0)
652
+ elif isinstance(m, nn.LayerNorm):
653
+ nn.init.constant_(m.bias, 0)
654
+ nn.init.constant_(m.weight, 1.0)
655
+
656
+ def _init_deform_weights(self, m):
657
+ if isinstance(m, getattr(opsm, self.core_op)):
658
+ m._reset_parameters()
659
+
660
+ def forward(self, x):
661
+ x = self.patch_embed(x)
662
+ x = self.pos_drop(x)
663
+
664
+ # d2 need dict output
665
+ # seq_out = []
666
+ seq_out = {}
667
+ for level_idx, level in enumerate(self.levels):
668
+ x, x_ = level(x, return_wo_downsample=True)
669
+ if level_idx in self.out_indices:
670
+ # seq_out.append(x_.permute(0, 3, 1, 2).contiguous())
671
+ seq_out["res{}".format(level_idx+2)] = x_.permute(0, 3, 1, 2).contiguous()
672
+ return seq_out
673
+
674
+ @BACKBONE_REGISTRY.register()
675
+ class D2InternImage(InternImage):
676
+ def __init__(self, cfg, input_shape):
677
+
678
+ super().__init__(
679
+ core_op= cfg.MODEL.INTERNIMAGE.CORE_OP ,
680
+ channels=cfg.MODEL.INTERNIMAGE.CHANNELS,
681
+ depths=cfg.MODEL.INTERNIMAGE.DEPTHS,
682
+ groups=cfg.MODEL.INTERNIMAGE.GROUPS,
683
+ mlp_ratio= cfg.MODEL.INTERNIMAGE.MLP_RATIO ,
684
+ drop_path_rate=cfg.MODEL.INTERNIMAGE.DROP_PATH_RATE,
685
+ norm_layer=cfg.MODEL.INTERNIMAGE.NORM_LAYER,
686
+ layer_scale=cfg.MODEL.INTERNIMAGE.LAYER_SCALE ,
687
+ offset_scale=cfg.MODEL.INTERNIMAGE.OFFSET_SCALE,
688
+ post_norm=cfg.MODEL.INTERNIMAGE.POST_NORM,
689
+ with_cp=cfg.MODEL.INTERNIMAGE.WITH_CP ,
690
+ out_indices=cfg.MODEL.INTERNIMAGE.OUT_IINDICES,
691
+ dw_kernel_size= cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE, # for InternImage-H/G
692
+ res_post_norm= cfg.MODEL.INTERNIMAGE.RES_POST_NORM, # for InternImage-H/G
693
+ level2_post_norm= cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
694
+ level2_post_norm_block_ids= cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
695
+ center_feature_scale= cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
696
+
697
+
698
+ )
699
+
700
+
701
+ pretrained_weight = cfg.MODEL.INTERNIMAGE.PRETRAINED_WEIGHT
702
+ if pretrained_weight:
703
+ checkpoint = torch.load(pretrained_weight, map_location='cpu')
704
+ print(f'\nload pretrain weight from {pretrained_weight} \n')
705
+ self.load_state_dict(checkpoint['model'], strict=False)
706
+
707
+
708
+ def forward(self, x):
709
+ """
710
+ Args:
711
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
712
+ Returns:
713
+ dict[str->Tensor]: names and the corresponding features
714
+ """
715
+ assert (
716
+ x.dim() == 4
717
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
718
+ outputs = {}
719
+ y = super().forward(x)
720
+ for k in y.keys():
721
+ if k in self._out_features:
722
+ outputs[k] = y[k]
723
+ return outputs
724
+
725
+ def output_shape(self):
726
+ return {
727
+ name: ShapeSpec(
728
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
729
+ )
730
+ for name in self._out_features
731
+ }
732
+
733
+ @property
734
+ def size_divisibility(self):
735
+ return 32
736
+
737
+
GLEE/glee/backbone/registry.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+
4
+ def register_backbone(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+ _model_entrypoints[model_name] = fn
8
+ return fn
9
+
10
+ def model_entrypoints(model_name):
11
+ return _model_entrypoints[model_name]
12
+
13
+ def is_model(model_name):
14
+ return model_name in _model_entrypoints
GLEE/glee/backbone/resnet.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import pickle
3
+ import numpy as np
4
+ from typing import Any, Dict
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ from .backbone import Backbone
12
+ from .registry import register_backbone
13
+
14
+ from detectron2.layers import (
15
+ CNNBlockBase,
16
+ Conv2d,
17
+ DeformConv,
18
+ ModulatedDeformConv,
19
+ ShapeSpec,
20
+ get_norm,
21
+ )
22
+ from detectron2.utils.file_io import PathManager
23
+
24
+ __all__ = [
25
+ "ResNetBlockBase",
26
+ "BasicBlock",
27
+ "BottleneckBlock",
28
+ "DeformBottleneckBlock",
29
+ "BasicStem",
30
+ "ResNet",
31
+ "make_stage",
32
+ "get_resnet_backbone",
33
+ ]
34
+
35
+
36
+ class BasicBlock(CNNBlockBase):
37
+ """
38
+ The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
39
+ with two 3x3 conv layers and a projection shortcut if needed.
40
+ """
41
+
42
+ def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
43
+ """
44
+ Args:
45
+ in_channels (int): Number of input channels.
46
+ out_channels (int): Number of output channels.
47
+ stride (int): Stride for the first conv.
48
+ norm (str or callable): normalization for all conv layers.
49
+ See :func:`layers.get_norm` for supported format.
50
+ """
51
+ super().__init__(in_channels, out_channels, stride)
52
+
53
+ if in_channels != out_channels:
54
+ self.shortcut = Conv2d(
55
+ in_channels,
56
+ out_channels,
57
+ kernel_size=1,
58
+ stride=stride,
59
+ bias=False,
60
+ norm=get_norm(norm, out_channels),
61
+ )
62
+ else:
63
+ self.shortcut = None
64
+
65
+ self.conv1 = Conv2d(
66
+ in_channels,
67
+ out_channels,
68
+ kernel_size=3,
69
+ stride=stride,
70
+ padding=1,
71
+ bias=False,
72
+ norm=get_norm(norm, out_channels),
73
+ )
74
+
75
+ self.conv2 = Conv2d(
76
+ out_channels,
77
+ out_channels,
78
+ kernel_size=3,
79
+ stride=1,
80
+ padding=1,
81
+ bias=False,
82
+ norm=get_norm(norm, out_channels),
83
+ )
84
+
85
+ for layer in [self.conv1, self.conv2, self.shortcut]:
86
+ if layer is not None: # shortcut can be None
87
+ weight_init.c2_msra_fill(layer)
88
+
89
+ def forward(self, x):
90
+ out = self.conv1(x)
91
+ out = F.relu_(out)
92
+ out = self.conv2(out)
93
+
94
+ if self.shortcut is not None:
95
+ shortcut = self.shortcut(x)
96
+ else:
97
+ shortcut = x
98
+
99
+ out += shortcut
100
+ out = F.relu_(out)
101
+ return out
102
+
103
+
104
+ class BottleneckBlock(CNNBlockBase):
105
+ """
106
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
107
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
108
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ in_channels,
114
+ out_channels,
115
+ *,
116
+ bottleneck_channels,
117
+ stride=1,
118
+ num_groups=1,
119
+ norm="BN",
120
+ stride_in_1x1=False,
121
+ dilation=1,
122
+ ):
123
+ """
124
+ Args:
125
+ bottleneck_channels (int): number of output channels for the 3x3
126
+ "bottleneck" conv layers.
127
+ num_groups (int): number of groups for the 3x3 conv layer.
128
+ norm (str or callable): normalization for all conv layers.
129
+ See :func:`layers.get_norm` for supported format.
130
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
131
+ first 1x1 convolution or the bottleneck 3x3 convolution.
132
+ dilation (int): the dilation rate of the 3x3 conv layer.
133
+ """
134
+ super().__init__(in_channels, out_channels, stride)
135
+
136
+ if in_channels != out_channels:
137
+ self.shortcut = Conv2d(
138
+ in_channels,
139
+ out_channels,
140
+ kernel_size=1,
141
+ stride=stride,
142
+ bias=False,
143
+ norm=get_norm(norm, out_channels),
144
+ )
145
+ else:
146
+ self.shortcut = None
147
+
148
+ # The original MSRA ResNet models have stride in the first 1x1 conv
149
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
150
+ # stride in the 3x3 conv
151
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
152
+
153
+ self.conv1 = Conv2d(
154
+ in_channels,
155
+ bottleneck_channels,
156
+ kernel_size=1,
157
+ stride=stride_1x1,
158
+ bias=False,
159
+ norm=get_norm(norm, bottleneck_channels),
160
+ )
161
+
162
+ self.conv2 = Conv2d(
163
+ bottleneck_channels,
164
+ bottleneck_channels,
165
+ kernel_size=3,
166
+ stride=stride_3x3,
167
+ padding=1 * dilation,
168
+ bias=False,
169
+ groups=num_groups,
170
+ dilation=dilation,
171
+ norm=get_norm(norm, bottleneck_channels),
172
+ )
173
+
174
+ self.conv3 = Conv2d(
175
+ bottleneck_channels,
176
+ out_channels,
177
+ kernel_size=1,
178
+ bias=False,
179
+ norm=get_norm(norm, out_channels),
180
+ )
181
+
182
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
183
+ if layer is not None: # shortcut can be None
184
+ weight_init.c2_msra_fill(layer)
185
+
186
+ # Zero-initialize the last normalization in each residual branch,
187
+ # so that at the beginning, the residual branch starts with zeros,
188
+ # and each residual block behaves like an identity.
189
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
190
+ # "For BN layers, the learnable scaling coefficient γ is initialized
191
+ # to be 1, except for each residual block's last BN
192
+ # where γ is initialized to be 0."
193
+
194
+ # nn.init.constant_(self.conv3.norm.weight, 0)
195
+ # TODO this somehow hurts performance when training GN models from scratch.
196
+ # Add it as an option when we need to use this code to train a backbone.
197
+
198
+ def forward(self, x):
199
+ out = self.conv1(x)
200
+ out = F.relu_(out)
201
+
202
+ out = self.conv2(out)
203
+ out = F.relu_(out)
204
+
205
+ out = self.conv3(out)
206
+
207
+ if self.shortcut is not None:
208
+ shortcut = self.shortcut(x)
209
+ else:
210
+ shortcut = x
211
+
212
+ out += shortcut
213
+ out = F.relu_(out)
214
+ return out
215
+
216
+
217
+ class DeformBottleneckBlock(CNNBlockBase):
218
+ """
219
+ Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
220
+ in the 3x3 convolution.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ in_channels,
226
+ out_channels,
227
+ *,
228
+ bottleneck_channels,
229
+ stride=1,
230
+ num_groups=1,
231
+ norm="BN",
232
+ stride_in_1x1=False,
233
+ dilation=1,
234
+ deform_modulated=False,
235
+ deform_num_groups=1,
236
+ ):
237
+ super().__init__(in_channels, out_channels, stride)
238
+ self.deform_modulated = deform_modulated
239
+
240
+ if in_channels != out_channels:
241
+ self.shortcut = Conv2d(
242
+ in_channels,
243
+ out_channels,
244
+ kernel_size=1,
245
+ stride=stride,
246
+ bias=False,
247
+ norm=get_norm(norm, out_channels),
248
+ )
249
+ else:
250
+ self.shortcut = None
251
+
252
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
253
+
254
+ self.conv1 = Conv2d(
255
+ in_channels,
256
+ bottleneck_channels,
257
+ kernel_size=1,
258
+ stride=stride_1x1,
259
+ bias=False,
260
+ norm=get_norm(norm, bottleneck_channels),
261
+ )
262
+
263
+ if deform_modulated:
264
+ deform_conv_op = ModulatedDeformConv
265
+ # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
266
+ offset_channels = 27
267
+ else:
268
+ deform_conv_op = DeformConv
269
+ offset_channels = 18
270
+
271
+ self.conv2_offset = Conv2d(
272
+ bottleneck_channels,
273
+ offset_channels * deform_num_groups,
274
+ kernel_size=3,
275
+ stride=stride_3x3,
276
+ padding=1 * dilation,
277
+ dilation=dilation,
278
+ )
279
+ self.conv2 = deform_conv_op(
280
+ bottleneck_channels,
281
+ bottleneck_channels,
282
+ kernel_size=3,
283
+ stride=stride_3x3,
284
+ padding=1 * dilation,
285
+ bias=False,
286
+ groups=num_groups,
287
+ dilation=dilation,
288
+ deformable_groups=deform_num_groups,
289
+ norm=get_norm(norm, bottleneck_channels),
290
+ )
291
+
292
+ self.conv3 = Conv2d(
293
+ bottleneck_channels,
294
+ out_channels,
295
+ kernel_size=1,
296
+ bias=False,
297
+ norm=get_norm(norm, out_channels),
298
+ )
299
+
300
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
301
+ if layer is not None: # shortcut can be None
302
+ weight_init.c2_msra_fill(layer)
303
+
304
+ nn.init.constant_(self.conv2_offset.weight, 0)
305
+ nn.init.constant_(self.conv2_offset.bias, 0)
306
+
307
+ def forward(self, x):
308
+ out = self.conv1(x)
309
+ out = F.relu_(out)
310
+
311
+ if self.deform_modulated:
312
+ offset_mask = self.conv2_offset(out)
313
+ offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
314
+ offset = torch.cat((offset_x, offset_y), dim=1)
315
+ mask = mask.sigmoid()
316
+ out = self.conv2(out, offset, mask)
317
+ else:
318
+ offset = self.conv2_offset(out)
319
+ out = self.conv2(out, offset)
320
+ out = F.relu_(out)
321
+
322
+ out = self.conv3(out)
323
+
324
+ if self.shortcut is not None:
325
+ shortcut = self.shortcut(x)
326
+ else:
327
+ shortcut = x
328
+
329
+ out += shortcut
330
+ out = F.relu_(out)
331
+ return out
332
+
333
+
334
+ class BasicStem(CNNBlockBase):
335
+ """
336
+ The standard ResNet stem (layers before the first residual block),
337
+ with a conv, relu and max_pool.
338
+ """
339
+
340
+ def __init__(self, in_channels=3, out_channels=64, norm="BN"):
341
+ """
342
+ Args:
343
+ norm (str or callable): norm after the first conv layer.
344
+ See :func:`layers.get_norm` for supported format.
345
+ """
346
+ super().__init__(in_channels, out_channels, 4)
347
+ self.in_channels = in_channels
348
+ self.conv1 = Conv2d(
349
+ in_channels,
350
+ out_channels,
351
+ kernel_size=7,
352
+ stride=2,
353
+ padding=3,
354
+ bias=False,
355
+ norm=get_norm(norm, out_channels),
356
+ )
357
+ weight_init.c2_msra_fill(self.conv1)
358
+
359
+ def forward(self, x):
360
+ x = self.conv1(x)
361
+ x = F.relu_(x)
362
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
363
+ return x
364
+
365
+
366
+ class ResNet(Backbone):
367
+ """
368
+ Implement :paper:`ResNet`.
369
+ """
370
+
371
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
372
+ """
373
+ Args:
374
+ stem (nn.Module): a stem module
375
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
376
+ each contains multiple :class:`CNNBlockBase`.
377
+ num_classes (None or int): if None, will not perform classification.
378
+ Otherwise, will create a linear layer.
379
+ out_features (list[str]): name of the layers whose outputs should
380
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
381
+ If None, will return the output of the last layer.
382
+ freeze_at (int): The number of stages at the beginning to freeze.
383
+ see :meth:`freeze` for detailed explanation.
384
+ """
385
+ super().__init__()
386
+ self.stem = stem
387
+ self.num_classes = num_classes
388
+
389
+ current_stride = self.stem.stride
390
+ self._out_feature_strides = {"stem": current_stride}
391
+ self._out_feature_channels = {"stem": self.stem.out_channels}
392
+
393
+ self.stage_names, self.stages = [], []
394
+
395
+ if out_features is not None:
396
+ # Avoid keeping unused layers in this module. They consume extra memory
397
+ # and may cause allreduce to fail
398
+ num_stages = max(
399
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
400
+ )
401
+ stages = stages[:num_stages]
402
+ for i, blocks in enumerate(stages):
403
+ assert len(blocks) > 0, len(blocks)
404
+ for block in blocks:
405
+ assert isinstance(block, CNNBlockBase), block
406
+
407
+ name = "res" + str(i + 2)
408
+ stage = nn.Sequential(*blocks)
409
+
410
+ self.add_module(name, stage)
411
+ self.stage_names.append(name)
412
+ self.stages.append(stage)
413
+
414
+ self._out_feature_strides[name] = current_stride = int(
415
+ current_stride * np.prod([k.stride for k in blocks])
416
+ )
417
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
418
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
419
+
420
+ if num_classes is not None:
421
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
422
+ self.linear = nn.Linear(curr_channels, num_classes)
423
+
424
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
425
+ # "The 1000-way fully-connected layer is initialized by
426
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
427
+ nn.init.normal_(self.linear.weight, std=0.01)
428
+ name = "linear"
429
+
430
+ if out_features is None:
431
+ out_features = [name]
432
+ self._out_features = out_features
433
+ assert len(self._out_features)
434
+ children = [x[0] for x in self.named_children()]
435
+ for out_feature in self._out_features:
436
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
437
+ self.freeze(freeze_at)
438
+
439
+ def forward(self, x):
440
+ """
441
+ Args:
442
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
443
+
444
+ Returns:
445
+ dict[str->Tensor]: names and the corresponding features
446
+ """
447
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
448
+ outputs = {}
449
+ x = self.stem(x)
450
+ if "stem" in self._out_features:
451
+ outputs["stem"] = x
452
+ for name, stage in zip(self.stage_names, self.stages):
453
+ x = stage(x)
454
+ if name in self._out_features:
455
+ outputs[name] = x
456
+ if self.num_classes is not None:
457
+ x = self.avgpool(x)
458
+ x = torch.flatten(x, 1)
459
+ x = self.linear(x)
460
+ if "linear" in self._out_features:
461
+ outputs["linear"] = x
462
+ return outputs
463
+
464
+ def output_shape(self):
465
+ return {
466
+ name: ShapeSpec(
467
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
468
+ )
469
+ for name in self._out_features
470
+ }
471
+
472
+ def freeze(self, freeze_at=0):
473
+ """
474
+ Freeze the first several stages of the ResNet. Commonly used in
475
+ fine-tuning.
476
+
477
+ Layers that produce the same feature map spatial size are defined as one
478
+ "stage" by :paper:`FPN`.
479
+
480
+ Args:
481
+ freeze_at (int): number of stages to freeze.
482
+ `1` means freezing the stem. `2` means freezing the stem and
483
+ one residual stage, etc.
484
+
485
+ Returns:
486
+ nn.Module: this ResNet itself
487
+ """
488
+ if freeze_at >= 1:
489
+ self.stem.freeze()
490
+ for idx, stage in enumerate(self.stages, start=2):
491
+ if freeze_at >= idx:
492
+ for block in stage.children():
493
+ block.freeze()
494
+ return self
495
+
496
+ @staticmethod
497
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
498
+ """
499
+ Create a list of blocks of the same type that forms one ResNet stage.
500
+
501
+ Args:
502
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
503
+ stage. A module of this type must not change spatial resolution of inputs unless its
504
+ stride != 1.
505
+ num_blocks (int): number of blocks in this stage
506
+ in_channels (int): input channels of the entire stage.
507
+ out_channels (int): output channels of **every block** in the stage.
508
+ kwargs: other arguments passed to the constructor of
509
+ `block_class`. If the argument name is "xx_per_block", the
510
+ argument is a list of values to be passed to each block in the
511
+ stage. Otherwise, the same argument is passed to every block
512
+ in the stage.
513
+
514
+ Returns:
515
+ list[CNNBlockBase]: a list of block module.
516
+
517
+ Examples:
518
+ ::
519
+ stage = ResNet.make_stage(
520
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
521
+ bottleneck_channels=16, num_groups=1,
522
+ stride_per_block=[2, 1, 1],
523
+ dilations_per_block=[1, 1, 2]
524
+ )
525
+
526
+ Usually, layers that produce the same feature map spatial size are defined as one
527
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
528
+ all be 1.
529
+ """
530
+ blocks = []
531
+ for i in range(num_blocks):
532
+ curr_kwargs = {}
533
+ for k, v in kwargs.items():
534
+ if k.endswith("_per_block"):
535
+ assert len(v) == num_blocks, (
536
+ f"Argument '{k}' of make_stage should have the "
537
+ f"same length as num_blocks={num_blocks}."
538
+ )
539
+ newk = k[: -len("_per_block")]
540
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
541
+ curr_kwargs[newk] = v[i]
542
+ else:
543
+ curr_kwargs[k] = v
544
+
545
+ blocks.append(
546
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
547
+ )
548
+ in_channels = out_channels
549
+ return blocks
550
+
551
+ @staticmethod
552
+ def make_default_stages(depth, block_class=None, **kwargs):
553
+ """
554
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
555
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
556
+ instead for fine-grained customization.
557
+
558
+ Args:
559
+ depth (int): depth of ResNet
560
+ block_class (type): the CNN block class. Has to accept
561
+ `bottleneck_channels` argument for depth > 50.
562
+ By default it is BasicBlock or BottleneckBlock, based on the
563
+ depth.
564
+ kwargs:
565
+ other arguments to pass to `make_stage`. Should not contain
566
+ stride and channels, as they are predefined for each depth.
567
+
568
+ Returns:
569
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
570
+ :class:`ResNet.__init__`.
571
+ """
572
+ num_blocks_per_stage = {
573
+ 18: [2, 2, 2, 2],
574
+ 34: [3, 4, 6, 3],
575
+ 50: [3, 4, 6, 3],
576
+ 101: [3, 4, 23, 3],
577
+ 152: [3, 8, 36, 3],
578
+ }[depth]
579
+ if block_class is None:
580
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
581
+ if depth < 50:
582
+ in_channels = [64, 64, 128, 256]
583
+ out_channels = [64, 128, 256, 512]
584
+ else:
585
+ in_channels = [64, 256, 512, 1024]
586
+ out_channels = [256, 512, 1024, 2048]
587
+ ret = []
588
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
589
+ if depth >= 50:
590
+ kwargs["bottleneck_channels"] = o // 4
591
+ ret.append(
592
+ ResNet.make_stage(
593
+ block_class=block_class,
594
+ num_blocks=n,
595
+ stride_per_block=[s] + [1] * (n - 1),
596
+ in_channels=i,
597
+ out_channels=o,
598
+ **kwargs,
599
+ )
600
+ )
601
+ return ret
602
+
603
+
604
+ ResNetBlockBase = CNNBlockBase
605
+ """
606
+ Alias for backward compatibiltiy.
607
+ """
608
+
609
+
610
+ def make_stage(*args, **kwargs):
611
+ """
612
+ Deprecated alias for backward compatibiltiy.
613
+ """
614
+ return ResNet.make_stage(*args, **kwargs)
615
+
616
+
617
+ def _convert_ndarray_to_tensor(state_dict: Dict[str, Any]) -> None:
618
+ """
619
+ In-place convert all numpy arrays in the state_dict to torch tensor.
620
+ Args:
621
+ state_dict (dict): a state-dict to be loaded to the model.
622
+ Will be modified.
623
+ """
624
+ # model could be an OrderedDict with _metadata attribute
625
+ # (as returned by Pytorch's state_dict()). We should preserve these
626
+ # properties.
627
+ for k in list(state_dict.keys()):
628
+ v = state_dict[k]
629
+ if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
630
+ raise ValueError(
631
+ "Unsupported type found in checkpoint! {}: {}".format(k, type(v))
632
+ )
633
+ if not isinstance(v, torch.Tensor):
634
+ state_dict[k] = torch.from_numpy(v)
635
+
636
+
637
+ @register_backbone
638
+ def get_resnet_backbone(cfg):
639
+ """
640
+ Create a ResNet instance from config.
641
+
642
+ Returns:
643
+ ResNet: a :class:`ResNet` instance.
644
+ """
645
+ res_cfg = cfg['MODEL']['BACKBONE']['RESNETS']
646
+
647
+ # need registration of new blocks/stems?
648
+ norm = res_cfg['NORM']
649
+ stem = BasicStem(
650
+ in_channels=res_cfg['STEM_IN_CHANNELS'],
651
+ out_channels=res_cfg['STEM_OUT_CHANNELS'],
652
+ norm=norm,
653
+ )
654
+
655
+ # fmt: off
656
+ freeze_at = res_cfg['FREEZE_AT']
657
+ out_features = res_cfg['OUT_FEATURES']
658
+ depth = res_cfg['DEPTH']
659
+ num_groups = res_cfg['NUM_GROUPS']
660
+ width_per_group = res_cfg['WIDTH_PER_GROUP']
661
+ bottleneck_channels = num_groups * width_per_group
662
+ in_channels = res_cfg['STEM_OUT_CHANNELS']
663
+ out_channels = res_cfg['RES2_OUT_CHANNELS']
664
+ stride_in_1x1 = res_cfg['STRIDE_IN_1X1']
665
+ res5_dilation = res_cfg['RES5_DILATION']
666
+ deform_on_per_stage = res_cfg['DEFORM_ON_PER_STAGE']
667
+ deform_modulated = res_cfg['DEFORM_MODULATED']
668
+ deform_num_groups = res_cfg['DEFORM_NUM_GROUPS']
669
+ # fmt: on
670
+ assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
671
+
672
+ num_blocks_per_stage = {
673
+ 18: [2, 2, 2, 2],
674
+ 34: [3, 4, 6, 3],
675
+ 50: [3, 4, 6, 3],
676
+ 101: [3, 4, 23, 3],
677
+ 152: [3, 8, 36, 3],
678
+ }[depth]
679
+
680
+ if depth in [18, 34]:
681
+ assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
682
+ assert not any(
683
+ deform_on_per_stage
684
+ ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
685
+ assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
686
+ assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
687
+
688
+ stages = []
689
+
690
+ for idx, stage_idx in enumerate(range(2, 6)):
691
+ # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
692
+ dilation = res5_dilation if stage_idx == 5 else 1
693
+ first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
694
+ stage_kargs = {
695
+ "num_blocks": num_blocks_per_stage[idx],
696
+ "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
697
+ "in_channels": in_channels,
698
+ "out_channels": out_channels,
699
+ "norm": norm,
700
+ }
701
+ # Use BasicBlock for R18 and R34.
702
+ if depth in [18, 34]:
703
+ stage_kargs["block_class"] = BasicBlock
704
+ else:
705
+ stage_kargs["bottleneck_channels"] = bottleneck_channels
706
+ stage_kargs["stride_in_1x1"] = stride_in_1x1
707
+ stage_kargs["dilation"] = dilation
708
+ stage_kargs["num_groups"] = num_groups
709
+ if deform_on_per_stage[idx]:
710
+ stage_kargs["block_class"] = DeformBottleneckBlock
711
+ stage_kargs["deform_modulated"] = deform_modulated
712
+ stage_kargs["deform_num_groups"] = deform_num_groups
713
+ else:
714
+ stage_kargs["block_class"] = BottleneckBlock
715
+ blocks = ResNet.make_stage(**stage_kargs)
716
+ in_channels = out_channels
717
+ out_channels *= 2
718
+ bottleneck_channels *= 2
719
+ stages.append(blocks)
720
+ backbone = ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
721
+
722
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
723
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
724
+ with PathManager.open(filename, "rb") as f:
725
+ ckpt = pickle.load(f, encoding="latin1")['model']
726
+ _convert_ndarray_to_tensor(ckpt)
727
+ ckpt.pop('stem.fc.weight')
728
+ ckpt.pop('stem.fc.bias')
729
+ backbone.load_state_dict(ckpt)
730
+
731
+ return backbone
GLEE/glee/backbone/swin.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
+
16
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
17
+ from .registry import register_backbone
18
+
19
+
20
+ class Mlp(nn.Module):
21
+ """Multilayer perceptron."""
22
+
23
+ def __init__(
24
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
25
+ ):
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
41
+
42
+
43
+ def window_partition(x, window_size):
44
+ """
45
+ Args:
46
+ x: (B, H, W, C)
47
+ window_size (int): window size
48
+ Returns:
49
+ windows: (num_windows*B, window_size, window_size, C)
50
+ """
51
+ B, H, W, C = x.shape
52
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
53
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
54
+ return windows
55
+
56
+
57
+ def window_reverse(windows, window_size, H, W):
58
+ """
59
+ Args:
60
+ windows: (num_windows*B, window_size, window_size, C)
61
+ window_size (int): Window size
62
+ H (int): Height of image
63
+ W (int): Width of image
64
+ Returns:
65
+ x: (B, H, W, C)
66
+ """
67
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
68
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
69
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
70
+ return x
71
+
72
+
73
+ class WindowAttention(nn.Module):
74
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
75
+ It supports both of shifted and non-shifted window.
76
+ Args:
77
+ dim (int): Number of input channels.
78
+ window_size (tuple[int]): The height and width of the window.
79
+ num_heads (int): Number of attention heads.
80
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
81
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
82
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
83
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ window_size,
90
+ num_heads,
91
+ qkv_bias=True,
92
+ qk_scale=None,
93
+ attn_drop=0.0,
94
+ proj_drop=0.0,
95
+ ):
96
+
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.window_size = window_size # Wh, Ww
100
+ self.num_heads = num_heads
101
+ head_dim = dim // num_heads
102
+ self.scale = qk_scale or head_dim ** -0.5
103
+
104
+ # define a parameter table of relative position bias
105
+ self.relative_position_bias_table = nn.Parameter(
106
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
107
+ ) # 2*Wh-1 * 2*Ww-1, nH
108
+
109
+ # get pair-wise relative position index for each token inside the window
110
+ coords_h = torch.arange(self.window_size[0])
111
+ coords_w = torch.arange(self.window_size[1])
112
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
113
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
114
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
115
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
116
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
117
+ relative_coords[:, :, 1] += self.window_size[1] - 1
118
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
119
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
120
+ self.register_buffer("relative_position_index", relative_position_index)
121
+
122
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
123
+ self.attn_drop = nn.Dropout(attn_drop)
124
+ self.proj = nn.Linear(dim, dim)
125
+ self.proj_drop = nn.Dropout(proj_drop)
126
+
127
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
128
+ self.softmax = nn.Softmax(dim=-1)
129
+
130
+ def forward(self, x, mask=None):
131
+ """Forward function.
132
+ Args:
133
+ x: input features with shape of (num_windows*B, N, C)
134
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
135
+ """
136
+ B_, N, C = x.shape
137
+ qkv = (
138
+ self.qkv(x)
139
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
140
+ .permute(2, 0, 3, 1, 4)
141
+ )
142
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
143
+
144
+ q = q * self.scale
145
+ attn = q @ k.transpose(-2, -1)
146
+
147
+ relative_position_bias = self.relative_position_bias_table[
148
+ self.relative_position_index.view(-1)
149
+ ].view(
150
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
151
+ ) # Wh*Ww,Wh*Ww,nH
152
+ relative_position_bias = relative_position_bias.permute(
153
+ 2, 0, 1
154
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
155
+ attn = attn + relative_position_bias.unsqueeze(0)
156
+
157
+ if mask is not None:
158
+ nW = mask.shape[0]
159
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
160
+ attn = attn.view(-1, self.num_heads, N, N)
161
+ attn = self.softmax(attn)
162
+ else:
163
+ attn = self.softmax(attn)
164
+
165
+ attn = self.attn_drop(attn)
166
+
167
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
168
+ x = self.proj(x)
169
+ x = self.proj_drop(x)
170
+ return x
171
+
172
+
173
+ class SwinTransformerBlock(nn.Module):
174
+ """Swin Transformer Block.
175
+ Args:
176
+ dim (int): Number of input channels.
177
+ num_heads (int): Number of attention heads.
178
+ window_size (int): Window size.
179
+ shift_size (int): Shift size for SW-MSA.
180
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
181
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
182
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
183
+ drop (float, optional): Dropout rate. Default: 0.0
184
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
185
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
186
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
187
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ dim,
193
+ num_heads,
194
+ window_size=7,
195
+ shift_size=0,
196
+ mlp_ratio=4.0,
197
+ qkv_bias=True,
198
+ qk_scale=None,
199
+ drop=0.0,
200
+ attn_drop=0.0,
201
+ drop_path=0.0,
202
+ act_layer=nn.GELU,
203
+ norm_layer=nn.LayerNorm,
204
+ ):
205
+ super().__init__()
206
+ self.dim = dim
207
+ self.num_heads = num_heads
208
+ self.window_size = window_size
209
+ self.shift_size = shift_size
210
+ self.mlp_ratio = mlp_ratio
211
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
212
+
213
+ self.norm1 = norm_layer(dim)
214
+ self.attn = WindowAttention(
215
+ dim,
216
+ window_size=to_2tuple(self.window_size),
217
+ num_heads=num_heads,
218
+ qkv_bias=qkv_bias,
219
+ qk_scale=qk_scale,
220
+ attn_drop=attn_drop,
221
+ proj_drop=drop,
222
+ )
223
+
224
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
225
+ self.norm2 = norm_layer(dim)
226
+ mlp_hidden_dim = int(dim * mlp_ratio)
227
+ self.mlp = Mlp(
228
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
229
+ )
230
+
231
+ self.H = None
232
+ self.W = None
233
+
234
+ def forward(self, x, mask_matrix):
235
+ """Forward function.
236
+ Args:
237
+ x: Input feature, tensor size (B, H*W, C).
238
+ H, W: Spatial resolution of the input feature.
239
+ mask_matrix: Attention mask for cyclic shift.
240
+ """
241
+ B, L, C = x.shape
242
+ H, W = self.H, self.W
243
+ assert L == H * W, "input feature has wrong size"
244
+
245
+ shortcut = x
246
+ x = self.norm1(x)
247
+ x = x.view(B, H, W, C)
248
+
249
+ # pad feature maps to multiples of window size
250
+ pad_l = pad_t = 0
251
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
252
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
253
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
254
+ _, Hp, Wp, _ = x.shape
255
+
256
+ # cyclic shift
257
+ if self.shift_size > 0:
258
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
259
+ attn_mask = mask_matrix
260
+ else:
261
+ shifted_x = x
262
+ attn_mask = None
263
+
264
+ # partition windows
265
+ x_windows = window_partition(
266
+ shifted_x, self.window_size
267
+ ) # nW*B, window_size, window_size, C
268
+ x_windows = x_windows.view(
269
+ -1, self.window_size * self.window_size, C
270
+ ) # nW*B, window_size*window_size, C
271
+
272
+ # W-MSA/SW-MSA
273
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
274
+
275
+ # merge windows
276
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
277
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
278
+
279
+ # reverse cyclic shift
280
+ if self.shift_size > 0:
281
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
282
+ else:
283
+ x = shifted_x
284
+
285
+ if pad_r > 0 or pad_b > 0:
286
+ x = x[:, :H, :W, :].contiguous()
287
+
288
+ x = x.view(B, H * W, C)
289
+
290
+ # FFN
291
+ x = shortcut + self.drop_path(x)
292
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
293
+
294
+ return x
295
+
296
+
297
+ class PatchMerging(nn.Module):
298
+ """Patch Merging Layer
299
+ Args:
300
+ dim (int): Number of input channels.
301
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
302
+ """
303
+
304
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
305
+ super().__init__()
306
+ self.dim = dim
307
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
308
+ self.norm = norm_layer(4 * dim)
309
+
310
+ def forward(self, x, H, W):
311
+ """Forward function.
312
+ Args:
313
+ x: Input feature, tensor size (B, H*W, C).
314
+ H, W: Spatial resolution of the input feature.
315
+ """
316
+ B, L, C = x.shape
317
+ assert L == H * W, "input feature has wrong size"
318
+
319
+ x = x.view(B, H, W, C)
320
+
321
+ # padding
322
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
323
+ if pad_input:
324
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
325
+
326
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
327
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
328
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
329
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
330
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
331
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
332
+
333
+ x = self.norm(x)
334
+ x = self.reduction(x)
335
+
336
+ return x
337
+
338
+
339
+ class BasicLayer(nn.Module):
340
+ """A basic Swin Transformer layer for one stage.
341
+ Args:
342
+ dim (int): Number of feature channels
343
+ depth (int): Depths of this stage.
344
+ num_heads (int): Number of attention head.
345
+ window_size (int): Local window size. Default: 7.
346
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
347
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
348
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
349
+ drop (float, optional): Dropout rate. Default: 0.0
350
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
351
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
352
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
353
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
354
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ dim,
360
+ depth,
361
+ num_heads,
362
+ window_size=7,
363
+ mlp_ratio=4.0,
364
+ qkv_bias=True,
365
+ qk_scale=None,
366
+ drop=0.0,
367
+ attn_drop=0.0,
368
+ drop_path=0.0,
369
+ norm_layer=nn.LayerNorm,
370
+ downsample=None,
371
+ use_checkpoint=False,
372
+ ):
373
+ super().__init__()
374
+ self.window_size = window_size
375
+ self.shift_size = window_size // 2
376
+ self.depth = depth
377
+ self.use_checkpoint = use_checkpoint
378
+
379
+ # build blocks
380
+ self.blocks = nn.ModuleList(
381
+ [
382
+ SwinTransformerBlock(
383
+ dim=dim,
384
+ num_heads=num_heads,
385
+ window_size=window_size,
386
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
387
+ mlp_ratio=mlp_ratio,
388
+ qkv_bias=qkv_bias,
389
+ qk_scale=qk_scale,
390
+ drop=drop,
391
+ attn_drop=attn_drop,
392
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
393
+ norm_layer=norm_layer,
394
+ )
395
+ for i in range(depth)
396
+ ]
397
+ )
398
+
399
+ # patch merging layer
400
+ if downsample is not None:
401
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
402
+ else:
403
+ self.downsample = None
404
+
405
+ def forward(self, x, H, W):
406
+ """Forward function.
407
+ Args:
408
+ x: Input feature, tensor size (B, H*W, C).
409
+ H, W: Spatial resolution of the input feature.
410
+ """
411
+
412
+ # calculate attention mask for SW-MSA
413
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
414
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
415
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
416
+ h_slices = (
417
+ slice(0, -self.window_size),
418
+ slice(-self.window_size, -self.shift_size),
419
+ slice(-self.shift_size, None),
420
+ )
421
+ w_slices = (
422
+ slice(0, -self.window_size),
423
+ slice(-self.window_size, -self.shift_size),
424
+ slice(-self.shift_size, None),
425
+ )
426
+ cnt = 0
427
+ for h in h_slices:
428
+ for w in w_slices:
429
+ img_mask[:, h, w, :] = cnt
430
+ cnt += 1
431
+
432
+ mask_windows = window_partition(
433
+ img_mask, self.window_size
434
+ ) # nW, window_size, window_size, 1
435
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
436
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
437
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
438
+ attn_mask == 0, float(0.0)
439
+ )
440
+
441
+ for blk in self.blocks:
442
+ blk.H, blk.W = H, W
443
+ if self.use_checkpoint:
444
+ x = checkpoint.checkpoint(blk, x, attn_mask)
445
+ else:
446
+ x = blk(x, attn_mask)
447
+ if self.downsample is not None:
448
+ x_down = self.downsample(x, H, W)
449
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
450
+ return x, H, W, x_down, Wh, Ww
451
+ else:
452
+ return x, H, W, x, H, W
453
+
454
+
455
+ class PatchEmbed(nn.Module):
456
+ """Image to Patch Embedding
457
+ Args:
458
+ patch_size (int): Patch token size. Default: 4.
459
+ in_chans (int): Number of input image channels. Default: 3.
460
+ embed_dim (int): Number of linear projection output channels. Default: 96.
461
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
462
+ """
463
+
464
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
465
+ super().__init__()
466
+ patch_size = to_2tuple(patch_size)
467
+ self.patch_size = patch_size
468
+
469
+ self.in_chans = in_chans
470
+ self.embed_dim = embed_dim
471
+
472
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
473
+ if norm_layer is not None:
474
+ self.norm = norm_layer(embed_dim)
475
+ else:
476
+ self.norm = None
477
+
478
+ def forward(self, x):
479
+ """Forward function."""
480
+ # padding
481
+ _, _, H, W = x.size()
482
+ if W % self.patch_size[1] != 0:
483
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
484
+ if H % self.patch_size[0] != 0:
485
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
486
+
487
+ x = self.proj(x) # B C Wh Ww
488
+ if self.norm is not None:
489
+ Wh, Ww = x.size(2), x.size(3)
490
+ x = x.flatten(2).transpose(1, 2)
491
+ x = self.norm(x)
492
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
493
+
494
+ return x
495
+
496
+
497
+ class SwinTransformer(nn.Module):
498
+ """Swin Transformer backbone.
499
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
500
+ https://arxiv.org/pdf/2103.14030
501
+ Args:
502
+ pretrain_img_size (int): Input image size for training the pretrained model,
503
+ used in absolute postion embedding. Default 224.
504
+ patch_size (int | tuple(int)): Patch size. Default: 4.
505
+ in_chans (int): Number of input image channels. Default: 3.
506
+ embed_dim (int): Number of linear projection output channels. Default: 96.
507
+ depths (tuple[int]): Depths of each Swin Transformer stage.
508
+ num_heads (tuple[int]): Number of attention head of each stage.
509
+ window_size (int): Window size. Default: 7.
510
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
511
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
512
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
513
+ drop_rate (float): Dropout rate.
514
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
515
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
516
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
517
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
518
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
519
+ out_indices (Sequence[int]): Output from which stages.
520
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
521
+ -1 means not freezing any parameters.
522
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
523
+ """
524
+
525
+ def __init__(
526
+ self,
527
+ pretrain_img_size=224,
528
+ patch_size=4,
529
+ in_chans=3,
530
+ embed_dim=96,
531
+ depths=[2, 2, 6, 2],
532
+ num_heads=[3, 6, 12, 24],
533
+ window_size=7,
534
+ mlp_ratio=4.0,
535
+ qkv_bias=True,
536
+ qk_scale=None,
537
+ drop_rate=0.0,
538
+ attn_drop_rate=0.0,
539
+ drop_path_rate=0.2,
540
+ norm_layer=nn.LayerNorm,
541
+ ape=False,
542
+ patch_norm=True,
543
+ out_indices=(0, 1, 2, 3),
544
+ frozen_stages=-1,
545
+ use_checkpoint=False,
546
+ ):
547
+ super().__init__()
548
+
549
+ self.pretrain_img_size = pretrain_img_size
550
+ self.num_layers = len(depths)
551
+ self.embed_dim = embed_dim
552
+ self.ape = ape
553
+ self.patch_norm = patch_norm
554
+ self.out_indices = out_indices
555
+ self.frozen_stages = frozen_stages
556
+
557
+ # split image into non-overlapping patches
558
+ self.patch_embed = PatchEmbed(
559
+ patch_size=patch_size,
560
+ in_chans=in_chans,
561
+ embed_dim=embed_dim,
562
+ norm_layer=norm_layer if self.patch_norm else None,
563
+ )
564
+
565
+ # absolute position embedding
566
+ if self.ape:
567
+ pretrain_img_size = to_2tuple(pretrain_img_size)
568
+ patch_size = to_2tuple(patch_size)
569
+ patches_resolution = [
570
+ pretrain_img_size[0] // patch_size[0],
571
+ pretrain_img_size[1] // patch_size[1],
572
+ ]
573
+
574
+ self.absolute_pos_embed = nn.Parameter(
575
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
576
+ )
577
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
578
+
579
+ self.pos_drop = nn.Dropout(p=drop_rate)
580
+
581
+ # stochastic depth
582
+ dpr = [
583
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
584
+ ] # stochastic depth decay rule
585
+
586
+ # build layers
587
+ self.layers = nn.ModuleList()
588
+ for i_layer in range(self.num_layers):
589
+ layer = BasicLayer(
590
+ dim=int(embed_dim * 2 ** i_layer),
591
+ depth=depths[i_layer],
592
+ num_heads=num_heads[i_layer],
593
+ window_size=window_size,
594
+ mlp_ratio=mlp_ratio,
595
+ qkv_bias=qkv_bias,
596
+ qk_scale=qk_scale,
597
+ drop=drop_rate,
598
+ attn_drop=attn_drop_rate,
599
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
600
+ norm_layer=norm_layer,
601
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
602
+ use_checkpoint=use_checkpoint,
603
+ )
604
+ self.layers.append(layer)
605
+
606
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
607
+ self.num_features = num_features
608
+
609
+ # add a norm layer for each output
610
+ for i_layer in out_indices:
611
+ layer = norm_layer(num_features[i_layer])
612
+ layer_name = f"norm{i_layer}"
613
+ self.add_module(layer_name, layer)
614
+
615
+ self._freeze_stages()
616
+
617
+ def _freeze_stages(self):
618
+ if self.frozen_stages >= 0:
619
+ self.patch_embed.eval()
620
+ for param in self.patch_embed.parameters():
621
+ param.requires_grad = False
622
+
623
+ if self.frozen_stages >= 1 and self.ape:
624
+ self.absolute_pos_embed.requires_grad = False
625
+
626
+ if self.frozen_stages >= 2:
627
+ self.pos_drop.eval()
628
+ for i in range(0, self.frozen_stages - 1):
629
+ m = self.layers[i]
630
+ m.eval()
631
+ for param in m.parameters():
632
+ param.requires_grad = False
633
+
634
+ def init_weights(self, pretrained=None):
635
+ """Initialize the weights in backbone.
636
+ Args:
637
+ pretrained (str, optional): Path to pre-trained weights.
638
+ Defaults to None.
639
+ """
640
+
641
+ def _init_weights(m):
642
+ if isinstance(m, nn.Linear):
643
+ trunc_normal_(m.weight, std=0.02)
644
+ if isinstance(m, nn.Linear) and m.bias is not None:
645
+ nn.init.constant_(m.bias, 0)
646
+ elif isinstance(m, nn.LayerNorm):
647
+ nn.init.constant_(m.bias, 0)
648
+ nn.init.constant_(m.weight, 1.0)
649
+
650
+ if isinstance(pretrained, str):
651
+ self.apply(_init_weights)
652
+ checkpoint = torch.load(pretrained, map_location='cpu')
653
+ print(f'\nload pretrain weight from {pretrained} \n')
654
+ self.load_state_dict(checkpoint['model'], strict=False)
655
+ elif pretrained is None:
656
+ self.apply(_init_weights)
657
+ else:
658
+ raise TypeError('pretrained must be a str or None')
659
+
660
+ def forward(self, x):
661
+ """Forward function."""
662
+ x = self.patch_embed(x)
663
+
664
+ Wh, Ww = x.size(2), x.size(3)
665
+ if self.ape:
666
+ # interpolate the position embedding to the corresponding size
667
+ absolute_pos_embed = F.interpolate(
668
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
669
+ )
670
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
671
+ else:
672
+ x = x.flatten(2).transpose(1, 2)
673
+ x = self.pos_drop(x)
674
+
675
+ outs = {}
676
+ for i in range(self.num_layers):
677
+ layer = self.layers[i]
678
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
679
+
680
+ if i in self.out_indices:
681
+ norm_layer = getattr(self, f"norm{i}")
682
+ x_out = norm_layer(x_out)
683
+
684
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
685
+ outs["res{}".format(i + 2)] = out
686
+
687
+ return outs
688
+
689
+ def train(self, mode=True):
690
+ """Convert the model into training mode while keep layers freezed."""
691
+ super(SwinTransformer, self).train(mode)
692
+ self._freeze_stages()
693
+
694
+
695
+ @BACKBONE_REGISTRY.register()
696
+ class D2SwinTransformer(SwinTransformer, Backbone):
697
+ def __init__(self, cfg, input_shape):
698
+
699
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
700
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
701
+ in_chans = 3
702
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
703
+ depths = cfg.MODEL.SWIN.DEPTHS
704
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
705
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
706
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
707
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
708
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
709
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
710
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
711
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
712
+ norm_layer = nn.LayerNorm
713
+ ape = cfg.MODEL.SWIN.APE
714
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
715
+ use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT
716
+ pretrained_weight = cfg.MODEL.SWIN.PRETRAINED_WEIGHT
717
+
718
+
719
+ super().__init__(
720
+ pretrain_img_size,
721
+ patch_size,
722
+ in_chans,
723
+ embed_dim,
724
+ depths,
725
+ num_heads,
726
+ window_size,
727
+ mlp_ratio,
728
+ qkv_bias,
729
+ qk_scale,
730
+ drop_rate,
731
+ attn_drop_rate,
732
+ drop_path_rate,
733
+ norm_layer,
734
+ ape,
735
+ patch_norm,
736
+ use_checkpoint=use_checkpoint,
737
+ )
738
+ self.init_weights(pretrained_weight)
739
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
740
+
741
+ self._out_feature_strides = {
742
+ "res2": 4,
743
+ "res3": 8,
744
+ "res4": 16,
745
+ "res5": 32,
746
+ }
747
+ self._out_feature_channels = {
748
+ "res2": self.num_features[0],
749
+ "res3": self.num_features[1],
750
+ "res4": self.num_features[2],
751
+ "res5": self.num_features[3],
752
+ }
753
+
754
+ def forward(self, x):
755
+ """
756
+ Args:
757
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
758
+ Returns:
759
+ dict[str->Tensor]: names and the corresponding features
760
+ """
761
+ assert (
762
+ x.dim() == 4
763
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
764
+ outputs = {}
765
+ y = super().forward(x)
766
+ for k in y.keys():
767
+ if k in self._out_features:
768
+ outputs[k] = y[k]
769
+ return outputs
770
+
771
+ def output_shape(self):
772
+ return {
773
+ name: ShapeSpec(
774
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
775
+ )
776
+ for name in self._out_features
777
+ }
778
+
779
+ @property
780
+ def size_divisibility(self):
781
+ return 32
782
+
783
+
GLEE/glee/backbone/vit.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import fvcore.nn.weight_init as weight_init
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from detectron2.layers import CNNBlockBase, Conv2d, get_norm
8
+ from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
9
+ import torch.nn.functional as F
10
+
11
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
12
+ from .utils import (
13
+ PatchEmbed,
14
+ add_decomposed_rel_pos,
15
+ get_abs_pos,
16
+ window_partition,
17
+ window_unpartition,
18
+ )
19
+ from functools import partial
20
+ import torch.utils.checkpoint as checkpoint
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ __all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]
26
+
27
+
28
+ class Attention(nn.Module):
29
+ """Multi-head Attention block with relative position embeddings."""
30
+
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ num_heads=8,
35
+ qkv_bias=True,
36
+ use_rel_pos=False,
37
+ rel_pos_zero_init=True,
38
+ input_size=None,
39
+ ):
40
+ """
41
+ Args:
42
+ dim (int): Number of input channels.
43
+ num_heads (int): Number of attention heads.
44
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
45
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
46
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
47
+ input_size (int or None): Input resolution for calculating the relative positional
48
+ parameter size.
49
+ """
50
+ super().__init__()
51
+ self.num_heads = num_heads
52
+ head_dim = dim // num_heads
53
+ self.scale = head_dim**-0.5
54
+
55
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
56
+ self.proj = nn.Linear(dim, dim)
57
+
58
+ self.use_rel_pos = use_rel_pos
59
+ if self.use_rel_pos:
60
+ # initialize relative positional embeddings
61
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
62
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
63
+
64
+ if not rel_pos_zero_init:
65
+ nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
66
+ nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
67
+
68
+ def forward(self, x):
69
+ B, H, W, _ = x.shape
70
+ # qkv with shape (3, B, nHead, H * W, C)
71
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
72
+ # q, k, v with shape (B * nHead, H * W, C)
73
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
74
+
75
+ with torch.backends.cuda.sdp_kernel(
76
+ enable_flash=True, enable_math=False, enable_mem_efficient=True
77
+ ):
78
+ x = F.scaled_dot_product_attention(q, k, v)
79
+ attn = (q * self.scale) @ k.transpose(-2, -1)
80
+
81
+ if self.use_rel_pos:
82
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
83
+
84
+ attn = attn.softmax(dim=-1)
85
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
86
+ x = self.proj(x)
87
+
88
+ return x
89
+
90
+
91
+ class ResBottleneckBlock(CNNBlockBase):
92
+ """
93
+ The standard bottleneck residual block without the last activation layer.
94
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ in_channels,
100
+ out_channels,
101
+ bottleneck_channels,
102
+ norm="LN",
103
+ act_layer=nn.GELU,
104
+ ):
105
+ """
106
+ Args:
107
+ in_channels (int): Number of input channels.
108
+ out_channels (int): Number of output channels.
109
+ bottleneck_channels (int): number of output channels for the 3x3
110
+ "bottleneck" conv layers.
111
+ norm (str or callable): normalization for all conv layers.
112
+ See :func:`layers.get_norm` for supported format.
113
+ act_layer (callable): activation for all conv layers.
114
+ """
115
+ super().__init__(in_channels, out_channels, 1)
116
+
117
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
118
+ self.norm1 = get_norm(norm, bottleneck_channels)
119
+ self.act1 = act_layer()
120
+
121
+ self.conv2 = Conv2d(
122
+ bottleneck_channels,
123
+ bottleneck_channels,
124
+ 3,
125
+ padding=1,
126
+ bias=False,
127
+ )
128
+ self.norm2 = get_norm(norm, bottleneck_channels)
129
+ self.act2 = act_layer()
130
+
131
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
132
+ self.norm3 = get_norm(norm, out_channels)
133
+
134
+ for layer in [self.conv1, self.conv2, self.conv3]:
135
+ weight_init.c2_msra_fill(layer)
136
+ for layer in [self.norm1, self.norm2]:
137
+ layer.weight.data.fill_(1.0)
138
+ layer.bias.data.zero_()
139
+ # zero init last norm layer.
140
+ self.norm3.weight.data.zero_()
141
+ self.norm3.bias.data.zero_()
142
+
143
+ def forward(self, x):
144
+ out = x
145
+ for layer in self.children():
146
+ out = layer(out)
147
+
148
+ out = x + out
149
+ return out
150
+
151
+
152
+ class Block(nn.Module):
153
+ """Transformer blocks with support of window attention and residual propagation blocks"""
154
+
155
+ def __init__(
156
+ self,
157
+ dim,
158
+ num_heads,
159
+ mlp_ratio=4.0,
160
+ qkv_bias=True,
161
+ drop_path=0.0,
162
+ norm_layer=nn.LayerNorm,
163
+ act_layer=nn.GELU,
164
+ use_rel_pos=False,
165
+ rel_pos_zero_init=True,
166
+ window_size=0,
167
+ use_residual_block=False,
168
+ input_size=None,
169
+ ):
170
+ """
171
+ Args:
172
+ dim (int): Number of input channels.
173
+ num_heads (int): Number of attention heads in each ViT block.
174
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
175
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
176
+ drop_path (float): Stochastic depth rate.
177
+ norm_layer (nn.Module): Normalization layer.
178
+ act_layer (nn.Module): Activation layer.
179
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
180
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
181
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
182
+ use window attention.
183
+ use_residual_block (bool): If True, use a residual block after the MLP block.
184
+ input_size (int or None): Input resolution for calculating the relative positional
185
+ parameter size.
186
+ """
187
+ super().__init__()
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = Attention(
190
+ dim,
191
+ num_heads=num_heads,
192
+ qkv_bias=qkv_bias,
193
+ use_rel_pos=use_rel_pos,
194
+ rel_pos_zero_init=rel_pos_zero_init,
195
+ input_size=input_size if window_size == 0 else (window_size, window_size),
196
+ )
197
+
198
+ from timm.models.layers import DropPath, Mlp
199
+
200
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
201
+ self.norm2 = norm_layer(dim)
202
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
203
+
204
+ self.window_size = window_size
205
+
206
+ self.use_residual_block = use_residual_block
207
+ if use_residual_block:
208
+ # Use a residual block with bottleneck channel as dim // 2
209
+ self.residual = ResBottleneckBlock(
210
+ in_channels=dim,
211
+ out_channels=dim,
212
+ bottleneck_channels=dim // 2,
213
+ norm="LN",
214
+ act_layer=act_layer,
215
+ )
216
+
217
+ def forward(self, x):
218
+ shortcut = x
219
+ x = self.norm1(x)
220
+ # Window partition
221
+ if self.window_size > 0:
222
+ H, W = x.shape[1], x.shape[2]
223
+ x, pad_hw = window_partition(x, self.window_size)
224
+ x = self.attn(x)
225
+ # Reverse window partition
226
+ if self.window_size > 0:
227
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
228
+
229
+ x = shortcut + self.drop_path(x)
230
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
231
+
232
+ if self.use_residual_block:
233
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
234
+
235
+ return x
236
+
237
+
238
+ class ViT(Backbone):
239
+ """
240
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
241
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
242
+ https://arxiv.org/abs/2203.16527
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ img_size=1024,
248
+ patch_size=16,
249
+ in_chans=3,
250
+ embed_dim=768,
251
+ depth=12,
252
+ num_heads=12,
253
+ mlp_ratio=4.0,
254
+ qkv_bias=True,
255
+ drop_path_rate=0.0,
256
+ norm_layer=nn.LayerNorm,
257
+ act_layer=nn.GELU,
258
+ use_abs_pos=True,
259
+ use_rel_pos=False,
260
+ rel_pos_zero_init=True,
261
+ window_size=0,
262
+ window_block_indexes=(),
263
+ residual_block_indexes=(),
264
+ use_act_checkpoint=False,
265
+ pretrain_img_size=224,
266
+ pretrain_use_cls_token=True,
267
+ out_feature="last_feat",
268
+ ):
269
+ """
270
+ Args:
271
+ img_size (int): Input image size.
272
+ patch_size (int): Patch size.
273
+ in_chans (int): Number of input image channels.
274
+ embed_dim (int): Patch embedding dimension.
275
+ depth (int): Depth of ViT.
276
+ num_heads (int): Number of attention heads in each ViT block.
277
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
278
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
279
+ drop_path_rate (float): Stochastic depth rate.
280
+ norm_layer (nn.Module): Normalization layer.
281
+ act_layer (nn.Module): Activation layer.
282
+ use_abs_pos (bool): If True, use absolute positional embeddings.
283
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
284
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
285
+ window_size (int): Window size for window attention blocks.
286
+ window_block_indexes (list): Indexes for blocks using window attention.
287
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
288
+ use_act_checkpoint (bool): If True, use activation checkpointing.
289
+ pretrain_img_size (int): input image size for pretraining models.
290
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
291
+ out_feature (str): name of the feature from the last block.
292
+ """
293
+ super().__init__()
294
+ self.pretrain_use_cls_token = pretrain_use_cls_token
295
+
296
+ self.patch_embed = PatchEmbed(
297
+ kernel_size=(patch_size, patch_size),
298
+ stride=(patch_size, patch_size),
299
+ in_chans=in_chans,
300
+ embed_dim=embed_dim,
301
+ )
302
+
303
+ if use_abs_pos:
304
+ # Initialize absolute positional embedding with pretrain image size.
305
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
306
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
307
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
308
+ else:
309
+ self.pos_embed = None
310
+
311
+ # stochastic depth decay rule
312
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
313
+
314
+ self.blocks = nn.ModuleList()
315
+ for i in range(depth):
316
+ block = Block(
317
+ dim=embed_dim,
318
+ num_heads=num_heads,
319
+ mlp_ratio=mlp_ratio,
320
+ qkv_bias=qkv_bias,
321
+ drop_path=dpr[i],
322
+ norm_layer=norm_layer,
323
+ act_layer=act_layer,
324
+ use_rel_pos=use_rel_pos,
325
+ rel_pos_zero_init=rel_pos_zero_init,
326
+ window_size=window_size if i in window_block_indexes else 0,
327
+ use_residual_block=i in residual_block_indexes,
328
+ input_size=(img_size // patch_size, img_size // patch_size),
329
+ )
330
+ if use_act_checkpoint:
331
+ # TODO: use torch.utils.checkpoint
332
+ from fairscale.nn.checkpoint import checkpoint_wrapper
333
+
334
+ block = checkpoint_wrapper(block)
335
+ self.blocks.append(block)
336
+
337
+ self._out_feature_channels = {out_feature: embed_dim}
338
+ self._out_feature_strides = {out_feature: patch_size}
339
+ self._out_features = [out_feature]
340
+
341
+ if self.pos_embed is not None:
342
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
343
+
344
+ # In our method, we don't use backbone feature with stride 4
345
+ self.fpn1 = nn.Sequential(
346
+ nn.ConvTranspose2d(embed_dim, embed_dim // 2, kernel_size=2, stride=2),
347
+ )
348
+ self.fpn2 = nn.Identity()
349
+ self.fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
350
+
351
+ self.apply(self._init_weights)
352
+
353
+ def _init_weights(self, m):
354
+ if isinstance(m, nn.Linear):
355
+ nn.init.trunc_normal_(m.weight, std=0.02)
356
+ if isinstance(m, nn.Linear) and m.bias is not None:
357
+ nn.init.constant_(m.bias, 0)
358
+ elif isinstance(m, nn.LayerNorm):
359
+ nn.init.constant_(m.bias, 0)
360
+ nn.init.constant_(m.weight, 1.0)
361
+
362
+ def forward(self, x):
363
+ x = self.patch_embed(x)
364
+ if self.pos_embed is not None:
365
+ x = x + get_abs_pos(
366
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
367
+ )
368
+
369
+ for blk in self.blocks:
370
+ x = blk(x)
371
+ xp = x.permute(0, 3, 1, 2) # (b, h, w, c) --> (b, c, h, w)
372
+
373
+ features = []
374
+ ops = [self.fpn1, self.fpn2, self.fpn3]
375
+ for i in range(len(ops)):
376
+ features.append(ops[i](xp))
377
+ rets = {"res{}".format(u + 3): v for (u,v) in enumerate(features)}
378
+
379
+ return rets
380
+
381
+
382
+
383
+ @BACKBONE_REGISTRY.register()
384
+ class D2ViT(ViT, Backbone):
385
+ def __init__(self, cfg, input_shape):
386
+ use_checkpoint = cfg.MODEL.VIT.USE_CHECKPOINT
387
+ if cfg.MODEL.VIT.NAME == "ViT-Base":
388
+ embed_dim=768
389
+ depth=12
390
+ drop_path_rate=0.1
391
+ num_heads=12
392
+ elif cfg.MODEL.VIT.NAME == "ViT-Large":
393
+ embed_dim=1024
394
+ depth=24
395
+ drop_path_rate=0.4
396
+ num_heads=16
397
+ elif cfg.MODEL.VIT.NAME == "ViT-huge":
398
+ embed_dim=1280
399
+ depth=32
400
+ drop_path_rate=0.5
401
+ num_heads=16
402
+ else:
403
+ raise ValueError("Unsupported ViT name")
404
+ super().__init__(
405
+ img_size=1024,
406
+ patch_size=16,
407
+ in_chans=input_shape.channels,
408
+ embed_dim=embed_dim,
409
+ depth=depth,
410
+ num_heads=num_heads,
411
+ drop_path_rate=drop_path_rate,
412
+ window_size=14,
413
+ mlp_ratio=4,
414
+ qkv_bias=True,
415
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
416
+ window_block_indexes=[
417
+ # 2, 5, 8 11 for global attention
418
+ 0,
419
+ 1,
420
+ 3,
421
+ 4,
422
+ 6,
423
+ 7,
424
+ 9,
425
+ 10,
426
+ ],
427
+ residual_block_indexes=[],
428
+ use_rel_pos=True,
429
+ out_feature="last_feat",
430
+ use_act_checkpoint=use_checkpoint)
431
+
432
+ self._out_features = cfg.MODEL.VIT.OUT_FEATURES
433
+
434
+ self._out_feature_strides = {
435
+ "res3": 8,
436
+ "res4": 16,
437
+ "res5": 32,
438
+ }
439
+ self._out_feature_channels = {
440
+ "res3": embed_dim // 2,
441
+ "res4": embed_dim,
442
+ "res5": embed_dim,
443
+ }
444
+
445
+ def forward(self, x):
446
+ """
447
+ Args:
448
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
449
+ Returns:
450
+ dict[str->Tensor]: names and the corresponding features
451
+ """
452
+ assert (
453
+ x.dim() == 4
454
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
455
+ outputs = {}
456
+ y = super().forward(x)
457
+ for k in y.keys():
458
+ if k in self._out_features:
459
+ outputs[k] = y[k]
460
+ return outputs
461
+
462
+ def output_shape(self):
463
+ return {
464
+ name: ShapeSpec(
465
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
466
+ )
467
+ for name in self._out_features
468
+ }
469
+
470
+ @property
471
+ def size_divisibility(self):
472
+ return 32
GLEE/glee/backbone/vit_utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import math
3
+ import numpy as np
4
+ from scipy import interpolate
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ __all__ = [
10
+ "window_partition",
11
+ "window_unpartition",
12
+ "add_decomposed_rel_pos",
13
+ "get_abs_pos",
14
+ "PatchEmbed",
15
+ ]
16
+
17
+
18
+ def window_partition(x, window_size):
19
+ """
20
+ Partition into non-overlapping windows with padding if needed.
21
+ Args:
22
+ x (tensor): input tokens with [B, H, W, C].
23
+ window_size (int): window size.
24
+
25
+ Returns:
26
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
27
+ (Hp, Wp): padded height and width before partition
28
+ """
29
+ B, H, W, C = x.shape
30
+
31
+ pad_h = (window_size - H % window_size) % window_size
32
+ pad_w = (window_size - W % window_size) % window_size
33
+ if pad_h > 0 or pad_w > 0:
34
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
35
+ Hp, Wp = H + pad_h, W + pad_w
36
+
37
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
38
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
39
+ return windows, (Hp, Wp)
40
+
41
+
42
+ def window_unpartition(windows, window_size, pad_hw, hw):
43
+ """
44
+ Window unpartition into original sequences and removing padding.
45
+ Args:
46
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
47
+ window_size (int): window size.
48
+ pad_hw (Tuple): padded height and width (Hp, Wp).
49
+ hw (Tuple): original height and width (H, W) before padding.
50
+
51
+ Returns:
52
+ x: unpartitioned sequences with [B, H, W, C].
53
+ """
54
+ Hp, Wp = pad_hw
55
+ H, W = hw
56
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
57
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ def get_rel_pos(q_size, k_size, rel_pos, interp_type):
66
+ """
67
+ Get relative positional embeddings according to the relative positions of
68
+ query and key sizes.
69
+ Args:
70
+ q_size (int): size of query q.
71
+ k_size (int): size of key k.
72
+ rel_pos (Tensor): relative position embeddings (L, C).
73
+
74
+ Returns:
75
+ Extracted positional embeddings according to relative positions.
76
+ """
77
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
78
+ # Interpolate rel pos if needed.
79
+ if rel_pos.shape[0] != max_rel_dist:
80
+ if interp_type == "vitdet":
81
+ # the vitdet impl:
82
+ # https://github.com/facebookresearch/detectron2/blob/96c752ce821a3340e27edd51c28a00665dd32a30/detectron2/modeling/backbone/utils.py#L77.
83
+
84
+ rel_pos_resized = F.interpolate(
85
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
86
+ size=max_rel_dist,
87
+ mode="linear",
88
+ )
89
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
90
+ elif interp_type == "beit":
91
+ # steal from beit https://github.com/microsoft/unilm/tree/master/beit
92
+ # modified by Yuxin Fang
93
+
94
+ src_size = rel_pos.shape[0]
95
+ dst_size = max_rel_dist
96
+
97
+ q = 1.0903078
98
+ dis = []
99
+
100
+ cur = 1
101
+ for i in range(src_size // 2):
102
+ dis.append(cur)
103
+ cur += q ** (i + 1)
104
+
105
+ r_ids = [-_ for _ in reversed(dis)]
106
+ x = r_ids + [0] + dis
107
+ t = dst_size // 2.0
108
+ dx = np.arange(-t, t + 0.1, 1.0)
109
+
110
+ all_rel_pos_bias = []
111
+ for i in range(rel_pos.shape[1]):
112
+ # a hack from https://github.com/baaivision/EVA/issues/8,
113
+ # could also be used in fine-tuning but the performance haven't been tested.
114
+ z = rel_pos[:, i].view(src_size).cpu().float().detach().numpy()
115
+ f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
116
+ all_rel_pos_bias.append(
117
+ torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
118
+ rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
119
+ else:
120
+ raise NotImplementedError()
121
+ else:
122
+ rel_pos_resized = rel_pos
123
+
124
+ # Scale the coords with short length if shapes for q and k are different.
125
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
126
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
127
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
128
+
129
+ return rel_pos_resized[relative_coords.long()]
130
+
131
+
132
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size, interp_type):
133
+ """
134
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
135
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
136
+ Args:
137
+ attn (Tensor): attention map.
138
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
139
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
140
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
141
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
142
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
143
+
144
+ Returns:
145
+ attn (Tensor): attention map with added relative positional embeddings.
146
+ """
147
+ q_h, q_w = q_size
148
+ k_h, k_w = k_size
149
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h, interp_type)
150
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w, interp_type)
151
+
152
+ B, _, dim = q.shape
153
+ r_q = q.reshape(B, q_h, q_w, dim)
154
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
155
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
156
+
157
+ attn = (
158
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
159
+ ).view(B, q_h * q_w, k_h * k_w)
160
+
161
+ return attn
162
+
163
+
164
+ def get_abs_pos(abs_pos, has_cls_token, hw):
165
+ """
166
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
167
+ dimension for the original embeddings.
168
+ Args:
169
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
170
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
171
+ hw (Tuple): size of input image tokens.
172
+
173
+ Returns:
174
+ Absolute positional embeddings after processing with shape (1, H, W, C)
175
+ """
176
+ h, w = hw
177
+ if has_cls_token:
178
+ abs_pos = abs_pos[:, 1:]
179
+ xy_num = abs_pos.shape[1]
180
+ size = int(math.sqrt(xy_num))
181
+ assert size * size == xy_num
182
+
183
+ if size != h or size != w:
184
+ new_abs_pos = F.interpolate(
185
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
186
+ size=(h, w),
187
+ mode="bicubic",
188
+ align_corners=False,
189
+ )
190
+
191
+ return new_abs_pos.permute(0, 2, 3, 1)
192
+ else:
193
+ return abs_pos.reshape(1, h, w, -1)
194
+
195
+
196
+ class PatchEmbed(nn.Module):
197
+ """
198
+ Image to Patch Embedding.
199
+ """
200
+
201
+ def __init__(
202
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
203
+ ):
204
+ """
205
+ Args:
206
+ kernel_size (Tuple): kernel size of the projection layer.
207
+ stride (Tuple): stride of the projection layer.
208
+ padding (Tuple): padding size of the projection layer.
209
+ in_chans (int): Number of input image channels.
210
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
211
+ """
212
+ super().__init__()
213
+
214
+ self.proj = nn.Conv2d(
215
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
216
+ )
217
+
218
+ def forward(self, x):
219
+ x = self.proj(x)
220
+ # B C H W -> B H W C
221
+ x = x.permute(0, 2, 3, 1)
222
+ return x
GLEE/glee/config.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from detectron2.config import CfgNode as CN
3
+
4
+
5
+ def add_glee_config(cfg):
6
+ """
7
+ Add config for DETR.
8
+ """
9
+
10
+ cfg.FIND_UNUSED_PARAMETERS = True
11
+ cfg.MODEL.MAX_CATEGORY_LEN = 100
12
+ cfg.MODEL.PSEUDO_VIDEO = False
13
+ cfg.MODEL.FREEZE_WHOLE = False
14
+ cfg.MODEL.CONTRAS_MEAN = False
15
+ cfg.MODEL.CROSS_TRACK = False
16
+ cfg.MODEL.TRACK_VERSION = 'v3'
17
+
18
+ cfg.INPUT.SAMPLING_FRAME_NUM = 1
19
+ cfg.INPUT.SAMPLING_FRAME_RANGE = 10
20
+ cfg.INPUT.SAMPLING_INTERVAL = 1
21
+ cfg.INPUT.SAMPLING_FRAME_SHUFFLE = False
22
+ cfg.INPUT.AUGMENTATIONS = [] # "brightness", "contrast", "saturation", "rotation"
23
+ cfg.INPUT.DATASET_MAPPER_NAME = None
24
+
25
+ cfg.DATALOADER.DATASET_RATIO = [1, 1]
26
+ cfg.DATALOADER.USE_DIFF_BS_SIZE = True
27
+ cfg.DATALOADER.DATASET_BS = [2, 2]
28
+ cfg.DATALOADER.DATASET_FILTERS = [True, True]
29
+ cfg.DATALOADER.USE_RFS = [False, False]
30
+ cfg.DATALOADER.MULTI_DATASET_GROUPING = True
31
+ cfg.DATALOADER.DATASET_ANN = ['image']
32
+
33
+
34
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
35
+
36
+ cfg.DATALOADER.DATASET_RATIO = [1, 1]
37
+ cfg.DATALOADER.USE_DIFF_BS_SIZE = True
38
+ cfg.DATALOADER.DATASET_BS = [2, 2]
39
+ cfg.DATALOADER.USE_RFS = [False, False]
40
+ cfg.DATALOADER.MULTI_DATASET_GROUPING = True
41
+ cfg.DATALOADER.DATASET_ANN = ['box', 'box']
42
+
43
+ # Allow different datasets to use different input resolutions
44
+ cfg.INPUT.MIN_SIZE_TRAIN_MULTI = [(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)]
45
+ cfg.INPUT.MAX_SIZE_TRAIN_MULTI = [1333, 768]
46
+
47
+
48
+ # MaskDINO model config
49
+ cfg.MODEL.MaskDINO = CN()
50
+ cfg.MODEL.MaskDINO.LEARN_TGT = False
51
+
52
+ # loss
53
+ cfg.MODEL.MaskDINO.PANO_BOX_LOSS = False
54
+ cfg.MODEL.MaskDINO.SEMANTIC_CE_LOSS = False
55
+ cfg.MODEL.MaskDINO.DEEP_SUPERVISION = True
56
+ cfg.MODEL.MaskDINO.NO_OBJECT_WEIGHT = 0.1
57
+ cfg.MODEL.MaskDINO.CLASS_WEIGHT = 4.0
58
+ cfg.MODEL.MaskDINO.DICE_WEIGHT = 5.0
59
+ cfg.MODEL.MaskDINO.MASK_WEIGHT = 5.0
60
+ cfg.MODEL.MaskDINO.BOX_WEIGHT = 5.
61
+ cfg.MODEL.MaskDINO.GIOU_WEIGHT = 2.
62
+
63
+ # cost weight
64
+ cfg.MODEL.MaskDINO.COST_CLASS_WEIGHT = 4.0
65
+ cfg.MODEL.MaskDINO.COST_DICE_WEIGHT = 5.0
66
+ cfg.MODEL.MaskDINO.COST_MASK_WEIGHT = 5.0
67
+ cfg.MODEL.MaskDINO.COST_BOX_WEIGHT = 5.
68
+ cfg.MODEL.MaskDINO.COST_GIOU_WEIGHT = 2.
69
+
70
+ # transformer config
71
+ cfg.MODEL.MaskDINO.NHEADS = 8
72
+ cfg.MODEL.MaskDINO.DROPOUT = 0.1
73
+ cfg.MODEL.MaskDINO.DIM_FEEDFORWARD = 2048
74
+ cfg.MODEL.MaskDINO.ENC_LAYERS = 0
75
+ cfg.MODEL.MaskDINO.DEC_LAYERS = 6
76
+ cfg.MODEL.MaskDINO.INITIAL_PRED = True
77
+ cfg.MODEL.MaskDINO.PRE_NORM = False
78
+ cfg.MODEL.MaskDINO.BOX_LOSS = True
79
+ cfg.MODEL.MaskDINO.HIDDEN_DIM = 256
80
+ cfg.MODEL.MaskDINO.NUM_OBJECT_QUERIES = 100
81
+
82
+ cfg.MODEL.MaskDINO.ENFORCE_INPUT_PROJ = False
83
+ cfg.MODEL.MaskDINO.TWO_STAGE = True
84
+ cfg.MODEL.MaskDINO.INITIALIZE_BOX_TYPE = 'no' # ['no', 'bitmask', 'mask2box']
85
+ cfg.MODEL.MaskDINO.DN="seg"
86
+ cfg.MODEL.MaskDINO.DN_NOISE_SCALE=0.4
87
+ cfg.MODEL.MaskDINO.DN_NUM=100
88
+ cfg.MODEL.MaskDINO.PRED_CONV=False
89
+
90
+ cfg.MODEL.MaskDINO.EVAL_FLAG = 1
91
+
92
+ # MSDeformAttn encoder configs
93
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
94
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
95
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
96
+ cfg.MODEL.SEM_SEG_HEAD.DIM_FEEDFORWARD = 2048
97
+ cfg.MODEL.SEM_SEG_HEAD.NUM_FEATURE_LEVELS = 3
98
+ cfg.MODEL.SEM_SEG_HEAD.TOTAL_NUM_FEATURE_LEVELS = 4
99
+ cfg.MODEL.SEM_SEG_HEAD.FEATURE_ORDER = 'high2low' # ['low2high', 'high2low'] high2low: from high level to low level
100
+
101
+ #####################
102
+
103
+ # MaskDINO inference config
104
+ cfg.MODEL.MaskDINO.TEST = CN()
105
+ cfg.MODEL.MaskDINO.TEST.TEST_FOUCUS_ON_BOX = False
106
+ cfg.MODEL.MaskDINO.TEST.SEMANTIC_ON = True
107
+ cfg.MODEL.MaskDINO.TEST.INSTANCE_ON = False
108
+ cfg.MODEL.MaskDINO.TEST.PANOPTIC_ON = False
109
+ cfg.MODEL.MaskDINO.TEST.OBJECT_MASK_THRESHOLD = 0.0
110
+ cfg.MODEL.MaskDINO.TEST.OVERLAP_THRESHOLD = 0.0
111
+ cfg.MODEL.MaskDINO.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
112
+ cfg.MODEL.MaskDINO.TEST.PANO_TRANSFORM_EVAL = True
113
+ cfg.MODEL.MaskDINO.TEST.PANO_TEMPERATURE = 0.06
114
+ # cfg.MODEL.MaskDINO.TEST.EVAL_FLAG = 1
115
+
116
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
117
+ # you can use this config to override
118
+ cfg.MODEL.MaskDINO.SIZE_DIVISIBILITY = 32
119
+
120
+ # pixel decoder config
121
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
122
+ # adding transformer in pixel decoder
123
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
124
+ # pixel decoder
125
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "MaskDINOEncoder"
126
+
127
+ # transformer module
128
+ cfg.MODEL.MaskDINO.TRANSFORMER_DECODER_NAME = "MaskDINODecoder"
129
+
130
+ # LSJ aug
131
+ cfg.INPUT.IMAGE_SIZE = 1024
132
+ cfg.INPUT.MIN_SCALE = 0.1
133
+ cfg.INPUT.MAX_SCALE = 2.0
134
+
135
+ # point loss configs
136
+ # Number of points sampled during training for a mask point head.
137
+ cfg.MODEL.MaskDINO.TRAIN_NUM_POINTS = 112 * 112
138
+ # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
139
+ # original paper.
140
+ cfg.MODEL.MaskDINO.OVERSAMPLE_RATIO = 3.0
141
+ # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
142
+ # the original paper.
143
+ cfg.MODEL.MaskDINO.IMPORTANCE_SAMPLE_RATIO = 0.75
144
+
145
+
146
+
147
+
148
+ cfg.MODEL.DIM_PROJ = 256
149
+ cfg.MODEL.VISUAL_PROMPT = False
150
+ cfg.MODEL.TEXT = CN()
151
+ cfg.MODEL.TEXT.ARCH = 'vlpencoder'
152
+ cfg.MODEL.TEXT.NAME= 'transformer'
153
+ cfg.MODEL.TEXT.TOKENIZER= 'clip'
154
+ cfg.MODEL.TEXT.CONTEXT_LENGTH= 77 # 77
155
+ cfg.MODEL.TEXT.WIDTH= 512
156
+ cfg.MODEL.TEXT.HEADS= 8
157
+ cfg.MODEL.TEXT.LAYERS= 12 # 6
158
+ cfg.MODEL.TEXT.AUTOGRESSIVE= True
159
+
160
+
161
+
162
+ cfg.MODEL.LANGUAGE_BACKBONE = CN()
163
+ cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT = False
164
+ cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE = "bert-base-uncased"
165
+ cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE = "bert-base-uncased"
166
+ cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM = 768
167
+ cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN = 77 # max length of the tokenized captions.
168
+ cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS = 1
169
+ # cfg.MODEL.LANGUAGE_BACKBONE.UNUSED_TOKEN = 106
170
+ # cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL = False
171
+ cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX = True
172
+
173
+
174
+
175
+
176
+
177
+ cfg.MODEL.ENCODER = CN()
178
+ cfg.MODEL.ENCODER.NAME= 'transformer_encoder_fpn'
179
+ cfg.MODEL.ENCODER.IGNORE_VALUE= 255
180
+ cfg.MODEL.ENCODER.NUM_CLASSES= 133
181
+ cfg.MODEL.ENCODER.LOSS_WEIGHT= 1.0
182
+ cfg.MODEL.ENCODER.CONVS_DIM= 512
183
+ cfg.MODEL.ENCODER.MASK_DIM= 512
184
+ cfg.MODEL.ENCODER.NORM= "GN"
185
+ cfg.MODEL.ENCODER.IN_FEATURES= ["res2", "res3", "res4", "res5"]
186
+ cfg.MODEL.ENCODER.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES= ["res3", "res4", "res5"]
187
+ cfg.MODEL.ENCODER.COMMON_STRIDE= 4
188
+ cfg.MODEL.ENCODER.TRANSFORMER_ENC_LAYERS= 6
189
+
190
+ cfg.MODEL.DECODER = CN()
191
+ cfg.MODEL.DECODER.TRANSFORMER_IN_FEATURE= "multi_scale_pixel_decoder"
192
+ cfg.MODEL.DECODER.MASK = True
193
+ # DETECTION= False
194
+ # SPATIAL=
195
+ # ENABLED= True
196
+ # GROUNDING=
197
+ # ENABLED= False
198
+ # MAX_LEN= 5
199
+ # TEXT_WEIGHT= 2.0
200
+ # CLASS_WEIGHT= 0.5
201
+ # VISUAL=
202
+ # ENABLED= False
203
+ # AUDIO=
204
+ # ENABLED= False
205
+ # OPENIMAGE=
206
+ # ENABLED= False
207
+ # NEGATIVE_SAMPLES= 5
208
+ # GROUNDING=
209
+ # ENABLED= False
210
+ # MAX_LEN= 5
211
+ # CAPTION=
212
+ # ENABLED= False
213
+ # PHRASE_PROB= 0.5
214
+ # SIM_THRES= 0.95
215
+ cfg.MODEL.DECODER.HIDDEN_DIM= 512
216
+ cfg.MODEL.DECODER.NUM_OBJECT_QUERIES= 101
217
+ cfg.MODEL.DECODER.NHEADS= 8
218
+ cfg.MODEL.DECODER.DROPOUT= 0.0
219
+ cfg.MODEL.DECODER.DIM_FEEDFORWARD= 2048
220
+ cfg.MODEL.DECODER.MAX_SPATIAL_LEN= [512, 512, 512, 512]
221
+ cfg.MODEL.DECODER.PRE_NORM= False
222
+ cfg.MODEL.DECODER.ENFORCE_INPUT_PROJ= False
223
+ cfg.MODEL.DECODER.SIZE_DIVISIBILITY= 32
224
+ cfg.MODEL.DECODER.TRAIN_NUM_POINTS= 12544
225
+ cfg.MODEL.DECODER.OVERSAMPLE_RATIO= 3.0
226
+ cfg.MODEL.DECODER.IMPORTANCE_SAMPLE_RATIO= 0.75
227
+ cfg.MODEL.DECODER.DEC_LAYERS= 10 # 9 decoder layers, add one for the loss on learnable query
228
+ cfg.MODEL.DECODER.TOP_GROUNDING_LAYERS= 10
229
+ cfg.MODEL.DECODER.TOP_CAPTION_LAYERS= 10
230
+ cfg.MODEL.DECODER.TOP_SPATIAL_LAYERS= 10
231
+ cfg.MODEL.DECODER.TOP_OPENIMAGE_LAYERS= 10
232
+ # TEST=
233
+ # SEMANTIC_ON= True
234
+ # INSTANCE_ON= True
235
+ # PANOPTIC_ON= True
236
+ # OVERLAP_THRESHOLD= 0.8
237
+ # OBJECT_MASK_THRESHOLD= 0.4
238
+ # SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE= false
239
+ # DETECTIONS_PER_IMAGE= 100
240
+
241
+ cfg.ATTENTION_ARCH = CN()
242
+ # cfg.ATTENTION_ARCH.VARIABLE={
243
+ # 'queries': ['object'],
244
+ # 'tokens': ['grounding', 'spatial', 'visual', 'audio']}
245
+
246
+ # SELF_ATTENTION:
247
+ # queries:
248
+ # object: ['queries_object', 'tokens_grounding', 'tokens_spatial', 'tokens_visual', 'tokens_audio']
249
+ # tokens:
250
+ # grounding: ['queries_object', 'tokens_grounding']
251
+ # spatial: ['tokens_spatial']
252
+ # visual: ['tokens_visual']
253
+ # audio: ['queries_object', 'tokens_audio']
254
+ # CROSS_ATTENTION:
255
+ # queries:
256
+ # object: True
257
+ # tokens:
258
+ # grounding: False
259
+ # spatial: False
260
+ # visual: False
261
+ # audio: False
262
+ # MASKING: ['tokens_spatial', 'tokens_grounding', 'tokens_visual', 'tokens_audio']
263
+ # DUPLICATION:
264
+ # queries:
265
+ # grounding: 'queries_object'
266
+ # spatial: 'queries_object'
267
+ # SPATIAL_MEMORIES: 32
268
+
269
+
270
+
271
+
272
+
273
+
274
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
275
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
276
+ cfg.SOLVER.TEXTENCODER_MULTIPLIER = 1.0
277
+ cfg.SOLVER.LR_DECAY_RATE = None
278
+ cfg.SOLVER.LR_DECAY_RATE_NUM_LAYERS = None
279
+
280
+
281
+ ## support Swin backbone
282
+ cfg.MODEL.SWIN = CN()
283
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
284
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
285
+ cfg.MODEL.SWIN.EMBED_DIM = 96
286
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
287
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
288
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
289
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
290
+ cfg.MODEL.SWIN.QKV_BIAS = True
291
+ cfg.MODEL.SWIN.QK_SCALE = None
292
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
293
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
294
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
295
+ cfg.MODEL.SWIN.APE = False
296
+ cfg.MODEL.SWIN.PATCH_NORM = True
297
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
298
+ cfg.MODEL.SWIN.USE_CHECKPOINT = False
299
+ cfg.MODEL.SWIN.PRETRAINED_WEIGHT = None
300
+
301
+
302
+ # support InterImage backbone
303
+ cfg.MODEL.INTERNIMAGE = CN() # large as base
304
+
305
+ #### large
306
+ cfg.MODEL.INTERNIMAGE.PRETRAINED_WEIGHT = None
307
+ cfg.MODEL.INTERNIMAGE.CORE_OP = "DCNv3"
308
+ cfg.MODEL.INTERNIMAGE.CHANNELS = 160
309
+ cfg.MODEL.INTERNIMAGE.DEPTHS = [5, 5, 22, 5]
310
+ cfg.MODEL.INTERNIMAGE.GROUPS =[10, 20, 40, 80]
311
+ cfg.MODEL.INTERNIMAGE.MLP_RATIO =4.
312
+ cfg.MODEL.INTERNIMAGE.DROP_PATH_RATE =0.0
313
+ cfg.MODEL.INTERNIMAGE.NORM_LAYER = "LN"
314
+ cfg.MODEL.INTERNIMAGE.LAYER_SCALE = 1.0
315
+ cfg.MODEL.INTERNIMAGE.OFFSET_SCALE = 2.0
316
+ cfg.MODEL.INTERNIMAGE.POST_NORM = True
317
+ cfg.MODEL.INTERNIMAGE.WITH_CP = False
318
+ cfg.MODEL.INTERNIMAGE.OUT_IINDICES = (0, 1, 2, 3)
319
+ cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE = None
320
+ cfg.MODEL.INTERNIMAGE.RES_POST_NORM = False
321
+ cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM = False
322
+ cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS = None
323
+ cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE = False
324
+
325
+ ### huge
326
+ # cfg.MODEL.INTERNIMAGE.PRETRAINED_WEIGHT = None
327
+ # cfg.MODEL.INTERNIMAGE.CORE_OP = "DCNv3"
328
+ # cfg.MODEL.INTERNIMAGE.CHANNELS = 320
329
+ # cfg.MODEL.INTERNIMAGE.DEPTHS = [6, 6, 32, 6]
330
+ # cfg.MODEL.INTERNIMAGE.GROUPS = [10, 20, 40, 80]
331
+ # cfg.MODEL.INTERNIMAGE.MLP_RATIO =4.
332
+ # cfg.MODEL.INTERNIMAGE.DROP_PATH_RATE = 0.5
333
+ # cfg.MODEL.INTERNIMAGE.NORM_LAYER = "LN"
334
+ # cfg.MODEL.INTERNIMAGE.LAYER_SCALE = None
335
+ # cfg.MODEL.INTERNIMAGE.OFFSET_SCALE = 1.0
336
+ # cfg.MODEL.INTERNIMAGE.POST_NORM = False
337
+ # cfg.MODEL.INTERNIMAGE.WITH_CP = False
338
+ # cfg.MODEL.INTERNIMAGE.OUT_IINDICES = (0, 1, 2, 3)
339
+
340
+ # cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE = 5
341
+ # cfg.MODEL.INTERNIMAGE.RES_POST_NORM = True
342
+ # cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM = True
343
+ # cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS = [5, 11, 17, 23, 29]
344
+ # cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE = True
345
+
346
+
347
+ # support EVA02 backbone
348
+ cfg.MODEL.EVA02 = CN() # large as base
349
+
350
+ #### large
351
+ cfg.MODEL.EVA02.PRETRAINED_WEIGHT = None
352
+ cfg.MODEL.EVA02.IMAGE_SIZE = 1536
353
+ cfg.MODEL.EVA02.PATCH_SIZE = 16
354
+ cfg.MODEL.EVA02.WINDOW_SIZE = 16
355
+ cfg.MODEL.EVA02.DMBED_DIM =1024
356
+ cfg.MODEL.EVA02.DEPTH = 24
357
+ cfg.MODEL.EVA02.NUM_HEADS = 16
358
+ cfg.MODEL.EVA02.MLP_RATIO = 4*2/3
359
+ cfg.MODEL.EVA02.DROP_PATH_RATE = 0.3
360
+ cfg.MODEL.EVA02.CHECKPOINT = True
361
+ cfg.MODEL.EVA02.WINDOW_BLOCK_INDEXES = [0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22]
362
+
363
+
364
+
365
+ # support EVA01 backbone
366
+ cfg.MODEL.EVA01 = CN() # large as base
367
+
368
+ #### large
369
+ cfg.MODEL.EVA01.PRETRAINED_WEIGHT = None
370
+
371
+ cfg.MODEL.EVA01.BEIT_LIKE_QKV_BIAS = True
372
+ cfg.MODEL.EVA01.BEIT_LIKE_GAMMA = False
373
+ cfg.MODEL.EVA01.FREEZE_PATH_EMBED = True
374
+
375
+ cfg.MODEL.EVA01.IMAGE_SIZE = 1280 # only for correct dim in pos embed
376
+ cfg.MODEL.EVA01.PATCH_SIZE = 16
377
+ cfg.MODEL.EVA01.WINDOW_SIZE = 16
378
+ cfg.MODEL.EVA01.DMBED_DIM = 1408
379
+ cfg.MODEL.EVA01.DEPTH = 40
380
+ cfg.MODEL.EVA01.NUM_HEADS = 16
381
+ cfg.MODEL.EVA01.MLP_RATIO = 6144 / 1408
382
+ cfg.MODEL.EVA01.DROP_PATH_RATE = 0.6
383
+ cfg.MODEL.EVA01.WINDOW_BLOCK_INDEXES = [0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 32, 33, 34, 36, 37, 38]
384
+
385
+
386
+
387
+
GLEE/glee/config_deeplab.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+
5
+ def add_deeplab_config(cfg):
6
+ """
7
+ Add config for DeepLab.
8
+ """
9
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
10
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
11
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
12
+ # Used for `poly` learning rate schedule.
13
+ cfg.SOLVER.POLY_LR_POWER = 0.9
14
+ cfg.SOLVER.POLY_LR_CONSTANT_ENDING = 0.0
15
+ # Loss type, choose from `cross_entropy`, `hard_pixel_mining`.
16
+ cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE = "hard_pixel_mining"
17
+ # DeepLab settings
18
+ cfg.MODEL.SEM_SEG_HEAD.PROJECT_FEATURES = ["res2"]
19
+ cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS = [48]
20
+ cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS = 256
21
+ cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS = [6, 12, 18]
22
+ cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT = 0.1
23
+ cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV = False
24
+ # Backbone new configs
25
+ cfg.MODEL.RESNETS.RES4_DILATION = 1
26
+ cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 2, 4]
27
+ # ResNet stem type from: `basic`, `deeplab`
28
+ cfg.MODEL.RESNETS.STEM_TYPE = "deeplab"
GLEE/glee/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
GLEE/glee/models/glee_model.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ # from ..backbone import build_backbone, Backbone
8
+ # from ..body.encoder import build_encoder
9
+ # from ..body.decoder import build_decoder
10
+
11
+ from detectron2.modeling import build_backbone
12
+
13
+ from .pixel_decoder.maskdino_encoder import build_pixel_decoder
14
+ from .transformer_decoder.maskdino_decoder import build_transformer_decoder
15
+
16
+ import random
17
+ from transformers import AutoTokenizer
18
+ from collections import OrderedDict
19
+ from ..modules.point_features import point_sample
20
+ from timm.models.layers import trunc_normal_
21
+ from transformers import CLIPTokenizer,CLIPTextModel
22
+ from .vos_utils import masks_to_boxes, FeatureFuser
23
+ import numpy as np
24
+ import math
25
+
26
+
27
+ def rand_sample(x, max_len):
28
+ if x.shape[1] <= max_len:
29
+ return x
30
+ else:
31
+ rand_idx = torch.randperm(x.shape[1])[:max_len]
32
+ return x[:,rand_idx]
33
+
34
+
35
+ def agg_lang_feat(features, mask, pool_type="average"):
36
+ """average pooling of language features"""
37
+ # feat: (bs, seq_len, C)
38
+ # mask: (bs, seq_len)
39
+ if pool_type == "average":
40
+ embedded = features * mask.unsqueeze(-1).float() # use mask to zero out invalid token features
41
+ aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
42
+ elif pool_type == "max":
43
+ out = []
44
+ for i in range(len(features)):
45
+ pool_feat, _ = torch.max(features[i][mask[i]], 0) # (L, C) -> (C, )
46
+ out.append(pool_feat)
47
+ aggregate = torch.stack(out, dim=0) # (bs, C)
48
+ else:
49
+ raise ValueError("pool_type should be average or max")
50
+ return aggregate
51
+
52
+ class GLEE_Model(nn.Module):
53
+ """
54
+ Main class for mask classification semantic segmentation architectures.
55
+ """
56
+ def __init__(self, cfg, matcher, device, video_info, contras_mean):
57
+ super().__init__()
58
+ self.cfg = cfg
59
+ self.matcher = matcher
60
+ self.backbone = build_backbone(cfg)
61
+ output_channels = [v for k,v in self.backbone._out_feature_channels.items()]
62
+ self.sot_fuser = FeatureFuser(output_channels[-3:], 256)
63
+
64
+
65
+ self.tokenizer = CLIPTokenizer.from_pretrained('GLEE/clip_vit_base_patch32')
66
+ self.tokenizer.add_special_tokens({'cls_token': self.tokenizer.eos_token})
67
+ self.text_encoder = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
68
+ # self.text_encoder_teacher = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
69
+ self.lang_encoder = None
70
+ # for p in self.text_encoder_teacher.parameters():
71
+ # p.requires_grad = False
72
+ self.lang_projection = nn.Parameter(torch.rand(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, cfg.MODEL.DIM_PROJ))
73
+ self.text_encode_type = 'clip_teacher'
74
+
75
+ # self.lang_encoder = None
76
+ self.pixel_decoder = build_pixel_decoder(cfg, self.backbone.output_shape())
77
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
78
+ self.predictor = build_transformer_decoder(cfg, transformer_predictor_in_channels, lang_encoder = self.lang_encoder, mask_classification=True,)
79
+ self.to(device)
80
+
81
+ self.video_info = video_info
82
+ self.contras_mean = contras_mean
83
+
84
+ self.track_loss_version = cfg.MODEL.TRACK_VERSION
85
+
86
+ self.no_mask_tasks = ['obj365', 'obj365_clip','openimage', 'openimage_clip', 'vg', 'grit', 'bdd_det', 'bdd_track_box']
87
+
88
+
89
+ # for visual prompt
90
+ hidden_dim = 256
91
+ self.max_spatial_len = [512,512,512,512]
92
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(4)])
93
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
94
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
95
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
96
+ trunc_normal_(self.mask_sptial_embed[3], std=.02)
97
+ # learnable positive negative indicator
98
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
99
+
100
+ @property
101
+ def device(self):
102
+ return self.pixel_mean.device
103
+
104
+ def forward(self, images, prompts, task, targets=None, batch_name_list=None, is_train = True, visual_prompt_type='scribble'):
105
+ extra = {}
106
+ # dist_loss = None
107
+ early_semantic = None
108
+
109
+ if self.text_encode_type == "clip_teacher":
110
+ if task not in ['grounding','rvos']:
111
+ assert batch_name_list
112
+ calsses_name_list = batch_name_list
113
+ tokenized = self.tokenizer.batch_encode_plus(calsses_name_list,
114
+ max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
115
+ padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
116
+ return_special_tokens_mask=True,
117
+ return_tensors='pt',
118
+ truncation=True).to(images.device)
119
+ texts = (tokenized['input_ids'], tokenized['attention_mask'])
120
+ token_x = self.text_encoder(*texts)['last_hidden_state']
121
+
122
+ valid_mask = tokenized['attention_mask'].bool()
123
+ # token_x_teacher = self.text_encoder_teacher(*texts)['last_hidden_state']
124
+ # if is_train:
125
+ # dist_loss = F.mse_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
126
+ # F.l2_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
127
+ token_x = token_x @ self.lang_projection
128
+ lang_feat_pool = agg_lang_feat(token_x, tokenized['attention_mask'], pool_type="average") # (bs, 768)
129
+ extra['class_embeddings'] = lang_feat_pool
130
+ if True: # early_fusion
131
+ gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
132
+ gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
133
+ gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
134
+ early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
135
+
136
+
137
+ if 'grounding' in prompts:
138
+
139
+ if self.text_encode_type == 'clip_frozen' or self.text_encode_type == 'clip_teacher':
140
+
141
+ tokens = self.tokenizer(
142
+ prompts['grounding'], padding='max_length', truncation=True, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, return_tensors='pt'
143
+ )
144
+ tokens = {key: value.to(images.device) for key, value in tokens.items()}
145
+
146
+ texts = (tokens['input_ids'], tokens['attention_mask'])
147
+ x = self.text_encoder(*texts)
148
+ token_x = x['last_hidden_state']
149
+ token_x = token_x @ self.lang_projection
150
+
151
+ extra['grounding_tokens'] = token_x.permute(1,0,2) #[len,bz,C]
152
+
153
+ non_zero_query_mask = tokens['attention_mask']
154
+ lang_feat_pool = agg_lang_feat(token_x, non_zero_query_mask, pool_type="average").unsqueeze(1) # (bs, 1, 768)
155
+
156
+ dist_loss = (lang_feat_pool*0).sum()
157
+
158
+ extra['grounding_nonzero_mask'] = ~non_zero_query_mask.bool() # [bz,len]
159
+ extra['grounding_class'] = lang_feat_pool.squeeze(1) #[bz,C
160
+ # gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
161
+ # gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
162
+ # gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
163
+ # early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
164
+ early_semantic = {"hidden":token_x.float(),"masks":tokens['attention_mask']>0}
165
+
166
+
167
+ if isinstance(images,torch.Tensor):
168
+ features = self.backbone(images)
169
+ else:
170
+ features = self.backbone(images.tensor)
171
+
172
+
173
+
174
+
175
+ if 'spatial' in prompts:
176
+ ## setp 1,2,3
177
+ key_images = [ images ] #bz*[1,3,H,W]
178
+ key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
179
+
180
+ prompt_mode = visual_prompt_type
181
+ ref_feats, ref_masks = self.get_template(key_images, key_promptmasks, prompt_mode)
182
+ early_fusion = {"hidden":ref_feats,"masks":ref_masks}
183
+ if early_semantic is None:
184
+ early_semantic = early_fusion
185
+ else:
186
+ early_semantic["hidden"] = torch.cat([early_semantic["hidden"],early_fusion["hidden"]],dim=1)
187
+ early_semantic["masks"] = torch.cat([early_semantic["masks"],early_fusion["masks"]],dim=1)
188
+
189
+
190
+ # bz = len(images)//2
191
+ mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_semantic)
192
+ if 'spatial' in prompts:
193
+ pos_masks = prompts['spatial']
194
+ # neg_masks = [~p for p in prompts['spatial']]
195
+ neg_masks = [p&False for p in prompts['spatial']]
196
+
197
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
198
+
199
+
200
+ _,h,w = extra['spatial_query_pos_mask'][0].shape
201
+ divisor = torch.tensor([h,w], device=mask_features.device)[None,]
202
+ # Get mean pos spatial query
203
+ non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
204
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
205
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
206
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True) #[(N, C, P)
207
+ spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num() # [1,bz,C]
208
+ # Get mean neg spatial query
209
+ non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
210
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
211
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
212
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
213
+ spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
214
+
215
+ # Get layerwise spatial query
216
+ src_spatial_queries = []
217
+ src_spatial_maskings = []
218
+ for i in range(len(multi_scale_features)):
219
+ bs,dc,h,w = multi_scale_features[i].shape
220
+ # src_mask_features = multi_scale_features[i].view(h,w,bs,dc)
221
+ src_mask_features = multi_scale_features[i].permute(2,3,0,1)
222
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
223
+
224
+ non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
225
+ non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
226
+ non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
227
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
228
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
229
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
230
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
231
+ non_zero_query_point[non_zero_query_mask] = 0
232
+
233
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
234
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
235
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
236
+
237
+ src_spatial_queries += [spatial_tokens]
238
+ src_spatial_maskings += [non_zero_query_mask]
239
+
240
+ extra['visual_prompt_tokens'] = src_spatial_queries #[len,bz,C]
241
+ extra['visual_prompt_nonzero_mask'] = src_spatial_maskings # [bz,len]
242
+
243
+
244
+ outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
245
+ return outputs
246
+
247
+
248
+
249
+
250
+
251
+
252
+ def get_template(self, imgs, pad_masks, prompt_mode='scribble'):
253
+ """img: (N, 3, H, W), mask: (N, 1, H, W), bbox: (1, 4)"""
254
+ """get 4-channel template"""
255
+
256
+ croped_img_with_mask = []
257
+
258
+ for image_i, mask_i in zip( imgs, pad_masks):
259
+
260
+ if prompt_mode in ['scribble','point']:
261
+ image_with_mask = image_i + mask_i.to(image_i)
262
+ else:
263
+ image_with_mask = image_i
264
+
265
+ # image_with_mask = torch.cat([image_i,mask_i.to(image_i)],dim=1) #[1,3,H,W]
266
+ box_i = masks_to_boxes(mask_i[0]) #[xyxy]
267
+ box_i[:, 2:] = box_i[:, 2:] - box_i[:, :2] #xywh
268
+
269
+
270
+ x, y, w, h = box_i[0].long().tolist()
271
+
272
+ self.search_area_factor=2
273
+
274
+ crop_sz = math.ceil(math.sqrt(w * h) * self.search_area_factor)
275
+ x1 = max(0,round(x + 0.5 * w - crop_sz * 0.5))
276
+ x2 = x1 + crop_sz
277
+ y1 = max(0,round(y + 0.5 * h - crop_sz * 0.5))
278
+ y2 = y1 + crop_sz
279
+
280
+ im_crop = image_with_mask[:, :, y1:y2, x1:x2]
281
+ # resize
282
+ if im_crop.shape[-1] ==0 or im_crop.shape[-2] ==0 :
283
+ im_crop = image_with_mask
284
+ im_crop = F.interpolate(im_crop, (256,256), mode='bilinear', align_corners=False)
285
+ croped_img_with_mask.append(im_crop)
286
+ croped_img_with_mask = torch.cat(croped_img_with_mask,dim=0) #[bz,3,256,256]
287
+ with torch.no_grad():
288
+ ref_srcs = self.backbone(croped_img_with_mask.contiguous())
289
+ ref_srcs = [v for k,v in ref_srcs.items()]
290
+ ref_feats = self.sot_fuser(ref_srcs[1:]).float() #[bz,256,32,32]
291
+
292
+ ref_feats = ref_feats.flatten(-2).permute(0, 2, 1) # (bs, L, C)
293
+ ref_masks = torch.ones_like(ref_feats[:,:,0])>0 #[bs,L]
294
+
295
+ return ref_feats, ref_masks
296
+
GLEE/glee/models/pixel_decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) IDEA, Inc. and its affiliates.
GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (184 Bytes). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-38.pyc ADDED
Binary file (6.36 kB). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/early_fusion.cpython-39.pyc ADDED
Binary file (6.3 kB). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-38.pyc ADDED
Binary file (15.4 kB). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/maskdino_encoder.cpython-39.pyc ADDED
Binary file (15.3 kB). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-38.pyc ADDED
Binary file (2.64 kB). View file
 
GLEE/glee/models/pixel_decoder/__pycache__/position_encoding.cpython-39.pyc ADDED
Binary file (2.6 kB). View file
 
GLEE/glee/models/pixel_decoder/early_fusion.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from timm.models.layers import DropPath
5
+
6
+
7
+
8
+
9
+ class VLFuse(torch.nn.Module):
10
+ """
11
+ Early Fusion Module
12
+ """
13
+
14
+ def __init__(self, ):
15
+ super(VLFuse, self).__init__()
16
+ self.init_configs()
17
+
18
+ # early fusion module
19
+ # bi-direction (text->image, image->text)
20
+ self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.img_dim, # 256
21
+ l_dim=self.lang_dim, # 768
22
+ embed_dim=self.embed_dim, # 2048
23
+ num_heads=self.n_head, # 8
24
+ dropout=0.1,
25
+ drop_path=.0,
26
+ init_values=1.0 / 6,
27
+ )
28
+ def init_configs(self, ):
29
+ # common params
30
+ self.img_dim = 256
31
+
32
+ self.max_query_len = 256
33
+ self.n_layers =1
34
+
35
+ # mha params
36
+ self.n_head = 8
37
+ self.embed_dim = 2048 # 2048 by default
38
+
39
+ self.lang_dim = 256
40
+
41
+ def forward(self, x, task=None):
42
+ visual_features = x["visual"]
43
+ language_dict_features = x["lang"]
44
+
45
+ fused_visual_features, language_features = self.b_attn(
46
+ visual_features, language_dict_features['hidden'], language_dict_features['masks'], task)
47
+
48
+ language_dict_features['hidden'] = language_features
49
+ fused_language_dict_features = language_dict_features
50
+
51
+ features_dict = {"visual": fused_visual_features,
52
+ "lang": fused_language_dict_features}
53
+
54
+ return features_dict
55
+
56
+
57
+ class BiMultiHeadAttention(nn.Module):
58
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1):
59
+ super(BiMultiHeadAttention, self).__init__()
60
+
61
+ self.embed_dim = embed_dim
62
+ self.num_heads = num_heads
63
+ self.head_dim = embed_dim // num_heads
64
+ self.v_dim = v_dim
65
+ self.l_dim = l_dim
66
+
67
+ assert (
68
+ self.head_dim * self.num_heads == self.embed_dim
69
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
70
+ self.scale = self.head_dim ** (-0.5)
71
+ self.dropout = dropout
72
+
73
+ self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
74
+ self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
75
+ self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
76
+ self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
77
+
78
+ self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
79
+ self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
80
+
81
+ self.stable_softmax_2d = False
82
+ self.clamp_min_for_underflow = True
83
+ self.clamp_max_for_overflow = True
84
+
85
+ self._reset_parameters()
86
+
87
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
88
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
89
+
90
+ def _reset_parameters(self):
91
+ nn.init.xavier_uniform_(self.v_proj.weight)
92
+ self.v_proj.bias.data.fill_(0)
93
+ nn.init.xavier_uniform_(self.l_proj.weight)
94
+ self.l_proj.bias.data.fill_(0)
95
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
96
+ self.values_v_proj.bias.data.fill_(0)
97
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
98
+ self.values_l_proj.bias.data.fill_(0)
99
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
100
+ self.out_v_proj.bias.data.fill_(0)
101
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
102
+ self.out_l_proj.bias.data.fill_(0)
103
+
104
+ def forward(self, v, l, attention_mask_l=None):
105
+ bsz, tgt_len, embed_dim = v.size()
106
+
107
+ query_states = self.v_proj(v) * self.scale
108
+ key_states = self._shape(self.l_proj(l), -1, bsz)
109
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
110
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
111
+
112
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim) # (bs * 8, -1, embed_dim//8)
113
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) # (bs * 8, seq_len_img, embed_dim//8)
114
+ key_states = key_states.view(*proj_shape) # (bs * 8, seq_len_text, embed_dim//8)
115
+ value_v_states = value_v_states.view(*proj_shape)
116
+ value_l_states = value_l_states.view(*proj_shape)
117
+
118
+ src_len = key_states.size(1)
119
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # (bs * 8, seq_len_img, seq_len_text)
120
+
121
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
122
+ raise ValueError(
123
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
124
+ )
125
+
126
+ # attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)
127
+
128
+ if self.stable_softmax_2d:
129
+ attn_weights = attn_weights - attn_weights.max()
130
+
131
+ if self.clamp_min_for_underflow:
132
+ attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
133
+ if self.clamp_max_for_overflow:
134
+ attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
135
+
136
+ attn_weights_T = attn_weights.transpose(1, 2)
137
+ attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
138
+ 0])
139
+ if self.clamp_min_for_underflow:
140
+ attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
141
+ if self.clamp_max_for_overflow:
142
+ attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
143
+
144
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
145
+ # assert attention_mask_l.dtype == torch.int64
146
+ if attention_mask_l is not None:
147
+ assert (attention_mask_l.dim() == 2) # (bs, seq_len)
148
+ attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) # (bs, 1, 1, seq_len)
149
+ attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
150
+ attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
151
+
152
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
153
+ raise ValueError(
154
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
155
+ )
156
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
157
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
158
+
159
+ attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
160
+
161
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
162
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
163
+
164
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
165
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
166
+
167
+
168
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
169
+ raise ValueError(
170
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
171
+ )
172
+
173
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
174
+ raise ValueError(
175
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
176
+ )
177
+
178
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
179
+ attn_output_v = attn_output_v.transpose(1, 2)
180
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
181
+
182
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
183
+ attn_output_l = attn_output_l.transpose(1, 2)
184
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
185
+
186
+ attn_output_v = self.out_v_proj(attn_output_v)
187
+ attn_output_l = self.out_l_proj(attn_output_l)
188
+
189
+ return attn_output_v, attn_output_l
190
+
191
+
192
+ class BiAttentionBlockForCheckpoint(nn.Module):
193
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
194
+ drop_path=.0, init_values=1e-4, ):
195
+ """
196
+ Inputs:
197
+ embed_dim - Dimensionality of input and attention feature vectors
198
+ num_heads - Number of heads to use in the Multi-Head Attention block
199
+ dropout - Amount of dropout to apply in the feed-forward network
200
+ """
201
+ super(BiAttentionBlockForCheckpoint, self).__init__()
202
+
203
+ # pre layer norm
204
+ self.layer_norm_v = nn.LayerNorm(v_dim)
205
+ self.layer_norm_l = nn.LayerNorm(l_dim)
206
+ self.attn = BiMultiHeadAttention(v_dim=v_dim,
207
+ l_dim=l_dim,
208
+ embed_dim=embed_dim,
209
+ num_heads=num_heads,
210
+ dropout=dropout,
211
+ )
212
+
213
+ # add layer scale for training stability
214
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
215
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
216
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
217
+
218
+
219
+ def forward(self, v, l, attention_mask_l=None, task=None):
220
+ # v: visual features, (bs, sigma(HW), 256)
221
+ # l: language features, (bs, seq_len, 768)
222
+ v = self.layer_norm_v(v)
223
+ l = self.layer_norm_l(l)
224
+ delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l)
225
+ # v, l = v + delta_v, l + delta_l
226
+ v = v + self.drop_path(self.gamma_v * delta_v)
227
+ l = l + self.drop_path(self.gamma_l * delta_l)
228
+ return v, l
229
+
230
+
GLEE/glee/models/pixel_decoder/maskdino_encoder.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # DINO
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified by Feng Li and Hao Zhang.
7
+ import logging
8
+ import numpy as np
9
+ from typing import Callable, Dict, List, Optional, Tuple, Union
10
+ import fvcore.nn.weight_init as weight_init
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
16
+ from torch.cuda.amp import autocast
17
+
18
+ from detectron2.config import configurable
19
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
20
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
21
+
22
+ from .position_encoding import PositionEmbeddingSine
23
+ from ...utils.utils import _get_clones, _get_clones_advanced, _get_activation_fn
24
+ from .ops.modules import MSDeformAttn
25
+ from .early_fusion import VLFuse
26
+
27
+ def build_pixel_decoder(cfg, input_shape):
28
+ """
29
+ Build a pixel decoder from `cfg.MODEL.MaskDINO.PIXEL_DECODER_NAME`.
30
+ """
31
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
32
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
33
+ forward_features = getattr(model, "forward_features", None)
34
+ if not callable(forward_features):
35
+ raise ValueError(
36
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
37
+ f"Please implement forward_features for {name} to only return mask features."
38
+ )
39
+ return model
40
+
41
+
42
+ # MSDeformAttn Transformer encoder in deformable detr
43
+ class MSDeformAttnTransformerEncoderOnly(nn.Module):
44
+ def __init__(self, d_model=256, nhead=8,
45
+ num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
46
+ activation="relu",
47
+ num_feature_levels=4, enc_n_points=4,):
48
+ super().__init__()
49
+
50
+ self.d_model = d_model
51
+ self.nhead = nhead
52
+
53
+ vl_fusion_layer = VLFuse()
54
+
55
+ encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
56
+ dropout, activation,
57
+ num_feature_levels, nhead, enc_n_points)
58
+ self.encoder = MSDeformAttnTransformerEncoder(vl_fusion_layer, encoder_layer, num_encoder_layers)
59
+
60
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
61
+
62
+ self._reset_parameters()
63
+
64
+ def _reset_parameters(self):
65
+ for p in self.parameters():
66
+ if p.dim() > 1:
67
+ nn.init.xavier_uniform_(p)
68
+ for m in self.modules():
69
+ if isinstance(m, MSDeformAttn):
70
+ m._reset_parameters()
71
+ normal_(self.level_embed)
72
+
73
+ def get_valid_ratio(self, mask):
74
+ _, H, W = mask.shape
75
+ valid_H = torch.sum(~mask[:, :, 0], 1)
76
+ valid_W = torch.sum(~mask[:, 0, :], 1)
77
+ valid_ratio_h = valid_H.float() / H
78
+ valid_ratio_w = valid_W.float() / W
79
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
80
+ return valid_ratio
81
+
82
+ def forward(self, srcs, masks, pos_embeds, early_fusion=None):
83
+
84
+ enable_mask=0
85
+ if masks is not None:
86
+ for src in srcs:
87
+ if src.size(2)%32 or src.size(3)%32:
88
+ enable_mask = 1
89
+ if enable_mask==0:
90
+ masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
91
+ # prepare input for encoder
92
+ src_flatten = []
93
+ mask_flatten = []
94
+ lvl_pos_embed_flatten = []
95
+ spatial_shapes = []
96
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
97
+ bs, c, h, w = src.shape
98
+ spatial_shape = (h, w)
99
+ spatial_shapes.append(spatial_shape)
100
+ src = src.flatten(2).transpose(1, 2)
101
+ mask = mask.flatten(1)
102
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
103
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
104
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
105
+ src_flatten.append(src)
106
+ mask_flatten.append(mask)
107
+ src_flatten = torch.cat(src_flatten, 1)
108
+ mask_flatten = torch.cat(mask_flatten, 1)
109
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
110
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
111
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
112
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
113
+ # encoder
114
+ memory, zero_loss = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, early_fusion)
115
+
116
+ return memory, spatial_shapes, level_start_index, zero_loss
117
+
118
+
119
+ class MSDeformAttnTransformerEncoderLayer(nn.Module):
120
+ def __init__(self,
121
+ d_model=256, d_ffn=1024,
122
+ dropout=0.1, activation="relu",
123
+ n_levels=4, n_heads=8, n_points=4):
124
+ super().__init__()
125
+
126
+ # self attention
127
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
128
+ self.dropout1 = nn.Dropout(dropout)
129
+ self.norm1 = nn.LayerNorm(d_model)
130
+
131
+ # ffn
132
+ self.linear1 = nn.Linear(d_model, d_ffn)
133
+ self.activation = _get_activation_fn(activation)
134
+ self.dropout2 = nn.Dropout(dropout)
135
+ self.linear2 = nn.Linear(d_ffn, d_model)
136
+ self.dropout3 = nn.Dropout(dropout)
137
+ self.norm2 = nn.LayerNorm(d_model)
138
+
139
+ @staticmethod
140
+ def with_pos_embed(tensor, pos):
141
+ return tensor if pos is None else tensor + pos
142
+
143
+ def forward_ffn(self, src):
144
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
145
+ src = src + self.dropout3(src2)
146
+ src = self.norm2(src)
147
+ return src
148
+
149
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
150
+ # self attention
151
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
152
+ src = src + self.dropout1(src2)
153
+ src = self.norm1(src)
154
+
155
+ # ffn
156
+ src = self.forward_ffn(src)
157
+
158
+ return src
159
+
160
+
161
+ class MSDeformAttnTransformerEncoder(nn.Module):
162
+ def __init__(self, vl_fusion_layer, encoder_layer, num_layers):
163
+ super().__init__()
164
+ self.layers = _get_clones(encoder_layer, num_layers)
165
+ self.num_layers = num_layers
166
+
167
+ self.vl_layers = _get_clones_advanced(vl_fusion_layer, num_layers, 1)
168
+
169
+
170
+ @staticmethod
171
+ def get_reference_points(spatial_shapes, valid_ratios, device):
172
+ reference_points_list = []
173
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
174
+
175
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
176
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
177
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
178
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
179
+ ref = torch.stack((ref_x, ref_y), -1)
180
+ reference_points_list.append(ref)
181
+ reference_points = torch.cat(reference_points_list, 1)
182
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
183
+ return reference_points
184
+
185
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, early_fusion=None):
186
+
187
+ if early_fusion:
188
+ output = {"visual": src, "lang": early_fusion}
189
+ else:
190
+ output = src
191
+
192
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
193
+ for _, (layer,vl_layer) in enumerate(zip(self.layers, self.vl_layers)):
194
+ if early_fusion:
195
+ output = vl_layer(output)
196
+ output["visual"] = layer(output["visual"], pos, reference_points, spatial_shapes, level_start_index, padding_mask)
197
+ else:
198
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
199
+
200
+
201
+ if early_fusion:
202
+ return output["visual"] , (output['lang']['hidden']*0).sum()
203
+ else:
204
+ return output, None
205
+
206
+
207
+ @SEM_SEG_HEADS_REGISTRY.register()
208
+ class MaskDINOEncoder(nn.Module):
209
+ """
210
+ This is the multi-scale encoder in detection models, also named as pixel decoder in segmentation models.
211
+ """
212
+ @configurable
213
+ def __init__(
214
+ self,
215
+ input_shape: Dict[str, ShapeSpec],
216
+ *,
217
+ transformer_dropout: float,
218
+ transformer_nheads: int,
219
+ transformer_dim_feedforward: int,
220
+ transformer_enc_layers: int,
221
+ conv_dim: int,
222
+ mask_dim: int,
223
+ norm: Optional[Union[str, Callable]] = None,
224
+ # deformable transformer encoder args
225
+ transformer_in_features: List[str],
226
+ common_stride: int,
227
+ num_feature_levels: int,
228
+ total_num_feature_levels: int,
229
+ feature_order: str,
230
+ ViTBackbone: bool,
231
+ ):
232
+ """
233
+ NOTE: this interface is experimental.
234
+ Args:
235
+ input_shape: shapes (channels and stride) of the input features
236
+ transformer_dropout: dropout probability in transformer
237
+ transformer_nheads: number of heads in transformer
238
+ transformer_dim_feedforward: dimension of feedforward network
239
+ transformer_enc_layers: number of transformer encoder layers
240
+ conv_dims: number of output channels for the intermediate conv layers.
241
+ mask_dim: number of output channels for the final conv layer.
242
+ norm (str or callable): normalization for all conv layers
243
+ num_feature_levels: feature scales used
244
+ total_num_feature_levels: total feautre scales used (include the downsampled features)
245
+ feature_order: 'low2high' or 'high2low', i.e., 'low2high' means low-resolution features are put in the first.
246
+ """
247
+ super().__init__()
248
+ transformer_input_shape = {
249
+ k: v for k, v in input_shape.items() if k in transformer_in_features
250
+ }
251
+ # this is the input shape of pixel decoder
252
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
253
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
254
+ self.feature_strides = [v.stride for k, v in input_shape]
255
+ self.feature_channels = [v.channels for k, v in input_shape]
256
+ self.feature_order = feature_order
257
+
258
+ if feature_order == "low2high":
259
+ transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: -x[1].stride)
260
+ else:
261
+ transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
262
+ self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
263
+ transformer_in_channels = [v.channels for k, v in transformer_input_shape]
264
+ self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
265
+ self.maskdino_num_feature_levels = num_feature_levels # always use 3 scales
266
+ self.total_num_feature_levels = total_num_feature_levels
267
+ self.common_stride = common_stride
268
+
269
+ self.transformer_num_feature_levels = len(self.transformer_in_features)
270
+ self.low_resolution_index = transformer_in_channels.index(max(transformer_in_channels))
271
+ self.high_resolution_index = 0 if self.feature_order == 'low2high' else -1
272
+
273
+ self.isViTBackbone = ViTBackbone
274
+ if not ViTBackbone:
275
+ if self.transformer_num_feature_levels > 1:
276
+ input_proj_list = []
277
+ for in_channels in transformer_in_channels[::-1]:
278
+ input_proj_list.append(nn.Sequential(
279
+ nn.Conv2d(in_channels, conv_dim, kernel_size=1),
280
+ nn.GroupNorm(32, conv_dim),
281
+ ))
282
+ # input projectino for downsample
283
+ in_channels = max(transformer_in_channels)
284
+ for _ in range(self.total_num_feature_levels - self.transformer_num_feature_levels): # exclude the res2
285
+ input_proj_list.append(nn.Sequential(
286
+ nn.Conv2d(in_channels, conv_dim, kernel_size=3, stride=2, padding=1),
287
+ nn.GroupNorm(32, conv_dim),
288
+ ))
289
+ in_channels = conv_dim
290
+ self.input_proj = nn.ModuleList(input_proj_list)
291
+ else:
292
+ self.input_proj = nn.ModuleList([
293
+ nn.Sequential(
294
+ nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
295
+ nn.GroupNorm(32, conv_dim),
296
+ )])
297
+
298
+ for proj in self.input_proj:
299
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
300
+ nn.init.constant_(proj[0].bias, 0)
301
+
302
+ self.transformer = MSDeformAttnTransformerEncoderOnly(
303
+ d_model=conv_dim,
304
+ dropout=transformer_dropout,
305
+ nhead=transformer_nheads,
306
+ dim_feedforward=transformer_dim_feedforward,
307
+ num_encoder_layers=transformer_enc_layers,
308
+ num_feature_levels=self.total_num_feature_levels,
309
+ )
310
+ N_steps = conv_dim // 2
311
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
312
+
313
+ self.mask_dim = mask_dim
314
+ # use 1x1 conv instead
315
+ self.mask_features = Conv2d(
316
+ conv_dim,
317
+ mask_dim,
318
+ kernel_size=1,
319
+ stride=1,
320
+ padding=0,
321
+ )
322
+ weight_init.c2_xavier_fill(self.mask_features)
323
+ # extra fpn levels
324
+ stride = min(self.transformer_feature_strides)
325
+ self.num_fpn_levels = max(int(np.log2(stride) - np.log2(self.common_stride)), 1)
326
+
327
+ lateral_convs = []
328
+ output_convs = []
329
+
330
+ use_bias = norm == ""
331
+ for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
332
+ lateral_norm = get_norm(norm, conv_dim)
333
+ output_norm = get_norm(norm, conv_dim)
334
+
335
+ lateral_conv = Conv2d(
336
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
337
+ )
338
+ output_conv = Conv2d(
339
+ conv_dim,
340
+ conv_dim,
341
+ kernel_size=3,
342
+ stride=1,
343
+ padding=1,
344
+ bias=use_bias,
345
+ norm=output_norm,
346
+ activation=F.relu,
347
+ )
348
+ weight_init.c2_xavier_fill(lateral_conv)
349
+ weight_init.c2_xavier_fill(output_conv)
350
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
351
+ self.add_module("layer_{}".format(idx + 1), output_conv)
352
+
353
+ lateral_convs.append(lateral_conv)
354
+ output_convs.append(output_conv)
355
+ # Place convs into top-down order (from low to high resolution)
356
+ # to make the top-down computation in forward clearer.
357
+ self.lateral_convs = lateral_convs[::-1]
358
+ self.output_convs = output_convs[::-1]
359
+
360
+ @classmethod
361
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
362
+ ret = {}
363
+ ret["input_shape"] = {
364
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
365
+ }
366
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
367
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
368
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
369
+ ret["transformer_dropout"] = cfg.MODEL.MaskDINO.DROPOUT
370
+ ret["transformer_nheads"] = cfg.MODEL.MaskDINO.NHEADS
371
+ ret["transformer_dim_feedforward"] = cfg.MODEL.SEM_SEG_HEAD.DIM_FEEDFORWARD # deformable transformer encoder
372
+ ret[
373
+ "transformer_enc_layers"
374
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
375
+ ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES # ['res3', 'res4', 'res5']
376
+ ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
377
+ ret["total_num_feature_levels"] = cfg.MODEL.SEM_SEG_HEAD.TOTAL_NUM_FEATURE_LEVELS
378
+ ret["num_feature_levels"] = cfg.MODEL.SEM_SEG_HEAD.NUM_FEATURE_LEVELS
379
+ ret["feature_order"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_ORDER
380
+ ret["ViTBackbone"] = cfg.MODEL.BACKBONE.NAME in ['D2_EVA02', 'D2_EVA01' , 'D2_ViT']
381
+ return ret
382
+
383
+ @autocast(enabled=False)
384
+ def forward_features(self, features, masks, early_fusion=None):
385
+ """
386
+ :param features: multi-scale features from the backbone
387
+ :param masks: image mask
388
+ :return: enhanced multi-scale features and mask feature (1/4 resolution) for the decoder to produce binary mask
389
+ """
390
+ # backbone features
391
+ srcs = []
392
+ pos = []
393
+ # additional downsampled features
394
+ srcsl = []
395
+ posl = []
396
+
397
+ if self.isViTBackbone:
398
+ for idx, f in enumerate(self.transformer_in_features[::-1]):
399
+ x = features[f].float() # deformable detr does not support half precision
400
+ srcs.append(x)
401
+ pos.append(self.pe_layer(x))
402
+ if self.feature_order != 'low2high':
403
+ srcs = srcs[::-1]
404
+ pos = pos[::-1]
405
+ else:
406
+ if self.total_num_feature_levels > self.transformer_num_feature_levels:
407
+ smallest_feat = features[self.transformer_in_features[self.low_resolution_index]].float()
408
+ _len_srcs = self.transformer_num_feature_levels
409
+ for l in range(_len_srcs, self.total_num_feature_levels):
410
+ if l == _len_srcs:
411
+ src = self.input_proj[l](smallest_feat)
412
+ else:
413
+ src = self.input_proj[l](srcsl[-1])
414
+ srcsl.append(src)
415
+ posl.append(self.pe_layer(src))
416
+ srcsl = srcsl[::-1]
417
+ # Reverse feature maps
418
+
419
+
420
+ for idx, f in enumerate(self.transformer_in_features[::-1]):
421
+ x = features[f].float() # deformable detr does not support half precision
422
+ srcs.append(self.input_proj[idx](x))
423
+ pos.append(self.pe_layer(x))
424
+ srcs.extend(srcsl) if self.feature_order == 'low2high' else srcsl.extend(srcs)
425
+ pos.extend(posl) if self.feature_order == 'low2high' else posl.extend(pos)
426
+ if self.feature_order != 'low2high':
427
+ srcs = srcsl
428
+ pos = posl
429
+
430
+ y, spatial_shapes, level_start_index, zero_loss = self.transformer(srcs, masks, pos, early_fusion)
431
+ bs = y.shape[0]
432
+
433
+ split_size_or_sections = [None] * self.total_num_feature_levels
434
+ for i in range(self.total_num_feature_levels):
435
+ if i < self.total_num_feature_levels - 1:
436
+ split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
437
+ else:
438
+ split_size_or_sections[i] = y.shape[1] - level_start_index[i]
439
+ y = torch.split(y, split_size_or_sections, dim=1)
440
+
441
+ out = []
442
+ multi_scale_features = []
443
+ num_cur_levels = 0
444
+ for i, z in enumerate(y):
445
+ out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
446
+
447
+ # append `out` with extra FPN levels
448
+ # Reverse feature maps into top-down order (from low to high resolution)
449
+ for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
450
+ x = features[f].float()
451
+ lateral_conv = self.lateral_convs[idx]
452
+ output_conv = self.output_convs[idx]
453
+ cur_fpn = lateral_conv(x)
454
+ # Following FPN implementation, we use nearest upsampling here
455
+ y = cur_fpn + F.interpolate(out[self.high_resolution_index], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
456
+ y = output_conv(y)
457
+ out.append(y)
458
+ for o in out:
459
+ if num_cur_levels < self.total_num_feature_levels:
460
+ multi_scale_features.append(o)
461
+ num_cur_levels += 1
462
+ return self.mask_features(out[-1]), out[0], multi_scale_features, zero_loss
463
+
GLEE/glee/models/pixel_decoder/ops/functions/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn_func import MSDeformAttnFunction
13
+
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (262 Bytes). View file
 
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (232 Bytes). View file
 
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc ADDED
Binary file (2.7 kB). View file
 
GLEE/glee/models/pixel_decoder/ops/functions/__pycache__/ms_deform_attn_func.cpython-39.pyc ADDED
Binary file (2.64 kB). View file
 
GLEE/glee/models/pixel_decoder/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.autograd import Function
19
+ from torch.autograd.function import once_differentiable
20
+
21
+ try:
22
+ import MultiScaleDeformableAttention as MSDA
23
+ except ModuleNotFoundError as e:
24
+ info_string = (
25
+ "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26
+ "\t`cd maskdino/modeling/pixel_decoder/ops`\n"
27
+ "\t`sh make.sh`\n"
28
+ )
29
+ # raise ModuleNotFoundError(info_string)
30
+
31
+
32
+ class MSDeformAttnFunction(Function):
33
+ @staticmethod
34
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35
+ ctx.im2col_step = im2col_step
36
+ output = MSDA.ms_deform_attn_forward(
37
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39
+ return output
40
+
41
+ @staticmethod
42
+ @once_differentiable
43
+ def backward(ctx, grad_output):
44
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45
+ grad_value, grad_sampling_loc, grad_attn_weight = \
46
+ MSDA.ms_deform_attn_backward(
47
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48
+
49
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50
+
51
+
52
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53
+ # for debug and test only,
54
+ # need to use cuda version instead
55
+ N_, S_, M_, D_ = value.shape
56
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58
+ sampling_grids = 2 * sampling_locations - 1
59
+ sampling_value_list = []
60
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65
+ # N_*M_, D_, Lq_, P_
66
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67
+ mode='bilinear', padding_mode='zeros', align_corners=False)
68
+ sampling_value_list.append(sampling_value_l_)
69
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72
+ return output.transpose(1, 2).contiguous()