tuandunghcmut commited on
Commit
5c6ebf3
·
verified ·
1 Parent(s): 8c4ebba

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +71 -0
  3. configuration_solider.py +80 -0
  4. model.safetensors +3 -0
  5. modeling_solider.py +1840 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_cfg": {
3
+ "type": "GELU"
4
+ },
5
+ "architectures": [
6
+ "SOLIDERModel"
7
+ ],
8
+ "attn_drop_rate": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_solider.SOLIDERConfig",
11
+ "AutoModel": "modeling_solider.SOLIDERModel"
12
+ },
13
+ "convert_weights": false,
14
+ "depths": [
15
+ 2,
16
+ 2,
17
+ 18,
18
+ 2
19
+ ],
20
+ "drop_path_rate": 0.0,
21
+ "drop_rate": 0.0,
22
+ "embed_dims": 128,
23
+ "frozen_stages": -1,
24
+ "hidden_size": 128,
25
+ "img_size": [
26
+ 224,
27
+ 224
28
+ ],
29
+ "in_channels": 3,
30
+ "init_cfg": null,
31
+ "mlp_ratio": 4,
32
+ "model_type": "swin_transformer",
33
+ "name": "solider_base",
34
+ "norm_cfg": {
35
+ "type": "LN"
36
+ },
37
+ "num_heads": [
38
+ 4,
39
+ 8,
40
+ 16,
41
+ 32
42
+ ],
43
+ "out_indices": [
44
+ 0,
45
+ 1,
46
+ 2,
47
+ 3
48
+ ],
49
+ "patch_norm": true,
50
+ "patch_size": 4,
51
+ "pretrain_img_size": [
52
+ 224,
53
+ 224
54
+ ],
55
+ "pretrained": null,
56
+ "qk_scale": null,
57
+ "qkv_bias": true,
58
+ "semantic_weight": 0.2,
59
+ "strides": [
60
+ 4,
61
+ 2,
62
+ 2,
63
+ 2
64
+ ],
65
+ "torch_dtype": "float32",
66
+ "transformers_version": "4.44.2",
67
+ "use_abs_pos_embed": false,
68
+ "vision_width": 1024,
69
+ "window_size": 7,
70
+ "with_cp": false
71
+ }
configuration_solider.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ BACKBONE_NAME2WIDTH = {
4
+ "swin_tiny_patch4_window7_224": 768,
5
+ "swin_small_patch4_window7_224": 768,
6
+ "swin_base_patch4_window7_224": 1024,
7
+ "solider_tiny": 768,
8
+ "solider_small": 768,
9
+ "solider_base": 1024,
10
+ }
11
+
12
+
13
+ class SOLIDERConfig(PretrainedConfig):
14
+ model_type = "swin_transformer"
15
+
16
+ def __init__(
17
+ self,
18
+ pretrain_img_size=224,
19
+ in_channels=3,
20
+ embed_dims=96,
21
+ patch_size=4,
22
+ window_size=7,
23
+ mlp_ratio=4,
24
+ depths=(2, 2, 6, 2),
25
+ num_heads=(3, 6, 12, 24),
26
+ strides=(4, 2, 2, 2),
27
+ out_indices=(0, 1, 2, 3),
28
+ qkv_bias=True,
29
+ qk_scale=None,
30
+ patch_norm=True,
31
+ drop_rate=0.0,
32
+ attn_drop_rate=0.0,
33
+ drop_path_rate=0.0, # NOTE: I modified this from the implemenation of SOLIDER
34
+ use_abs_pos_embed=False,
35
+ act_cfg=dict(type="GELU"),
36
+ norm_cfg=dict(type="LN"),
37
+ with_cp=False,
38
+ pretrained=None,
39
+ convert_weights=False,
40
+ frozen_stages=-1,
41
+ init_cfg=None,
42
+ semantic_weight=0.2, # NOTE: I modified this from the implemenation of SOLIDER
43
+ name="solider_small",
44
+ **kwargs,
45
+ ):
46
+ self.pretrain_img_size = pretrain_img_size
47
+ self.in_channels = in_channels
48
+ self.embed_dims = embed_dims
49
+ self.patch_size = patch_size
50
+ self.window_size = window_size
51
+ self.mlp_ratio = mlp_ratio
52
+ self.depths = depths
53
+ self.num_heads = num_heads
54
+ self.strides = strides
55
+ self.out_indices = out_indices
56
+ self.qkv_bias = qkv_bias
57
+ self.qk_scale = qk_scale
58
+ self.patch_norm = patch_norm
59
+ self.drop_rate = drop_rate
60
+ self.attn_drop_rate = attn_drop_rate
61
+ self.drop_path_rate = drop_path_rate
62
+ self.use_abs_pos_embed = use_abs_pos_embed
63
+ self.act_cfg = act_cfg
64
+ self.norm_cfg = norm_cfg
65
+ self.with_cp = with_cp
66
+ self.pretrained = pretrained
67
+ self.convert_weights = convert_weights
68
+ self.frozen_stages = frozen_stages
69
+ self.init_cfg = init_cfg
70
+ self.semantic_weight = semantic_weight
71
+
72
+ # NOTE: These below attributes are just for provide information!
73
+ # They are not effect on model building!
74
+ self.img_size = pretrain_img_size
75
+ assert name in BACKBONE_NAME2WIDTH
76
+ self.name = name
77
+ self.vision_width = BACKBONE_NAME2WIDTH[self.name]
78
+ self.hidden_size = self.embed_dims
79
+
80
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f69b0364dd102a368bcf73928e645edfe6e723acb1195f18b6ab55dec4f918d
3
+ size 347551320
modeling_solider.py ADDED
@@ -0,0 +1,1840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from collections import OrderedDict
4
+ from copy import deepcopy
5
+ import logging
6
+ import math
7
+ from typing import Sequence
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as cp
12
+ import numpy as np
13
+ import cv2
14
+ from dataclasses import dataclass
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers import PretrainedConfig
17
+
18
+ # from .lavis_base_model import BaseEncoder
19
+ # from lavis.common.registry import registry
20
+
21
+ from torch.nn import Module as BaseModule
22
+ from torch.nn import ModuleList
23
+ from torch.nn import Sequential
24
+ from torch.nn import Linear
25
+ from torch import Tensor
26
+ from itertools import repeat
27
+ import collections.abc
28
+
29
+ from .configuration_solider import SOLIDERConfig, BACKBONE_NAME2WIDTH
30
+ def _ntuple(n):
31
+ def parse(x):
32
+ if isinstance(x, collections.abc.Iterable):
33
+ return x
34
+ return tuple(repeat(x, n))
35
+
36
+ return parse
37
+
38
+
39
+ to_2tuple = _ntuple(2)
40
+
41
+
42
+ def trunc_normal_init(
43
+ module: nn.Module,
44
+ mean: float = 0,
45
+ std: float = 1,
46
+ a: float = -2,
47
+ b: float = 2,
48
+ bias: float = 0,
49
+ ) -> None:
50
+ if hasattr(module, "weight") and module.weight is not None:
51
+ # trunc_normal_(module.weight, mean, std, a, b) # type: ignore
52
+ _no_grad_trunc_normal_(module.weight, mean, std, a, b) # type: ignore
53
+ if hasattr(module, "bias") and module.bias is not None:
54
+ nn.init.constant_(module.bias, bias) # type: ignore
55
+
56
+
57
+ def _no_grad_trunc_normal_(
58
+ tensor: Tensor, mean: float, std: float, a: float, b: float
59
+ ) -> Tensor:
60
+ # Method based on
61
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
62
+ # Modified from
63
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
64
+ def norm_cdf(x):
65
+ # Computes standard normal cumulative distribution function
66
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
67
+
68
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
69
+ warnings.warn(
70
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
71
+ "The distribution of values may be incorrect.",
72
+ stacklevel=2,
73
+ )
74
+
75
+ with torch.no_grad():
76
+ # Values are generated by using a truncated uniform distribution and
77
+ # then using the inverse CDF for the normal distribution.
78
+ # Get upper and lower cdf values
79
+ lower = norm_cdf((a - mean) / std)
80
+ upper = norm_cdf((b - mean) / std)
81
+
82
+ # Uniformly fill tensor with values from [lower, upper], then translate
83
+ # to [2lower-1, 2upper-1].
84
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
85
+
86
+ # Use inverse cdf transform for normal distribution to get truncated
87
+ # standard normal
88
+ tensor.erfinv_()
89
+
90
+ # Transform to proper mean, std
91
+ tensor.mul_(std * math.sqrt(2.0))
92
+ tensor.add_(mean)
93
+
94
+ # Clamp to ensure it's in the proper range
95
+ tensor.clamp_(min=a, max=b)
96
+ return tensor
97
+
98
+
99
+ def trunc_normal_(
100
+ tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
101
+ ) -> Tensor:
102
+ r"""Fills the input Tensor with values drawn from a truncated
103
+ normal distribution. The values are effectively drawn from the
104
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
105
+ with values outside :math:`[a, b]` redrawn until they are within
106
+ the bounds. The method used for generating the random values works
107
+ best when :math:`a \leq \text{mean} \leq b`.
108
+
109
+ Modified from
110
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
111
+
112
+ Args:
113
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
114
+ mean (float): the mean of the normal distribution.
115
+ std (float): the standard deviation of the normal distribution.
116
+ a (float): the minimum cutoff value.
117
+ b (float): the maximum cutoff value.
118
+ """
119
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
120
+
121
+
122
+ def constant_init(module, val, bias=0):
123
+ if hasattr(module, "weight") and module.weight is not None:
124
+ nn.init.constant_(module.weight, val)
125
+ if hasattr(module, "bias") and module.bias is not None:
126
+ nn.init.constant_(module.bias, bias)
127
+
128
+
129
+ def build_norm_layer(norm_cfg, embed_dims):
130
+ assert norm_cfg["type"] == "LN"
131
+ norm_layer = nn.LayerNorm(embed_dims)
132
+ return norm_cfg["type"], norm_layer
133
+
134
+
135
+ class GELU(nn.Module):
136
+ r"""Applies the Gaussian Error Linear Units function:
137
+
138
+ .. math::
139
+ \text{GELU}(x) = x * \Phi(x)
140
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
141
+ Gaussian Distribution.
142
+
143
+ Shape:
144
+ - Input: :math:`(N, *)` where `*` means, any number of additional
145
+ dimensions
146
+ - Output: :math:`(N, *)`, same shape as the input
147
+
148
+ .. image:: scripts/activation_images/GELU.png
149
+
150
+ Examples::
151
+
152
+ >>> m = nn.GELU()
153
+ >>> input = torch.randn(2)
154
+ >>> output = m(input)
155
+ """
156
+
157
+ def forward(self, input):
158
+ return F.gelu(input)
159
+
160
+
161
+ def build_activation_layer(act_cfg):
162
+ if act_cfg["type"] == "ReLU":
163
+ act_layer = nn.ReLU(inplace=act_cfg["inplace"])
164
+ elif act_cfg["type"] == "GELU":
165
+ act_layer = GELU()
166
+ return act_layer
167
+
168
+
169
+ def build_conv_layer(
170
+ conv_cfg, in_channels, out_channels, kernel_size, stride, padding, dilation, bias
171
+ ):
172
+ conv_layer = nn.Conv2d(
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ kernel_size=kernel_size,
176
+ stride=stride,
177
+ padding=padding,
178
+ dilation=dilation,
179
+ bias=bias,
180
+ )
181
+ return conv_layer
182
+
183
+
184
+ def drop_path(x, drop_prob=0.0, training=False):
185
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
186
+ residual blocks).
187
+
188
+ We follow the implementation
189
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
190
+ """
191
+ if drop_prob == 0.0 or not training:
192
+ return x
193
+ keep_prob = 1 - drop_prob
194
+ # handle tensors with different dimensions, not just 4D tensors.
195
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
196
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
197
+ output = x.div(keep_prob) * random_tensor.floor()
198
+ return output
199
+
200
+
201
+ class DropPath(nn.Module):
202
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
203
+ residual blocks).
204
+
205
+ We follow the implementation
206
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
207
+
208
+ Args:
209
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
210
+ """
211
+
212
+ def __init__(self, drop_prob=0.1):
213
+ super(DropPath, self).__init__()
214
+ self.drop_prob = drop_prob
215
+
216
+ def forward(self, x):
217
+ return drop_path(x, self.drop_prob, self.training)
218
+
219
+
220
+ def build_dropout(drop_cfg):
221
+ drop_layer = DropPath(drop_cfg["drop_prob"])
222
+ return drop_layer
223
+
224
+
225
+ class FFN(BaseModule):
226
+ def __init__(
227
+ self,
228
+ embed_dims=256,
229
+ feedforward_channels=1024,
230
+ num_fcs=2,
231
+ act_cfg=dict(type="ReLU", inplace=True),
232
+ ffn_drop=0.0,
233
+ dropout_layer=None,
234
+ add_identity=True,
235
+ init_cfg=None,
236
+ **kwargs,
237
+ ):
238
+ super(FFN, self).__init__()
239
+ assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}."
240
+ self.embed_dims = embed_dims
241
+ self.feedforward_channels = feedforward_channels
242
+ self.num_fcs = num_fcs
243
+ self.act_cfg = act_cfg
244
+ self.activate = build_activation_layer(act_cfg)
245
+
246
+ layers = []
247
+ in_channels = embed_dims
248
+ for _ in range(num_fcs - 1):
249
+ layers.append(
250
+ Sequential(
251
+ Linear(in_channels, feedforward_channels),
252
+ self.activate,
253
+ nn.Dropout(ffn_drop),
254
+ )
255
+ )
256
+ in_channels = feedforward_channels
257
+ layers.append(Linear(feedforward_channels, embed_dims))
258
+ layers.append(nn.Dropout(ffn_drop))
259
+ self.layers = Sequential(*layers)
260
+ self.dropout_layer = (
261
+ build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity()
262
+ )
263
+ self.add_identity = add_identity
264
+
265
+ def forward(self, x, identity=None):
266
+ """Forward function for `FFN`.
267
+
268
+ The function would add x to the output tensor if residue is None.
269
+ """
270
+ out = self.layers(x)
271
+ if not self.add_identity:
272
+ return self.dropout_layer(out)
273
+ if identity is None:
274
+ identity = x
275
+ return identity + self.dropout_layer(out)
276
+
277
+
278
+ def swin_converter(ckpt):
279
+ new_ckpt = OrderedDict()
280
+
281
+ def correct_unfold_reduction_order(x):
282
+ out_channel, in_channel = x.shape
283
+ x = x.reshape(out_channel, 4, in_channel // 4)
284
+ x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel)
285
+ return x
286
+
287
+ def correct_unfold_norm_order(x):
288
+ in_channel = x.shape[0]
289
+ x = x.reshape(4, in_channel // 4)
290
+ x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
291
+ return x
292
+
293
+ for k, v in ckpt.items():
294
+ if k.startswith("head"):
295
+ continue
296
+ elif k.startswith("layers"):
297
+ new_v = v
298
+ if "attn." in k:
299
+ new_k = k.replace("attn.", "attn.w_msa.")
300
+ elif "mlp." in k:
301
+ if "mlp.fc1." in k:
302
+ new_k = k.replace("mlp.fc1.", "ffn.layers.0.0.")
303
+ elif "mlp.fc2." in k:
304
+ new_k = k.replace("mlp.fc2.", "ffn.layers.1.")
305
+ else:
306
+ new_k = k.replace("mlp.", "ffn.")
307
+ elif "downsample" in k:
308
+ new_k = k
309
+ if "reduction." in k:
310
+ new_v = correct_unfold_reduction_order(v)
311
+ elif "norm." in k:
312
+ new_v = correct_unfold_norm_order(v)
313
+ else:
314
+ new_k = k
315
+ new_k = new_k.replace("layers", "stages", 1)
316
+ elif k.startswith("patch_embed"):
317
+ new_v = v
318
+ if "proj" in k:
319
+ new_k = k.replace("proj", "projection")
320
+ else:
321
+ new_k = k
322
+ else:
323
+ new_v = v
324
+ new_k = k
325
+
326
+ new_ckpt["backbone." + new_k] = new_v
327
+
328
+ return new_ckpt
329
+
330
+
331
+ class AdaptivePadding(nn.Module):
332
+ """Applies padding to input (if needed) so that input can get fully covered
333
+ by filter you specified. It support two modes "same" and "corner". The
334
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
335
+ input. The "corner" mode would pad zero to bottom right.
336
+ Args:
337
+ kernel_size (int | tuple): Size of the kernel:
338
+ stride (int | tuple): Stride of the filter. Default: 1:
339
+ dilation (int | tuple): Spacing between kernel elements.
340
+ Default: 1
341
+ padding (str): Support "same" and "corner", "corner" mode
342
+ would pad zero to bottom right, and "same" mode would
343
+ pad zero around input. Default: "corner".
344
+ Example:
345
+ >>> kernel_size = 16
346
+ >>> stride = 16
347
+ >>> dilation = 1
348
+ >>> input = torch.rand(1, 1, 15, 17)
349
+ >>> adap_pad = AdaptivePadding(
350
+ >>> kernel_size=kernel_size,
351
+ >>> stride=stride,
352
+ >>> dilation=dilation,
353
+ >>> padding="corner")
354
+ >>> out = adap_pad(input)
355
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
356
+ >>> input = torch.rand(1, 1, 16, 17)
357
+ >>> out = adap_pad(input)
358
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
359
+ """
360
+
361
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"):
362
+ super(AdaptivePadding, self).__init__()
363
+
364
+ assert padding in ("same", "corner")
365
+
366
+ kernel_size = to_2tuple(kernel_size)
367
+ stride = to_2tuple(stride)
368
+ padding = to_2tuple(padding)
369
+ dilation = to_2tuple(dilation)
370
+
371
+ self.padding = padding
372
+ self.kernel_size = kernel_size
373
+ self.stride = stride
374
+ self.dilation = dilation
375
+
376
+ def get_pad_shape(self, input_shape):
377
+ input_h, input_w = input_shape
378
+ kernel_h, kernel_w = self.kernel_size
379
+ stride_h, stride_w = self.stride
380
+ output_h = math.ceil(input_h / stride_h)
381
+ output_w = math.ceil(input_w / stride_w)
382
+ pad_h = max(
383
+ (output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h,
384
+ 0,
385
+ )
386
+ pad_w = max(
387
+ (output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w,
388
+ 0,
389
+ )
390
+ return pad_h, pad_w
391
+
392
+ def forward(self, x):
393
+ B, C, h, w = x.shape
394
+
395
+ pad_h, pad_w = self.get_pad_shape((h, w))
396
+
397
+ if pad_h > 0 or pad_w > 0:
398
+ if self.padding == "corner":
399
+ return F.pad(x, [0, pad_w, 0, pad_h]).view(
400
+ B, C, h + pad_h, w + pad_w
401
+ ), (
402
+ h + pad_h,
403
+ w + pad_w,
404
+ )
405
+ elif self.padding == "same":
406
+ return F.pad(
407
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
408
+ ).view(B, C, h + pad_h, w + pad_w), (
409
+ h + pad_h,
410
+ w + pad_w,
411
+ )
412
+ return x, (h, w)
413
+
414
+
415
+ class PatchEmbed(BaseModule):
416
+ """Image to Patch Embedding.
417
+ We use a conv layer to implement PatchEmbed.
418
+ Args:
419
+ in_channels (int): The num of input channels. Default: 3
420
+ embed_dims (int): The dimensions of embedding. Default: 768
421
+ conv_type (str): The config dict for embedding
422
+ conv layer type selection. Default: "Conv2d.
423
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
424
+ stride (int): The slide stride of embedding conv.
425
+ Default: None (Would be set as `kernel_size`).
426
+ padding (int | tuple | string ): The padding length of
427
+ embedding conv. When it is a string, it means the mode
428
+ of adaptive padding, support "same" and "corner" now.
429
+ Default: "corner".
430
+ dilation (int): The dilation rate of embedding conv. Default: 1.
431
+ bias (bool): Bias of embed conv. Default: True.
432
+ norm_cfg (dict, optional): Config dict for normalization layer.
433
+ Default: None.
434
+ input_size (int | tuple | None): The size of input, which will be
435
+ used to calculate the out size. Only work when `dynamic_size`
436
+ is False. Default: None.
437
+ init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
438
+ Default: None.
439
+ """
440
+
441
+ def __init__(
442
+ self,
443
+ in_channels=3,
444
+ embed_dims=768,
445
+ conv_type="Conv2d",
446
+ kernel_size=16,
447
+ stride=16,
448
+ padding="corner",
449
+ dilation=1,
450
+ bias=True,
451
+ norm_cfg=None,
452
+ input_size=None,
453
+ init_cfg=None,
454
+ ):
455
+ super(PatchEmbed, self).__init__()
456
+
457
+ self.embed_dims = embed_dims
458
+ if stride is None:
459
+ stride = kernel_size
460
+
461
+ kernel_size = to_2tuple(kernel_size)
462
+ stride = to_2tuple(stride)
463
+ dilation = to_2tuple(dilation)
464
+
465
+ if isinstance(padding, str):
466
+ self.adap_padding = AdaptivePadding(
467
+ kernel_size=kernel_size,
468
+ stride=stride,
469
+ dilation=dilation,
470
+ padding=padding,
471
+ )
472
+ # disable the padding of conv
473
+ padding = 0
474
+ else:
475
+ self.adap_padding = None
476
+ padding = to_2tuple(padding)
477
+
478
+ self.projection = build_conv_layer(
479
+ dict(type=conv_type),
480
+ in_channels=in_channels,
481
+ out_channels=embed_dims,
482
+ kernel_size=kernel_size,
483
+ stride=stride,
484
+ padding=padding,
485
+ dilation=dilation,
486
+ bias=bias,
487
+ )
488
+
489
+ if norm_cfg is not None:
490
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
491
+ else:
492
+ self.norm = None
493
+
494
+ if input_size:
495
+ input_size = to_2tuple(input_size)
496
+ # `init_out_size` would be used outside to
497
+ # calculate the num_patches
498
+ # when `use_abs_pos_embed` outside
499
+ self.init_input_size = input_size
500
+ if self.adap_padding:
501
+ pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
502
+ input_h, input_w = input_size
503
+ input_h = input_h + pad_h
504
+ input_w = input_w + pad_w
505
+ input_size = (input_h, input_w)
506
+
507
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
508
+ h_out = (
509
+ input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
510
+ ) // stride[0] + 1
511
+ w_out = (
512
+ input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
513
+ ) // stride[1] + 1
514
+ self.init_out_size = (h_out, w_out)
515
+ else:
516
+ self.init_input_size = None
517
+ self.init_out_size = None
518
+
519
+ def forward(self, x):
520
+ """
521
+ Args:
522
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
523
+ Returns:
524
+ tuple: Contains merged results and its spatial shape.
525
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
526
+ - out_size (tuple[int]): Spatial shape of x, arrange as
527
+ (out_h, out_w).
528
+ """
529
+
530
+ if self.adap_padding:
531
+ x, _ = self.adap_padding(x)
532
+
533
+ x = self.projection(x)
534
+
535
+ B, C, out_h, out_w = x.shape
536
+
537
+ x = x.view(B, C, out_h * out_w).transpose(1, 2)
538
+
539
+ if self.norm is not None:
540
+ x = self.norm(x)
541
+ return x, (out_h, out_w)
542
+
543
+
544
+ class PatchMerging(BaseModule):
545
+ """Merge patch feature map.
546
+ This layer groups feature map by kernel_size, and applies norm and linear
547
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
548
+ merge patch, which is about 25% faster than original implementation.
549
+ Instead, we need to modify pretrained models for compatibility.
550
+ Args:
551
+ in_channels (int): The num of input channels.
552
+ to gets fully covered by filter and stride you specified..
553
+ Default: True.
554
+ out_channels (int): The num of output channels.
555
+ kernel_size (int | tuple, optional): the kernel size in the unfold
556
+ layer. Defaults to 2.
557
+ stride (int | tuple, optional): the stride of the sliding blocks in the
558
+ unfold layer. Default: None. (Would be set as `kernel_size`)
559
+ padding (int | tuple | string ): The padding length of
560
+ embedding conv. When it is a string, it means the mode
561
+ of adaptive padding, support "same" and "corner" now.
562
+ Default: "corner".
563
+ dilation (int | tuple, optional): dilation parameter in the unfold
564
+ layer. Default: 1.
565
+ bias (bool, optional): Whether to add bias in linear layer or not.
566
+ Defaults: False.
567
+ norm_cfg (dict, optional): Config dict for normalization layer.
568
+ Default: dict(type='LN').
569
+ init_cfg (dict, optional): The extra config for initialization.
570
+ Default: None.
571
+ """
572
+
573
+ def __init__(
574
+ self,
575
+ in_channels,
576
+ out_channels,
577
+ kernel_size=2,
578
+ stride=None,
579
+ padding="corner",
580
+ dilation=1,
581
+ bias=False,
582
+ norm_cfg=dict(type="LN"),
583
+ init_cfg=None,
584
+ ):
585
+ super().__init__()
586
+ self.in_channels = in_channels
587
+ self.out_channels = out_channels
588
+ if stride:
589
+ stride = stride
590
+ else:
591
+ stride = kernel_size
592
+
593
+ kernel_size = to_2tuple(kernel_size)
594
+ stride = to_2tuple(stride)
595
+ dilation = to_2tuple(dilation)
596
+
597
+ if isinstance(padding, str):
598
+ self.adap_padding = AdaptivePadding(
599
+ kernel_size=kernel_size,
600
+ stride=stride,
601
+ dilation=dilation,
602
+ padding=padding,
603
+ )
604
+ # disable the padding of unfold
605
+ padding = 0
606
+ else:
607
+ self.adap_padding = None
608
+
609
+ padding = to_2tuple(padding)
610
+ self.sampler = nn.Unfold(
611
+ kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride
612
+ )
613
+
614
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
615
+
616
+ if norm_cfg is not None:
617
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
618
+ else:
619
+ self.norm = None
620
+
621
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
622
+
623
+ def forward(self, x, input_size):
624
+ """
625
+ Args:
626
+ x (Tensor): Has shape (B, H*W, C_in).
627
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
628
+ Default: None.
629
+ Returns:
630
+ tuple: Contains merged results and its spatial shape.
631
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
632
+ - out_size (tuple[int]): Spatial shape of x, arrange as
633
+ (Merged_H, Merged_W).
634
+ """
635
+ B, L, C = x.shape
636
+ assert isinstance(input_size, Sequence), (
637
+ f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}"
638
+ )
639
+
640
+ H, W = input_size
641
+ assert L == H * W, "input feature has wrong size"
642
+
643
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
644
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
645
+ # but need to modify pretrained model for compatibility
646
+
647
+ if self.adap_padding:
648
+ x, (H, W) = self.adap_padding(x)
649
+
650
+ x = self.sampler(x)
651
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
652
+
653
+ out_h = (
654
+ H
655
+ + 2 * self.sampler.padding[0]
656
+ - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1)
657
+ - 1
658
+ ) // self.sampler.stride[0] + 1
659
+ out_w = (
660
+ W
661
+ + 2 * self.sampler.padding[1]
662
+ - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1)
663
+ - 1
664
+ ) // self.sampler.stride[1] + 1
665
+
666
+ x = x.view(B, C * H * W // (out_h * out_w), out_h * out_w)
667
+
668
+ output_size = (out_h, out_w)
669
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
670
+ x = self.norm(x) if self.norm else x
671
+ x = self.reduction(x)
672
+ return x, output_size
673
+
674
+
675
+ class WindowMSA(BaseModule):
676
+ """Window based multi-head self-attention (W-MSA) module with relative
677
+ position bias.
678
+ Args:
679
+ embed_dims (int): Number of input channels.
680
+ num_heads (int): Number of attention heads.
681
+ window_size (tuple[int]): The height and width of the window.
682
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
683
+ Default: True.
684
+ qk_scale (float | None, optional): Override default qk scale of
685
+ head_dim ** -0.5 if set. Default: None.
686
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
687
+ Default: 0.0
688
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
689
+ init_cfg (dict | None, optional): The Config for initialization.
690
+ Default: None.
691
+ """
692
+
693
+ def __init__(
694
+ self,
695
+ embed_dims,
696
+ num_heads,
697
+ window_size,
698
+ qkv_bias=True,
699
+ qk_scale=None,
700
+ attn_drop_rate=0.0,
701
+ proj_drop_rate=0.0,
702
+ init_cfg=None,
703
+ ):
704
+ super().__init__()
705
+ self.embed_dims = embed_dims
706
+ self.window_size = window_size # Wh, Ww
707
+ self.num_heads = num_heads
708
+ head_embed_dims = embed_dims // num_heads
709
+ self.scale = qk_scale or head_embed_dims**-0.5
710
+ self.init_cfg = init_cfg
711
+
712
+ # define a parameter table of relative position bias
713
+ self.relative_position_bias_table = nn.Parameter(
714
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
715
+ ) # 2*Wh-1 * 2*Ww-1, nH
716
+
717
+ # About 2x faster than original impl
718
+ Wh, Ww = self.window_size
719
+ rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
720
+ rel_position_index = rel_index_coords + rel_index_coords.T
721
+ rel_position_index = rel_position_index.flip(1).contiguous()
722
+ self.register_buffer("relative_position_index", rel_position_index)
723
+
724
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
725
+ self.attn_drop = nn.Dropout(attn_drop_rate)
726
+ self.proj = nn.Linear(embed_dims, embed_dims)
727
+ self.proj_drop = nn.Dropout(proj_drop_rate)
728
+
729
+ self.softmax = nn.Softmax(dim=-1)
730
+
731
+ def init_weights(self):
732
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
733
+
734
+ def forward(self, x, mask, N, C, nW):
735
+ """
736
+ Args:
737
+ x (tensor): input features with shape of (nW*B, N, C)
738
+ mask (tensor | None, Optional): mask with shape of (nW,
739
+ Wh*Ww, Wh*Ww), value should be between (-inf, 0].
740
+ """
741
+ nWB = x.shape[0]
742
+
743
+ qkv = (
744
+ self.qkv(x)
745
+ .reshape(x.shape[0], N, 3, self.num_heads, C // self.num_heads)
746
+ .permute(2, 0, 3, 1, 4)
747
+ )
748
+ # make torchscript happy (cannot use tensor as tuple)
749
+ q, k, v = qkv[0], qkv[1], qkv[2]
750
+
751
+ q = q * self.scale
752
+ attn = q @ k.transpose(-2, -1)
753
+
754
+ relative_position_bias = self.relative_position_bias_table[
755
+ self.relative_position_index.view(
756
+ (
757
+ self.window_size[0]
758
+ * self.window_size[1]
759
+ * self.window_size[0]
760
+ * self.window_size[1],
761
+ )
762
+ )
763
+ ].view(
764
+ self.window_size[0] * self.window_size[1],
765
+ self.window_size[0] * self.window_size[1],
766
+ self.num_heads,
767
+ ) # Wh*Ww,Wh*Ww,nH
768
+
769
+ relative_position_bias = relative_position_bias.permute(
770
+ 2, 0, 1
771
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
772
+ attn = attn + relative_position_bias.unsqueeze(0)
773
+
774
+ if mask is not None:
775
+ nW = mask.shape[0]
776
+ attn = attn.view(nWB // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
777
+ 1
778
+ ).unsqueeze(0)
779
+ attn = attn.view(nWB, self.num_heads, N, N)
780
+ attn = self.softmax(attn)
781
+
782
+ attn = self.attn_drop(attn)
783
+
784
+ x = (attn @ v).transpose(1, 2).reshape(nWB, N, C)
785
+ x = self.proj(x)
786
+ x = self.proj_drop(x)
787
+ return x
788
+
789
+ @staticmethod
790
+ def double_step_seq(step1, len1, step2, len2):
791
+ seq1 = torch.arange(0, step1 * len1, step1)
792
+ seq2 = torch.arange(0, step2 * len2, step2)
793
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
794
+
795
+
796
+ class ShiftWindowMSA(BaseModule):
797
+ """Shifted Window Multihead Self-Attention Module.
798
+ Args:
799
+ embed_dims (int): Number of input channels.
800
+ num_heads (int): Number of attention heads.
801
+ window_size (int): The height and width of the window.
802
+ shift_size (int, optional): The shift step of each window towards
803
+ right-bottom. If zero, act as regular window-msa. Defaults to 0.
804
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
805
+ Default: True
806
+ qk_scale (float | None, optional): Override default qk scale of
807
+ head_dim ** -0.5 if set. Defaults: None.
808
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
809
+ Defaults: 0.
810
+ proj_drop_rate (float, optional): Dropout ratio of output.
811
+ Defaults: 0.
812
+ dropout_layer (dict, optional): The dropout_layer used before output.
813
+ Defaults: dict(type='DropPath', drop_prob=0.).
814
+ init_cfg (dict, optional): The extra config for initialization.
815
+ Default: None.
816
+ """
817
+
818
+ def __init__(
819
+ self,
820
+ embed_dims,
821
+ num_heads,
822
+ window_size,
823
+ shift_size=0,
824
+ qkv_bias=True,
825
+ qk_scale=None,
826
+ attn_drop_rate=0,
827
+ proj_drop_rate=0,
828
+ dropout_layer=dict(type="DropPath", drop_prob=0.0),
829
+ init_cfg=None,
830
+ ):
831
+ super().__init__()
832
+
833
+ self.window_size = window_size
834
+ self.shift_size = shift_size
835
+
836
+ self.h_slices = (
837
+ slice(0, -self.window_size),
838
+ slice(-self.window_size, -self.shift_size),
839
+ slice(-self.shift_size, None),
840
+ )
841
+ self.w_slices = (
842
+ slice(0, -self.window_size),
843
+ slice(-self.window_size, -self.shift_size),
844
+ slice(-self.shift_size, None),
845
+ )
846
+
847
+ assert 0 <= self.shift_size < self.window_size
848
+
849
+ self.w_msa = WindowMSA(
850
+ embed_dims=embed_dims,
851
+ num_heads=num_heads,
852
+ window_size=to_2tuple(window_size),
853
+ qkv_bias=qkv_bias,
854
+ qk_scale=qk_scale,
855
+ attn_drop_rate=attn_drop_rate,
856
+ proj_drop_rate=proj_drop_rate,
857
+ init_cfg=None,
858
+ )
859
+
860
+ self.drop = build_dropout(dropout_layer)
861
+
862
+ def forward(self, query, hw_shape):
863
+ B, L, C = query.shape
864
+ H, W = hw_shape
865
+ assert L == H * W, "input feature has wrong size"
866
+ query = query.view(-1, H, W, C)
867
+
868
+ # pad feature maps to multiples of window size
869
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
870
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
871
+
872
+ query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
873
+
874
+ H_pad = H + pad_b
875
+ W_pad = W + pad_r
876
+
877
+ N = self.window_size**2
878
+ nW = H_pad * W_pad // N
879
+
880
+ # cyclic shift
881
+ if self.shift_size > 0:
882
+ shifted_query = torch.roll(
883
+ query, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
884
+ )
885
+
886
+ # calculate attention mask for SW-MSA
887
+ img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
888
+ cnt = 0
889
+ for h in self.h_slices:
890
+ for w in self.w_slices:
891
+ img_mask[:, h, w, :] = cnt
892
+ cnt += 1
893
+
894
+ # nW, window_size, window_size, 1
895
+ mask_windows = self.window_partition(img_mask, H_pad, W_pad, 1, nW)
896
+ mask_windows = mask_windows.view(nW, N)
897
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
898
+ attn_mask = attn_mask.masked_fill(
899
+ attn_mask != 0, float(-100.0)
900
+ ).masked_fill(attn_mask == 0, float(0.0))
901
+ else:
902
+ shifted_query = query
903
+ attn_mask = None
904
+
905
+ # nW*B, window_size, window_size, C
906
+ query_windows = self.window_partition(shifted_query, H_pad, W_pad, C, nW)
907
+
908
+ # nW*B, window_size*window_size, C
909
+ query_windows = query_windows.view(-1, N, C)
910
+
911
+ # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
912
+ attn_windows = self.w_msa(query_windows, attn_mask, N, C, nW)
913
+
914
+ # merge windows
915
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
916
+
917
+ # B H' W' C
918
+ shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, C, nW)
919
+ # reverse cyclic shift
920
+ if self.shift_size > 0:
921
+ x = torch.roll(
922
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
923
+ )
924
+ else:
925
+ x = shifted_x
926
+
927
+ if pad_r > 0 or pad_b:
928
+ x = x[:, :H, :W, :].contiguous()
929
+
930
+ x = x.view(-1, H * W, C)
931
+
932
+ x = self.drop(x)
933
+ return x
934
+
935
+ def window_reverse(self, windows, H, W, C, nW):
936
+ """
937
+ Args:
938
+ windows: (nW*B, window_size, window_size, C)
939
+ H (int): Height of image
940
+ W (int): Width of image
941
+ Returns:
942
+ x: (B, H, W, C)
943
+ """
944
+ window_size = self.window_size
945
+ x = windows.view(
946
+ -1, H // window_size, W // window_size, window_size, window_size, C
947
+ )
948
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
949
+ return x
950
+
951
+ def window_partition(self, x, H, W, C, nW):
952
+ """
953
+ Args:
954
+ x: (B, H, W, C)
955
+ Returns:
956
+ windows: (nW*B, window_size, window_size, C)
957
+ """
958
+ window_size = self.window_size
959
+ x = x.view(
960
+ -1,
961
+ H // window_size,
962
+ window_size,
963
+ W // window_size,
964
+ window_size,
965
+ C,
966
+ )
967
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
968
+ windows = windows.view(-1, window_size, window_size, C)
969
+ return windows
970
+
971
+
972
+ class SwinBlock(BaseModule):
973
+ """ "
974
+ Args:
975
+ embed_dims (int): The feature dimension.
976
+ num_heads (int): Parallel attention heads.
977
+ feedforward_channels (int): The hidden dimension for FFNs.
978
+ window_size (int, optional): The local window scale. Default: 7.
979
+ shift (bool, optional): whether to shift window or not. Default False.
980
+ qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
981
+ qk_scale (float | None, optional): Override default qk scale of
982
+ head_dim ** -0.5 if set. Default: None.
983
+ drop_rate (float, optional): Dropout rate. Default: 0.
984
+ attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
985
+ drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
986
+ act_cfg (dict, optional): The config dict of activation function.
987
+ Default: dict(type='GELU').
988
+ norm_cfg (dict, optional): The config dict of normalization.
989
+ Default: dict(type='LN').
990
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
991
+ will save some memory while slowing down the training speed.
992
+ Default: False.
993
+ init_cfg (dict | list | None, optional): The init config.
994
+ Default: None.
995
+ """
996
+
997
+ def __init__(
998
+ self,
999
+ embed_dims,
1000
+ num_heads,
1001
+ feedforward_channels,
1002
+ window_size=7,
1003
+ shift=False,
1004
+ qkv_bias=True,
1005
+ qk_scale=None,
1006
+ drop_rate=0.0,
1007
+ attn_drop_rate=0.0,
1008
+ drop_path_rate=0.0,
1009
+ act_cfg=dict(type="GELU"),
1010
+ norm_cfg=dict(type="LN"),
1011
+ with_cp=False,
1012
+ init_cfg=None,
1013
+ ):
1014
+ super(SwinBlock, self).__init__()
1015
+
1016
+ self.init_cfg = init_cfg
1017
+ self.with_cp = with_cp
1018
+
1019
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
1020
+ self.attn = ShiftWindowMSA(
1021
+ embed_dims=embed_dims,
1022
+ num_heads=num_heads,
1023
+ window_size=window_size,
1024
+ shift_size=window_size // 2 if shift else 0,
1025
+ qkv_bias=qkv_bias,
1026
+ qk_scale=qk_scale,
1027
+ attn_drop_rate=attn_drop_rate,
1028
+ proj_drop_rate=drop_rate,
1029
+ dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
1030
+ init_cfg=None,
1031
+ )
1032
+
1033
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
1034
+ self.ffn = FFN(
1035
+ embed_dims=embed_dims,
1036
+ feedforward_channels=feedforward_channels,
1037
+ num_fcs=2,
1038
+ ffn_drop=drop_rate,
1039
+ dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
1040
+ act_cfg=act_cfg,
1041
+ add_identity=True,
1042
+ init_cfg=None,
1043
+ )
1044
+
1045
+ def forward(self, x, hw_shape):
1046
+ def _inner_forward(x):
1047
+ identity = x
1048
+ x = self.norm1(x)
1049
+ x = self.attn(x, hw_shape)
1050
+
1051
+ x = x + identity
1052
+
1053
+ identity = x
1054
+ x = self.norm2(x)
1055
+ x = self.ffn(x, identity=identity)
1056
+
1057
+ return x
1058
+
1059
+ if self.with_cp and x.requires_grad:
1060
+ x = cp.checkpoint(_inner_forward, x)
1061
+ else:
1062
+ x = _inner_forward(x)
1063
+
1064
+ return x
1065
+
1066
+
1067
+ class SwinBlockSequence(BaseModule):
1068
+ """Implements one stage in Swin Transformer.
1069
+ Args:
1070
+ embed_dims (int): The feature dimension.
1071
+ num_heads (int): Parallel attention heads.
1072
+ feedforward_channels (int): The hidden dimension for FFNs.
1073
+ depth (int): The number of blocks in this stage.
1074
+ window_size (int, optional): The local window scale. Default: 7.
1075
+ qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
1076
+ qk_scale (float | None, optional): Override default qk scale of
1077
+ head_dim ** -0.5 if set. Default: None.
1078
+ drop_rate (float, optional): Dropout rate. Default: 0.
1079
+ attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
1080
+ drop_path_rate (float | list[float], optional): Stochastic depth
1081
+ rate. Default: 0.
1082
+ downsample (BaseModule | None, optional): The downsample operation
1083
+ module. Default: None.
1084
+ act_cfg (dict, optional): The config dict of activation function.
1085
+ Default: dict(type='GELU').
1086
+ norm_cfg (dict, optional): The config dict of normalization.
1087
+ Default: dict(type='LN').
1088
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
1089
+ will save some memory while slowing down the training speed.
1090
+ Default: False.
1091
+ init_cfg (dict | list | None, optional): The init config.
1092
+ Default: None.
1093
+ """
1094
+
1095
+ def __init__(
1096
+ self,
1097
+ embed_dims,
1098
+ num_heads,
1099
+ feedforward_channels,
1100
+ depth,
1101
+ window_size=7,
1102
+ qkv_bias=True,
1103
+ qk_scale=None,
1104
+ drop_rate=0.0,
1105
+ attn_drop_rate=0.0,
1106
+ drop_path_rate=0.0,
1107
+ downsample=None,
1108
+ act_cfg=dict(type="GELU"),
1109
+ norm_cfg=dict(type="LN"),
1110
+ with_cp=False,
1111
+ init_cfg=None,
1112
+ ):
1113
+ super().__init__()
1114
+
1115
+ if isinstance(drop_path_rate, list):
1116
+ drop_path_rates = drop_path_rate
1117
+ assert len(drop_path_rates) == depth
1118
+ else:
1119
+ drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
1120
+
1121
+ self.blocks = ModuleList()
1122
+ for i in range(depth):
1123
+ block = SwinBlock(
1124
+ embed_dims=embed_dims,
1125
+ num_heads=num_heads,
1126
+ feedforward_channels=feedforward_channels,
1127
+ window_size=window_size,
1128
+ shift=False if i % 2 == 0 else True,
1129
+ qkv_bias=qkv_bias,
1130
+ qk_scale=qk_scale,
1131
+ drop_rate=drop_rate,
1132
+ attn_drop_rate=attn_drop_rate,
1133
+ drop_path_rate=drop_path_rates[i],
1134
+ act_cfg=act_cfg,
1135
+ norm_cfg=norm_cfg,
1136
+ with_cp=with_cp,
1137
+ init_cfg=None,
1138
+ )
1139
+ self.blocks.append(block)
1140
+
1141
+ self.downsample = downsample
1142
+
1143
+ def forward(self, x, hw_shape):
1144
+ for block in self.blocks:
1145
+ x = block(x, hw_shape)
1146
+
1147
+ if self.downsample:
1148
+ x_down, down_hw_shape = self.downsample(x, hw_shape)
1149
+ return x_down, down_hw_shape, x, hw_shape
1150
+ else:
1151
+ return x, hw_shape, x, hw_shape
1152
+
1153
+
1154
+ class SwinTransformer(BaseModule):
1155
+ """Swin Transformer
1156
+ A PyTorch implement of : `Swin Transformer:
1157
+ Hierarchical Vision Transformer using Shifted Windows` -
1158
+ https://arxiv.org/abs/2103.14030
1159
+ Inspiration from
1160
+ https://github.com/microsoft/Swin-Transformer
1161
+ Args:
1162
+ pretrain_img_size (int | tuple[int]): The size of input image when
1163
+ pretrain. Defaults: 224.
1164
+ in_channels (int): The num of input channels.
1165
+ Defaults: 3.
1166
+ embed_dims (int): The feature dimension. Default: 96.
1167
+ patch_size (int | tuple[int]): Patch size. Default: 4.
1168
+ window_size (int): Window size. Default: 7.
1169
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
1170
+ Default: 4.
1171
+ depths (tuple[int]): Depths of each Swin Transformer stage.
1172
+ Default: (2, 2, 6, 2).
1173
+ num_heads (tuple[int]): Parallel attention heads of each Swin
1174
+ Transformer stage. Default: (3, 6, 12, 24).
1175
+ strides (tuple[int]): The patch merging or patch embedding stride of
1176
+ each Swin Transformer stage. (In swin, we set kernel size equal to
1177
+ stride.) Default: (4, 2, 2, 2).
1178
+ out_indices (tuple[int]): Output from which stages.
1179
+ Default: (0, 1, 2, 3).
1180
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key,
1181
+ value. Default: True
1182
+ qk_scale (float | None, optional): Override default qk scale of
1183
+ head_dim ** -0.5 if set. Default: None.
1184
+ patch_norm (bool): If add a norm layer for patch embed and patch
1185
+ merging. Default: True.
1186
+ drop_rate (float): Dropout rate. Defaults: 0.
1187
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
1188
+ drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
1189
+ use_abs_pos_embed (bool): If True, add absolute position embedding to
1190
+ the patch embedding. Defaults: False.
1191
+ act_cfg (dict): Config dict for activation layer.
1192
+ Default: dict(type='LN').
1193
+ norm_cfg (dict): Config dict for normalization layer at
1194
+ output of backone. Defaults: dict(type='LN').
1195
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
1196
+ will save some memory while slowing down the training speed.
1197
+ Default: False.
1198
+ pretrained (str, optional): model pretrained path. Default: None.
1199
+ convert_weights (bool): The flag indicates whether the
1200
+ pre-trained model is from the original repo. We may need
1201
+ to convert some keys to make it compatible.
1202
+ Default: False.
1203
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
1204
+ -1 means not freezing any parameters.
1205
+ init_cfg (dict, optional): The Config for initialization.
1206
+ Defaults to None.
1207
+ """
1208
+
1209
+ def __init__(
1210
+ self,
1211
+ pretrain_img_size=224,
1212
+ in_channels=3,
1213
+ embed_dims=96,
1214
+ patch_size=4,
1215
+ window_size=7,
1216
+ mlp_ratio=4,
1217
+ depths=(2, 2, 6, 2),
1218
+ num_heads=(3, 6, 12, 24),
1219
+ strides=(4, 2, 2, 2),
1220
+ out_indices=(0, 1, 2, 3),
1221
+ qkv_bias=True,
1222
+ qk_scale=None,
1223
+ patch_norm=True,
1224
+ drop_rate=0.0,
1225
+ attn_drop_rate=0.0,
1226
+ drop_path_rate=0.1,
1227
+ use_abs_pos_embed=False,
1228
+ act_cfg=dict(type="GELU"),
1229
+ norm_cfg=dict(type="LN"),
1230
+ with_cp=False,
1231
+ pretrained=None,
1232
+ convert_weights=False,
1233
+ frozen_stages=-1,
1234
+ init_cfg=None,
1235
+ semantic_weight=0.0,
1236
+ ):
1237
+ self.convert_weights = convert_weights
1238
+ self.frozen_stages = frozen_stages
1239
+ if isinstance(pretrain_img_size, int):
1240
+ pretrain_img_size = to_2tuple(pretrain_img_size)
1241
+ elif isinstance(pretrain_img_size, tuple):
1242
+ if len(pretrain_img_size) == 1:
1243
+ pretrain_img_size = to_2tuple(pretrain_img_size[0])
1244
+ assert len(pretrain_img_size) == 2, (
1245
+ f"The size of image should have length 1 or 2, "
1246
+ f"but got {len(pretrain_img_size)}"
1247
+ )
1248
+
1249
+ assert not (
1250
+ init_cfg and pretrained
1251
+ ), "init_cfg and pretrained cannot be specified at the same time"
1252
+ if isinstance(pretrained, str):
1253
+ warnings.warn(
1254
+ "DeprecationWarning: pretrained is deprecated, "
1255
+ 'please use "init_cfg" instead'
1256
+ )
1257
+ self.init_cfg = dict(type="Pretrained", checkpoint=pretrained)
1258
+ elif pretrained is None:
1259
+ self.init_cfg = init_cfg
1260
+ else:
1261
+ raise TypeError("pretrained must be a str or None")
1262
+
1263
+ super(SwinTransformer, self).__init__()
1264
+
1265
+ num_layers = len(depths)
1266
+ self.out_indices = out_indices
1267
+ self.use_abs_pos_embed = use_abs_pos_embed
1268
+
1269
+ assert strides[0] == patch_size, "Use non-overlapping patch embed."
1270
+
1271
+ self.patch_embed = PatchEmbed(
1272
+ in_channels=in_channels,
1273
+ embed_dims=embed_dims,
1274
+ conv_type="Conv2d",
1275
+ kernel_size=patch_size,
1276
+ stride=strides[0],
1277
+ norm_cfg=norm_cfg if patch_norm else None,
1278
+ init_cfg=None,
1279
+ )
1280
+
1281
+ if self.use_abs_pos_embed:
1282
+ patch_row = pretrain_img_size[0] // patch_size
1283
+ patch_col = pretrain_img_size[1] // patch_size
1284
+ num_patches = patch_row * patch_col
1285
+ self.absolute_pos_embed = nn.Parameter(
1286
+ torch.zeros((1, num_patches, embed_dims))
1287
+ )
1288
+
1289
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
1290
+
1291
+ # set stochastic depth decay rule
1292
+ total_depth = sum(depths)
1293
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)]
1294
+
1295
+ self.stages = ModuleList()
1296
+ in_channels = embed_dims
1297
+ for i in range(num_layers):
1298
+ if i < num_layers - 1:
1299
+ downsample = PatchMerging(
1300
+ in_channels=in_channels,
1301
+ out_channels=2 * in_channels,
1302
+ stride=strides[i + 1],
1303
+ norm_cfg=norm_cfg if patch_norm else None,
1304
+ init_cfg=None,
1305
+ )
1306
+ else:
1307
+ downsample = None
1308
+
1309
+ stage = SwinBlockSequence(
1310
+ embed_dims=in_channels,
1311
+ num_heads=num_heads[i],
1312
+ feedforward_channels=mlp_ratio * in_channels,
1313
+ depth=depths[i],
1314
+ window_size=window_size,
1315
+ qkv_bias=qkv_bias,
1316
+ qk_scale=qk_scale,
1317
+ drop_rate=drop_rate,
1318
+ attn_drop_rate=attn_drop_rate,
1319
+ drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
1320
+ downsample=downsample,
1321
+ act_cfg=act_cfg,
1322
+ norm_cfg=norm_cfg,
1323
+ with_cp=with_cp,
1324
+ init_cfg=None,
1325
+ )
1326
+ self.stages.append(stage)
1327
+ if downsample:
1328
+ in_channels = downsample.out_channels
1329
+
1330
+ self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
1331
+ # Add a norm layer for each output
1332
+ for i in out_indices:
1333
+ layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
1334
+ layer_name = f"norm{i}"
1335
+ self.add_module(layer_name, layer)
1336
+
1337
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
1338
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
1339
+
1340
+ # semantic embedding
1341
+ self.semantic_weight = semantic_weight
1342
+ if self.semantic_weight >= 0:
1343
+ self.semantic_embed_w = ModuleList()
1344
+ self.semantic_embed_b = ModuleList()
1345
+ for i in range(len(depths)):
1346
+ if i >= len(depths) - 1:
1347
+ i = len(depths) - 2
1348
+ semantic_embed_w = nn.Linear(2, self.num_features[i + 1])
1349
+ semantic_embed_b = nn.Linear(2, self.num_features[i + 1])
1350
+ # TODO: Test with semantic embed unfreeze
1351
+ for param in semantic_embed_w.parameters():
1352
+ param.requires_grad = False
1353
+ for param in semantic_embed_b.parameters():
1354
+ param.requires_grad = False
1355
+ trunc_normal_init(semantic_embed_w, std=0.02, bias=0.0)
1356
+ trunc_normal_init(semantic_embed_b, std=0.02, bias=0.0)
1357
+ self.semantic_embed_w.append(semantic_embed_w)
1358
+ self.semantic_embed_b.append(semantic_embed_b)
1359
+ self.softplus = nn.Softplus()
1360
+
1361
+ def train(self, mode=True):
1362
+ """Convert the model into training mode while keep layers freezed."""
1363
+ super(SwinTransformer, self).train(mode)
1364
+ self._freeze_stages()
1365
+
1366
+ def _freeze_stages(self):
1367
+ if self.frozen_stages >= 0:
1368
+ self.patch_embed.eval()
1369
+ for param in self.patch_embed.parameters():
1370
+ param.requires_grad = False
1371
+ if self.use_abs_pos_embed:
1372
+ self.absolute_pos_embed.requires_grad = False
1373
+ self.drop_after_pos.eval()
1374
+
1375
+ for i in range(1, self.frozen_stages + 1):
1376
+ if (i - 1) in self.out_indices:
1377
+ norm_layer = getattr(self, f"norm{i-1}")
1378
+ norm_layer.eval()
1379
+ for param in norm_layer.parameters():
1380
+ param.requires_grad = False
1381
+
1382
+ m = self.stages[i - 1]
1383
+ m.eval()
1384
+ for param in m.parameters():
1385
+ param.requires_grad = False
1386
+
1387
+ def init_weights(self, pretrained=None):
1388
+ logger = logging.getLogger("loading parameters.")
1389
+ if pretrained is None:
1390
+ logger.warn(
1391
+ f"No pre-trained weights for "
1392
+ f"{self.__class__.__name__}, "
1393
+ f"training start from scratch"
1394
+ )
1395
+ if self.use_abs_pos_embed:
1396
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1397
+ for m in self.modules():
1398
+ if isinstance(m, nn.Linear):
1399
+ trunc_normal_init(m, std=0.02, bias=0.0)
1400
+ elif isinstance(m, nn.LayerNorm):
1401
+ constant_init(m.bias, 0)
1402
+ constant_init(m.weight, 1.0)
1403
+ else:
1404
+ ckpt = torch.load(pretrained, map_location="cpu")
1405
+ if "teacher" in ckpt:
1406
+ ckpt = ckpt["teacher"]
1407
+
1408
+ if "state_dict" in ckpt:
1409
+ _state_dict = ckpt["state_dict"]
1410
+ elif "model" in ckpt:
1411
+ _state_dict = ckpt["model"]
1412
+ else:
1413
+ _state_dict = ckpt
1414
+ if self.convert_weights:
1415
+ # supported loading weight from original repo,
1416
+ _state_dict = swin_converter(_state_dict)
1417
+
1418
+ state_dict = OrderedDict()
1419
+ for k, v in _state_dict.items():
1420
+ if k.startswith("backbone."):
1421
+ state_dict[k[9:]] = v
1422
+
1423
+ # strip prefix of state_dict
1424
+ if list(state_dict.keys())[0].startswith("module."):
1425
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
1426
+
1427
+ # reshape absolute position embedding
1428
+ if state_dict.get("absolute_pos_embed") is not None:
1429
+ absolute_pos_embed = state_dict["absolute_pos_embed"]
1430
+ N1, L, C1 = absolute_pos_embed.size()
1431
+ N2, C2, H, W = self.absolute_pos_embed.size()
1432
+ if N1 != N2 or C1 != C2 or L != H * W:
1433
+ logger.warning("Error in loading absolute_pos_embed, pass")
1434
+ else:
1435
+ state_dict["absolute_pos_embed"] = (
1436
+ absolute_pos_embed.view(N2, H, W, C2)
1437
+ .permute(0, 3, 1, 2)
1438
+ .contiguous()
1439
+ )
1440
+
1441
+ # interpolate position bias table if needed
1442
+ relative_position_bias_table_keys = [
1443
+ k for k in state_dict.keys() if "relative_position_bias_table" in k
1444
+ ]
1445
+ for table_key in relative_position_bias_table_keys:
1446
+ table_pretrained = state_dict[table_key]
1447
+ table_current = self.state_dict()[table_key]
1448
+ L1, nH1 = table_pretrained.size()
1449
+ L2, nH2 = table_current.size()
1450
+ if nH1 != nH2:
1451
+ logger.warning(f"Error in loading {table_key}, pass")
1452
+ elif L1 != L2:
1453
+ S1 = int(L1**0.5)
1454
+ S2 = int(L2**0.5)
1455
+ table_pretrained_resized = F.interpolate(
1456
+ table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
1457
+ size=(S2, S2),
1458
+ mode="bicubic",
1459
+ )
1460
+ state_dict[table_key] = (
1461
+ table_pretrained_resized.view(nH2, L2)
1462
+ .permute(1, 0)
1463
+ .contiguous()
1464
+ )
1465
+
1466
+ res = self.load_state_dict(state_dict, False)
1467
+ print("unloaded parameters:", res)
1468
+
1469
+ def forward(self, x, semantic_weight=None):
1470
+ if self.semantic_weight >= 0 and semantic_weight == None:
1471
+ w = torch.ones(x.shape[0], 1) * self.semantic_weight
1472
+ w = torch.cat([w, 1 - w], axis=-1)
1473
+ semantic_weight = w.to(x.device)
1474
+
1475
+ x, hw_shape = self.patch_embed(x)
1476
+
1477
+ if self.use_abs_pos_embed:
1478
+ x = x + self.absolute_pos_embed
1479
+ x = self.drop_after_pos(x)
1480
+
1481
+ outs = []
1482
+ for i, stage in enumerate(self.stages):
1483
+ x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
1484
+ if self.semantic_weight >= 0:
1485
+ sw = self.semantic_embed_w[i](semantic_weight).unsqueeze(1)
1486
+ sb = self.semantic_embed_b[i](semantic_weight).unsqueeze(1)
1487
+ x = x * self.softplus(sw) + sb
1488
+ if i in self.out_indices:
1489
+ norm_layer = getattr(self, f"norm{i}")
1490
+ out = norm_layer(out)
1491
+ # out = (
1492
+ # out.view(-1, out_hw_shape[0], out_hw_shape[1], self.num_features[i])
1493
+ # .permute(0, 3, 1, 2)
1494
+ # .contiguous()
1495
+ # )
1496
+ outs.append(out)
1497
+
1498
+ x = outs[-1]
1499
+
1500
+ x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
1501
+
1502
+ x = torch.cat([x_cls.transpose(1, 2), x], dim=1)
1503
+
1504
+ return x
1505
+
1506
+
1507
+ def swin_base_patch4_window7_224(
1508
+ img_size=224, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, **kwargs
1509
+ ):
1510
+ model = SwinTransformer(
1511
+ pretrain_img_size=img_size,
1512
+ patch_size=4,
1513
+ window_size=7,
1514
+ embed_dims=128,
1515
+ depths=(2, 2, 18, 2),
1516
+ num_heads=(4, 8, 16, 32),
1517
+ drop_path_rate=drop_path_rate,
1518
+ drop_rate=drop_rate,
1519
+ attn_drop_rate=attn_drop_rate,
1520
+ **kwargs,
1521
+ )
1522
+ return model
1523
+
1524
+
1525
+ def swin_small_patch4_window7_224(
1526
+ img_size=224, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, **kwargs
1527
+ ):
1528
+ model = SwinTransformer(
1529
+ pretrain_img_size=img_size,
1530
+ patch_size=4,
1531
+ window_size=7,
1532
+ embed_dims=96,
1533
+ depths=(2, 2, 18, 2),
1534
+ num_heads=(3, 6, 12, 24),
1535
+ drop_path_rate=drop_path_rate,
1536
+ drop_rate=drop_rate,
1537
+ attn_drop_rate=attn_drop_rate,
1538
+ **kwargs,
1539
+ )
1540
+ return model
1541
+
1542
+
1543
+ def swin_tiny_patch4_window7_224(
1544
+ img_size=224, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, **kwargs
1545
+ ):
1546
+ model = SwinTransformer(
1547
+ pretrain_img_size=img_size,
1548
+ patch_size=4,
1549
+ window_size=7,
1550
+ embed_dims=96,
1551
+ depths=(2, 2, 6, 2),
1552
+ num_heads=(3, 6, 12, 24),
1553
+ drop_path_rate=drop_path_rate,
1554
+ drop_rate=drop_rate,
1555
+ attn_drop_rate=attn_drop_rate,
1556
+ **kwargs,
1557
+ )
1558
+ return model
1559
+
1560
+
1561
+ def build_solider(cfg: dict) -> SwinTransformer:
1562
+ name = cfg["name"]
1563
+ img_size = cfg["img_size"]
1564
+ # drop_path_rate = cfg["drop_path_rate"]\
1565
+ # TODO: Test with drop_path_rate = 0.0
1566
+ drop_path_rate = 0.1
1567
+ # drop_rate = cfg["drop_rate"]
1568
+ drop_rate = 0.0
1569
+ # attn_drop_rate = cfg["attn_drop_rate"]
1570
+ attn_drop_rate = 0.0
1571
+ pretrained = cfg["pretrained"]
1572
+ # convert_weights = cfg["convert_weights"]
1573
+ convert_weights = False
1574
+ semantic_weight = cfg["semantic_weight"]
1575
+
1576
+ if name == "swin_tiny_patch4_window7_224":
1577
+ model = swin_tiny_patch4_window7_224(
1578
+ img_size=img_size,
1579
+ drop_path_rate=drop_path_rate,
1580
+ drop_rate=drop_rate,
1581
+ attn_drop_rate=attn_drop_rate,
1582
+ pretrained=pretrained,
1583
+ convert_weights=convert_weights,
1584
+ semantic_weight=semantic_weight,
1585
+ )
1586
+
1587
+ elif name == "swin_small_patch4_window7_224":
1588
+ model = swin_small_patch4_window7_224(
1589
+ img_size=img_size,
1590
+ drop_path_rate=drop_path_rate,
1591
+ drop_rate=drop_rate,
1592
+ attn_drop_rate=attn_drop_rate,
1593
+ pretrained=pretrained,
1594
+ convert_weights=convert_weights,
1595
+ semantic_weight=semantic_weight,
1596
+ )
1597
+
1598
+ elif name == "swin_base_patch4_window7_224":
1599
+ model = swin_base_patch4_window7_224(
1600
+ img_size=img_size,
1601
+ drop_path_rate=drop_path_rate,
1602
+ drop_rate=drop_rate,
1603
+ attn_drop_rate=attn_drop_rate,
1604
+ pretrained=pretrained,
1605
+ convert_weights=convert_weights,
1606
+ semantic_weight=semantic_weight,
1607
+ )
1608
+
1609
+ else:
1610
+ raise RuntimeError(f"Not support model name: {name}")
1611
+
1612
+ if pretrained != "":
1613
+ if os.path.exists(pretrained):
1614
+ model.init_weights(pretrained)
1615
+ else:
1616
+ warnings.warn(f"pretrained: {pretrained} not exists")
1617
+
1618
+ return model
1619
+
1620
+
1621
+ # BACKBONE_NAME2WIDTH = {
1622
+ # "swin_tiny_patch4_window7_224": 768,
1623
+ # "swin_small_patch4_window7_224": 768,
1624
+ # "swin_base_patch4_window7_224": 1024,
1625
+ # "solider_tiny": 768,
1626
+ # "solider_small": 768,
1627
+ # "solider_base": 1024,
1628
+ # }
1629
+
1630
+
1631
+
1632
+ SOLIDER_BASE_MODEL_CONFIG_PARAMETERS = {
1633
+ "pretrain_img_size": [224, 224],
1634
+ "in_channels": 3,
1635
+ "embed_dims": 128,
1636
+ "patch_size": 4,
1637
+ "window_size": 7,
1638
+ "mlp_ratio": 4,
1639
+ "depths": (2, 2, 18, 2),
1640
+ "num_heads": (4, 8, 16, 32),
1641
+ "strides": (4, 2, 2, 2),
1642
+ "out_indices": (0, 1, 2, 3),
1643
+ "qkv_bias": True,
1644
+ "qk_scale": None,
1645
+ "patch_norm": True,
1646
+ "drop_rate": 0.0,
1647
+ "attn_drop_rate": 0.0,
1648
+ "drop_path_rate": 0.0,
1649
+ "use_abs_pos_embed": False,
1650
+ "act_cfg": dict(type="GELU"),
1651
+ "norm_cfg": dict(type="LN"),
1652
+ "with_cp": False,
1653
+ "pretrained": None,
1654
+ "convert_weights": False,
1655
+ "frozen_stages": -1,
1656
+ "init_cfg": None,
1657
+ "semantic_weight": 0.2,
1658
+ "name": "solider_base",
1659
+ }
1660
+
1661
+ SOLIDER_SMALL_MODEL_CONFIG_PARAMETERS = {
1662
+ "pretrain_img_size": [224, 224],
1663
+ "in_channels": 3,
1664
+ "embed_dims": 96,
1665
+ "patch_size": 4,
1666
+ "window_size": 7,
1667
+ "mlp_ratio": 4,
1668
+ "depths": (2, 2, 18, 2),
1669
+ "num_heads": (3, 6, 12, 24),
1670
+ "strides": (4, 2, 2, 2),
1671
+ "out_indices": (0, 1, 2, 3),
1672
+ "qkv_bias": True,
1673
+ "qk_scale": None,
1674
+ "patch_norm": True,
1675
+ "drop_rate": 0.0,
1676
+ "attn_drop_rate": 0.0,
1677
+ "drop_path_rate": 0.0,
1678
+ "use_abs_pos_embed": False,
1679
+ "act_cfg": dict(type="GELU"),
1680
+ "norm_cfg": dict(type="LN"),
1681
+ "with_cp": False,
1682
+ "pretrained": None,
1683
+ "convert_weights": False,
1684
+ "frozen_stages": -1,
1685
+ "init_cfg": None,
1686
+ "semantic_weight": 0.2,
1687
+ "name": "solider_small",
1688
+ }
1689
+
1690
+ SOLIDER_TINY_MODEL_CONFIG_PARAMETERS = {
1691
+ "pretrain_img_size": [224, 224],
1692
+ "in_channels": 3,
1693
+ "embed_dims": 96,
1694
+ "patch_size": 4,
1695
+ "window_size": 7,
1696
+ "mlp_ratio": 4,
1697
+ "depths": (2, 2, 6, 2),
1698
+ "num_heads": (3, 6, 12, 24),
1699
+ "strides": (4, 2, 2, 2),
1700
+ "out_indices": (0, 1, 2, 3),
1701
+ "qkv_bias": True,
1702
+ "qk_scale": None,
1703
+ "patch_norm": True,
1704
+ "drop_rate": 0.0,
1705
+ "attn_drop_rate": 0.0,
1706
+ "drop_path_rate": 0.0,
1707
+ "use_abs_pos_embed": False,
1708
+ "act_cfg": dict(type="GELU"),
1709
+ "norm_cfg": dict(type="LN"),
1710
+ "with_cp": False,
1711
+ "pretrained": None,
1712
+ "convert_weights": False,
1713
+ "frozen_stages": -1,
1714
+ "init_cfg": None,
1715
+ "semantic_weight": 0.2,
1716
+ "name": "solider_tiny",
1717
+ }
1718
+
1719
+ SOLIDER_BASE_CONFIG = SOLIDERConfig(**SOLIDER_BASE_MODEL_CONFIG_PARAMETERS)
1720
+ SOLIDER_SMALL_CONFIG = SOLIDERConfig(**SOLIDER_SMALL_MODEL_CONFIG_PARAMETERS)
1721
+ SOLIDER_TINY_CONFIG = SOLIDERConfig(**SOLIDER_TINY_MODEL_CONFIG_PARAMETERS)
1722
+
1723
+
1724
+ def build_solider_vision_encoder(weight_path, name="swin_small_patch4_window7_224"):
1725
+ vision_width = BACKBONE_NAME2WIDTH[name]
1726
+ return (
1727
+ build_solider(
1728
+ {
1729
+ "name": name,
1730
+ "img_size": [384, 128],
1731
+ "pretrained": weight_path,
1732
+ "semantic_weight": 0.2,
1733
+ }
1734
+ ),
1735
+ vision_width,
1736
+ )
1737
+
1738
+
1739
+ class SOLIDERModel(PreTrainedModel):
1740
+ config_class = SOLIDERConfig
1741
+ base_model_prefix = "solider"
1742
+
1743
+ def __init__(self, config: SOLIDERConfig):
1744
+ super().__init__(config)
1745
+ self.solider = SwinTransformer(
1746
+ pretrain_img_size=config.pretrain_img_size,
1747
+ embed_dims=config.embed_dims,
1748
+ patch_size=config.patch_size,
1749
+ window_size=config.window_size,
1750
+ mlp_ratio=config.mlp_ratio,
1751
+ depths=config.depths,
1752
+ num_heads=config.num_heads,
1753
+ strides=config.strides,
1754
+ out_indices=config.out_indices,
1755
+ qkv_bias=config.qkv_bias,
1756
+ qk_scale=config.qk_scale,
1757
+ patch_norm=config.patch_norm,
1758
+ drop_rate=config.drop_rate,
1759
+ attn_drop_rate=config.attn_drop_rate,
1760
+ drop_path_rate=config.drop_path_rate,
1761
+ use_abs_pos_embed=config.use_abs_pos_embed,
1762
+ act_cfg=config.act_cfg,
1763
+ norm_cfg=config.norm_cfg,
1764
+ with_cp=config.with_cp,
1765
+ pretrained=config.pretrained,
1766
+ convert_weights=config.convert_weights,
1767
+ frozen_stages=config.frozen_stages,
1768
+ init_cfg=config.init_cfg,
1769
+ semantic_weight=config.semantic_weight,
1770
+ )
1771
+ self.solider_name = config.name
1772
+ self.vision_width = BACKBONE_NAME2WIDTH[self.solider_name]
1773
+ self.hidden_size = self.vision_width
1774
+
1775
+ self.config = config
1776
+ # self.init_weights()
1777
+
1778
+ def forward(self, x, semantic_weight=None):
1779
+ # if semantic_weight is None, use the default value from config
1780
+ return self.solider(x, semantic_weight)
1781
+
1782
+
1783
+ class SoliderEncoder(SwinTransformer):
1784
+ options = [
1785
+ "swin_tiny_patch4_window7_224",
1786
+ "swin_small_patch4_window7_224",
1787
+ "swin_base_patch4_window7_224",
1788
+ ]
1789
+
1790
+ @classmethod
1791
+ def from_config(cls, cfg, from_pretrained=None):
1792
+ name = cfg.get("name", "swin_small_patch4_window7_224")
1793
+ img_size = cfg.get("img_size", [384, 128])
1794
+ drop_path_rate = cfg.get("drop_path_rate", 0.1)
1795
+ drop_rate = cfg.get("drop_rate", 0.0)
1796
+ attn_drop_rate = cfg.get("attn_drop_rate", 0.0)
1797
+ pretrained = cfg.get("pretrained", None)
1798
+ convert_weights = cfg.get("convert_weights", False)
1799
+ semantic_weight = cfg.get("semantic_weight", 0.2)
1800
+ if name == "swin_tiny_patch4_window7_224" or name == "tiny":
1801
+ model = swin_tiny_patch4_window7_224(
1802
+ img_size=img_size,
1803
+ drop_path_rate=drop_path_rate,
1804
+ drop_rate=drop_rate,
1805
+ attn_drop_rate=attn_drop_rate,
1806
+ pretrained=pretrained,
1807
+ convert_weights=convert_weights,
1808
+ semantic_weight=semantic_weight,
1809
+ )
1810
+ elif name == "swin_small_patch4_window7_224" or name == "small":
1811
+ model = swin_small_patch4_window7_224(
1812
+ img_size=img_size,
1813
+ drop_path_rate=drop_path_rate,
1814
+ drop_rate=drop_rate,
1815
+ attn_drop_rate=attn_drop_rate,
1816
+ pretrained=pretrained,
1817
+ convert_weights=convert_weights,
1818
+ semantic_weight=semantic_weight,
1819
+ )
1820
+
1821
+ elif name == "swin_base_patch4_window7_224" or name == "base":
1822
+ model = swin_base_patch4_window7_224(
1823
+ img_size=img_size,
1824
+ drop_path_rate=drop_path_rate,
1825
+ drop_rate=drop_rate,
1826
+ attn_drop_rate=attn_drop_rate,
1827
+ pretrained=pretrained,
1828
+ convert_weights=convert_weights,
1829
+ semantic_weight=semantic_weight,
1830
+ )
1831
+ model.vision_width = BACKBONE_NAME2WIDTH[name]
1832
+ if from_pretrained is not None:
1833
+ print("begin load pretrained model solider")
1834
+ state_dict_vision_encoder = torch.load(from_pretrained, map_location="cpu")
1835
+ msg = model.load_state_dict(state_dict_vision_encoder)
1836
+ print(msg)
1837
+ return model
1838
+
1839
+ def forward_features(self, x, semantic_weight=None):
1840
+ return SwinTransformer.forward(self, x, semantic_weight)