lieding1994 commited on
Commit
f83ff13
·
verified ·
1 Parent(s): b3597bc

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +70 -0
  3. configuration_davit.py +50 -0
  4. model.safetensors +3 -0
  5. modeling_davit.py +661 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DaViTModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_davit.DaViTConfig",
7
+ "AutoModel": "modeling_davit.DaViTModel"
8
+ },
9
+ "conv_at_attn": true,
10
+ "conv_at_ffn": true,
11
+ "depths": [
12
+ 1,
13
+ 1,
14
+ 9,
15
+ 1
16
+ ],
17
+ "drop_path_rate": 0.1,
18
+ "embed_dims": [
19
+ 128,
20
+ 256,
21
+ 512,
22
+ 1024
23
+ ],
24
+ "enable_checkpoint": false,
25
+ "in_chans": 3,
26
+ "mlp_ratio": 4.0,
27
+ "model_type": "davit",
28
+ "norm_layer": "layer_norm",
29
+ "num_groups": [
30
+ 4,
31
+ 8,
32
+ 16,
33
+ 32
34
+ ],
35
+ "num_heads": [
36
+ 4,
37
+ 8,
38
+ 16,
39
+ 32
40
+ ],
41
+ "patch_padding": [
42
+ 3,
43
+ 1,
44
+ 1,
45
+ 1
46
+ ],
47
+ "patch_prenorm": [
48
+ false,
49
+ true,
50
+ true,
51
+ true
52
+ ],
53
+ "patch_size": [
54
+ 7,
55
+ 3,
56
+ 3,
57
+ 3
58
+ ],
59
+ "patch_stride": [
60
+ 4,
61
+ 2,
62
+ 2,
63
+ 2
64
+ ],
65
+ "projection_dim": 768,
66
+ "qkv_bias": true,
67
+ "torch_dtype": "float16",
68
+ "transformers_version": "4.44.2",
69
+ "window_size": 12
70
+ }
configuration_davit.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ # Define configuration class
5
+ class DaViTConfig(PretrainedConfig):
6
+ model_type = "davit"
7
+
8
+ def __init__(
9
+ self,
10
+ in_chans=3,
11
+ # num_classes=1000,
12
+ depths=(1, 1, 9, 1),
13
+ patch_size=(7, 3, 3, 3),
14
+ patch_stride=(4, 2, 2, 2),
15
+ patch_padding=(3, 1, 1, 1),
16
+ patch_prenorm=(False, True, True, True),
17
+ embed_dims=(128, 256, 512, 1024),
18
+ num_heads=(4, 8, 16, 32),
19
+ num_groups=(4, 8, 16, 32),
20
+ window_size=12,
21
+ mlp_ratio=4.0,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.1,
24
+ norm_layer="layer_norm",
25
+ enable_checkpoint=False,
26
+ conv_at_attn=True,
27
+ conv_at_ffn=True,
28
+ projection_dim=768,
29
+ **kwargs
30
+ ):
31
+ super().__init__(**kwargs)
32
+ self.in_chans = in_chans
33
+ # self.num_classes = num_classes # Classes remove for AutoModel
34
+ self.depths = depths
35
+ self.patch_size = patch_size
36
+ self.patch_stride = patch_stride
37
+ self.patch_padding = patch_padding
38
+ self.patch_prenorm = patch_prenorm
39
+ self.embed_dims = embed_dims
40
+ self.num_heads = num_heads
41
+ self.num_groups = num_groups
42
+ self.window_size = window_size
43
+ self.mlp_ratio = mlp_ratio
44
+ self.qkv_bias = qkv_bias
45
+ self.drop_path_rate = drop_path_rate
46
+ self.norm_layer = norm_layer
47
+ self.enable_checkpoint = enable_checkpoint
48
+ self.conv_at_attn = conv_at_attn
49
+ self.conv_at_ffn = conv_at_ffn
50
+ self.projection_dim = projection_dim
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1643581e06858036d6ab364cdc49c0f7c564175c7b201a29b6b1dd179dba379c
3
+ size 182836040
modeling_davit.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch DaViT model."""
17
+
18
+
19
+ import math
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint as checkpoint
25
+ from collections import OrderedDict
26
+ from einops import rearrange
27
+ from timm.models.layers import DropPath, trunc_normal_
28
+
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import logging
31
+
32
+ # Ensure ConvEmbed, SpatialBlock, ChannelBlock, MySequential, etc., are defined before using them
33
+ from .configuration_davit import DaViTConfig
34
+
35
+ from transformers import AutoModel, AutoConfig
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class LearnedAbsolutePositionEmbedding2D(nn.Module):
41
+ """
42
+ This module learns positional embeddings up to a fixed maximum size.
43
+ """
44
+
45
+ def __init__(self, embedding_dim=256, num_pos=50):
46
+ super().__init__()
47
+ self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
48
+ self.column_embeddings = nn.Embedding(
49
+ num_pos, embedding_dim - (embedding_dim // 2)
50
+ )
51
+
52
+ def forward(self, pixel_values):
53
+ """
54
+ pixel_values: (batch_size, height, width, num_channels)
55
+ returns: (batch_size, height, width, embedding_dim * 2)
56
+ """
57
+ if len(pixel_values.shape) != 4:
58
+ raise ValueError("pixel_values must be a 4D tensor")
59
+ height, width = pixel_values.shape[1:3]
60
+ width_values = torch.arange(width, device=pixel_values.device)
61
+ height_values = torch.arange(height, device=pixel_values.device)
62
+ x_emb = self.column_embeddings(width_values)
63
+ y_emb = self.row_embeddings(height_values)
64
+ # (height, width, embedding_dim * 2)
65
+ pos = torch.cat(
66
+ [
67
+ x_emb.unsqueeze(0).repeat(height, 1, 1),
68
+ y_emb.unsqueeze(1).repeat(1, width, 1),
69
+ ],
70
+ dim=-1,
71
+ )
72
+ # (embedding_dim * 2, height, width)
73
+ pos = pos.permute(2, 0, 1)
74
+ pos = pos.unsqueeze(0)
75
+ # (batch_size, embedding_dim * 2, height, width)
76
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
77
+ # (batch_size, height, width, embedding_dim * 2)
78
+ pos = pos.permute(0, 2, 3, 1)
79
+ return pos
80
+
81
+
82
+ class PositionalEmbeddingCosine1D(nn.Module):
83
+ """
84
+ This class implements a very simple positional encoding. It follows closely
85
+ the encoder from the link below:
86
+ https://pytorch.org/tutorials/beginner/translation_transformer.html
87
+ Args:
88
+ embed_dim: The dimension of the embeddings.
89
+ dropout_prob: The dropout probability.
90
+ max_seq_len: The maximum length to precompute the positional encodings.
91
+ """
92
+
93
+ def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
94
+ super(PositionalEmbeddingCosine1D, self).__init__()
95
+ self.embed_dim = embed_dim
96
+ self.max_seq_len = max_seq_len
97
+ # Generate the sinusoidal arrays.
98
+ factor = math.log(10000)
99
+ denominator = torch.exp(
100
+ -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim
101
+ )
102
+ # Matrix where rows correspond to a positional embedding as a function
103
+ # of the position index (i.e., the row index).
104
+ frequencies = (
105
+ torch.arange(0, self.max_seq_len).reshape(self.max_seq_len, 1) * denominator
106
+ )
107
+ pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
108
+ # Populate uneven entries.
109
+ pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
110
+ pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
111
+ # Save the positional embeddings in a constant buffer.
112
+ self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
113
+
114
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
115
+ """
116
+ Args:
117
+ seq_embeds: The sequence embeddings in order. Allowed size:
118
+ 1. [T, D], where T is the length of the sequence, and D is the
119
+ frame embedding dimension.
120
+ 2. [B, T, D], where B is the batch size and T and D are the
121
+ same as above.
122
+ Returns a tensor of with the same dimensions as the input: i.e.,
123
+ [1, T, D] or [T, D].
124
+ """
125
+ shape_len = len(seq_embeds.shape)
126
+ assert 2 <= shape_len <= 3
127
+ len_seq = seq_embeds.size(-2)
128
+ assert len_seq <= self.max_seq_len
129
+ pos_embeds = self.pos_idx_to_embed[0 : seq_embeds.size(-2), :]
130
+ # Adapt pre-computed positional embeddings to the input.
131
+ if shape_len == 3:
132
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
133
+ return pos_embeds
134
+
135
+
136
+ class LearnedAbsolutePositionEmbedding1D(nn.Module):
137
+ """
138
+ Learnable absolute positional embeddings for 1D sequences.
139
+ Args:
140
+ embed_dim: The dimension of the embeddings.
141
+ max_seq_len: The maximum length to precompute the positional encodings.
142
+ """
143
+
144
+ def __init__(self, embedding_dim: int = 512, num_pos: int = 1024) -> None:
145
+ super(LearnedAbsolutePositionEmbedding1D, self).__init__()
146
+ self.embeddings = nn.Embedding(num_pos, embedding_dim)
147
+ self.num_pos = num_pos
148
+
149
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
150
+ """
151
+ Args:
152
+ seq_embeds: The sequence embeddings in order. Allowed size:
153
+ 1. [T, D], where T is the length of the sequence, and D is the
154
+ frame embedding dimension.
155
+ 2. [B, T, D], where B is the batch size and T and D are the
156
+ same as above.
157
+ Returns a tensor of with the same dimensions as the input: i.e.,
158
+ [1, T, D] or [T, D].
159
+ """
160
+ shape_len = len(seq_embeds.shape)
161
+ assert 2 <= shape_len <= 3
162
+ len_seq = seq_embeds.size(-2)
163
+ assert len_seq <= self.num_pos
164
+ # [T, D]
165
+ pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
166
+ # Adapt pre-computed positional embeddings to the input.
167
+ if shape_len == 3:
168
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
169
+ return pos_embeds
170
+
171
+
172
+ class MySequential(nn.Sequential):
173
+ def forward(self, *inputs):
174
+ for module in self._modules.values():
175
+ if type(inputs) == tuple:
176
+ inputs = module(*inputs)
177
+ else:
178
+ inputs = module(inputs)
179
+ return inputs
180
+
181
+
182
+ class PreNorm(nn.Module):
183
+ def __init__(self, norm, fn, drop_path=None):
184
+ super().__init__()
185
+ self.norm = norm
186
+ self.fn = fn
187
+ self.drop_path = drop_path
188
+
189
+ def forward(self, x, *args, **kwargs):
190
+ shortcut = x
191
+ if self.norm != None:
192
+ x, size = self.fn(self.norm(x), *args, **kwargs)
193
+ else:
194
+ x, size = self.fn(x, *args, **kwargs)
195
+
196
+ if self.drop_path:
197
+ x = self.drop_path(x)
198
+
199
+ x = shortcut + x
200
+
201
+ return x, size
202
+
203
+
204
+ class Mlp(nn.Module):
205
+ def __init__(
206
+ self,
207
+ in_features,
208
+ hidden_features=None,
209
+ out_features=None,
210
+ act_layer=nn.GELU,
211
+ ):
212
+ super().__init__()
213
+ out_features = out_features or in_features
214
+ hidden_features = hidden_features or in_features
215
+ self.net = nn.Sequential(
216
+ OrderedDict(
217
+ [
218
+ ("fc1", nn.Linear(in_features, hidden_features)),
219
+ ("act", act_layer()),
220
+ ("fc2", nn.Linear(hidden_features, out_features)),
221
+ ]
222
+ )
223
+ )
224
+
225
+ def forward(self, x, size):
226
+ return self.net(x), size
227
+
228
+
229
+ class DepthWiseConv2d(nn.Module):
230
+ def __init__(
231
+ self,
232
+ dim_in,
233
+ kernel_size,
234
+ padding,
235
+ stride,
236
+ bias=True,
237
+ ):
238
+ super().__init__()
239
+ self.dw = nn.Conv2d(
240
+ dim_in,
241
+ dim_in,
242
+ kernel_size=kernel_size,
243
+ padding=padding,
244
+ groups=dim_in,
245
+ stride=stride,
246
+ bias=bias,
247
+ )
248
+
249
+ def forward(self, x, size):
250
+ B, N, C = x.shape
251
+ H, W = size
252
+ assert N == H * W
253
+
254
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
255
+ size = (x.size(-2), x.size(-1))
256
+ x = x.flatten(2).transpose(1, 2)
257
+ return x, size
258
+
259
+
260
+ class ConvEmbed(nn.Module):
261
+ """Image to Patch Embedding"""
262
+
263
+ def __init__(
264
+ self,
265
+ patch_size=7,
266
+ in_chans=3,
267
+ embed_dim=64,
268
+ stride=4,
269
+ padding=2,
270
+ norm_layer=None,
271
+ pre_norm=True,
272
+ ):
273
+ super().__init__()
274
+ self.patch_size = patch_size
275
+
276
+ self.proj = nn.Conv2d(
277
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding
278
+ )
279
+
280
+ dim_norm = in_chans if pre_norm else embed_dim
281
+ self.norm = norm_layer(dim_norm) if norm_layer else None
282
+
283
+ self.pre_norm = pre_norm
284
+
285
+ def forward(self, x, size):
286
+ H, W = size
287
+ if len(x.size()) == 3:
288
+ if self.norm and self.pre_norm:
289
+ x = self.norm(x)
290
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
291
+
292
+ x = self.proj(x)
293
+
294
+ _, _, H, W = x.shape
295
+ x = rearrange(x, "b c h w -> b (h w) c")
296
+ if self.norm and not self.pre_norm:
297
+ x = self.norm(x)
298
+
299
+ return x, (H, W)
300
+
301
+
302
+ class ChannelAttention(nn.Module):
303
+
304
+ def __init__(self, dim, groups=8, qkv_bias=True):
305
+ super().__init__()
306
+
307
+ self.groups = groups
308
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
309
+ self.proj = nn.Linear(dim, dim)
310
+
311
+ def forward(self, x, size):
312
+ B, N, C = x.shape
313
+
314
+ qkv = (
315
+ self.qkv(x)
316
+ .reshape(B, N, 3, self.groups, C // self.groups)
317
+ .permute(2, 0, 3, 1, 4)
318
+ )
319
+ q, k, v = qkv[0], qkv[1], qkv[2]
320
+
321
+ q = q * (float(N) ** -0.5)
322
+ attention = q.transpose(-1, -2) @ k
323
+ attention = attention.softmax(dim=-1)
324
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
325
+ x = x.transpose(1, 2).reshape(B, N, C)
326
+ x = self.proj(x)
327
+ return x, size
328
+
329
+
330
+ class ChannelBlock(nn.Module):
331
+
332
+ def __init__(
333
+ self,
334
+ dim,
335
+ groups,
336
+ mlp_ratio=4.0,
337
+ qkv_bias=True,
338
+ drop_path_rate=0.0,
339
+ act_layer=nn.GELU,
340
+ norm_layer=nn.LayerNorm,
341
+ conv_at_attn=True,
342
+ conv_at_ffn=True,
343
+ ):
344
+ super().__init__()
345
+
346
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
347
+
348
+ self.conv1 = (
349
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
350
+ )
351
+ self.channel_attn = PreNorm(
352
+ norm_layer(dim),
353
+ ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
354
+ drop_path,
355
+ )
356
+ self.conv2 = (
357
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
358
+ )
359
+ self.ffn = PreNorm(
360
+ norm_layer(dim),
361
+ Mlp(
362
+ in_features=dim,
363
+ hidden_features=int(dim * mlp_ratio),
364
+ act_layer=act_layer,
365
+ ),
366
+ drop_path,
367
+ )
368
+
369
+ def forward(self, x, size):
370
+ if self.conv1:
371
+ x, size = self.conv1(x, size)
372
+ x, size = self.channel_attn(x, size)
373
+
374
+ if self.conv2:
375
+ x, size = self.conv2(x, size)
376
+ x, size = self.ffn(x, size)
377
+
378
+ return x, size
379
+
380
+
381
+ def window_partition(x, window_size: int):
382
+ B, H, W, C = x.shape
383
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
384
+ windows = (
385
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
386
+ )
387
+ return windows
388
+
389
+
390
+ def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
391
+ B = batch_size
392
+ # this will cause onnx conversion failed for dynamic axis, because treated as constant
393
+ # int(windows.shape[0] / (H * W / window_size / window_size))
394
+ x = windows.view(
395
+ B, H // window_size, W // window_size, window_size, window_size, -1
396
+ )
397
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
398
+ return x
399
+
400
+
401
+ class WindowAttention(nn.Module):
402
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
403
+
404
+ super().__init__()
405
+ self.dim = dim
406
+ self.window_size = window_size
407
+ self.num_heads = num_heads
408
+ head_dim = dim // num_heads
409
+ self.scale = float(head_dim) ** -0.5
410
+
411
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
412
+ self.proj = nn.Linear(dim, dim)
413
+
414
+ self.softmax = nn.Softmax(dim=-1)
415
+
416
+ def forward(self, x, size):
417
+
418
+ H, W = size
419
+ B, L, C = x.shape
420
+ assert L == H * W, "input feature has wrong size"
421
+
422
+ x = x.view(B, H, W, C)
423
+
424
+ pad_l = pad_t = 0
425
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
426
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
427
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
428
+ _, Hp, Wp, _ = x.shape
429
+
430
+ x = window_partition(x, self.window_size)
431
+ x = x.view(-1, self.window_size * self.window_size, C)
432
+
433
+ # W-MSA/SW-MSA
434
+ # attn_windows = self.attn(x_windows)
435
+
436
+ B_, N, C = x.shape
437
+ qkv = (
438
+ self.qkv(x)
439
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
440
+ .permute(2, 0, 3, 1, 4)
441
+ )
442
+ q, k, v = qkv[0], qkv[1], qkv[2]
443
+
444
+ q = q * self.scale
445
+ attn = q @ k.transpose(-2, -1)
446
+ attn = self.softmax(attn)
447
+
448
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
449
+ x = self.proj(x)
450
+
451
+ # merge windows
452
+ x = x.view(-1, self.window_size, self.window_size, C)
453
+ x = window_reverse(x, B, self.window_size, Hp, Wp)
454
+
455
+ if pad_r > 0 or pad_b > 0:
456
+ x = x[:, :H, :W, :].contiguous()
457
+
458
+ x = x.view(B, H * W, C)
459
+
460
+ return x, size
461
+
462
+
463
+ class SpatialBlock(nn.Module):
464
+
465
+ def __init__(
466
+ self,
467
+ dim,
468
+ num_heads,
469
+ window_size,
470
+ mlp_ratio=4.0,
471
+ qkv_bias=True,
472
+ drop_path_rate=0.0,
473
+ act_layer=nn.GELU,
474
+ norm_layer=nn.LayerNorm,
475
+ conv_at_attn=True,
476
+ conv_at_ffn=True,
477
+ ):
478
+ super().__init__()
479
+
480
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
481
+
482
+ self.conv1 = (
483
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
484
+ )
485
+ self.window_attn = PreNorm(
486
+ norm_layer(dim),
487
+ WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
488
+ drop_path,
489
+ )
490
+ self.conv2 = (
491
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
492
+ )
493
+ self.ffn = PreNorm(
494
+ norm_layer(dim),
495
+ Mlp(
496
+ in_features=dim,
497
+ hidden_features=int(dim * mlp_ratio),
498
+ act_layer=act_layer,
499
+ ),
500
+ drop_path,
501
+ )
502
+
503
+ def forward(self, x, size):
504
+ if self.conv1:
505
+ x, size = self.conv1(x, size)
506
+ x, size = self.window_attn(x, size)
507
+
508
+ if self.conv2:
509
+ x, size = self.conv2(x, size)
510
+ x, size = self.ffn(x, size)
511
+ return x, size
512
+
513
+
514
+ # Define DaViT model class
515
+ class DaViTModel(PreTrainedModel):
516
+ config_class = DaViTConfig
517
+
518
+ def __init__(self, config: DaViTConfig):
519
+ super().__init__(config)
520
+
521
+ self.num_classes = 1000 # config.num_classes
522
+ self.embed_dims = config.embed_dims
523
+ self.num_heads = config.num_heads
524
+ self.num_groups = config.num_groups
525
+ self.num_stages = len(self.embed_dims)
526
+ self.enable_checkpoint = config.enable_checkpoint
527
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
528
+
529
+ num_stages = len(config.embed_dims)
530
+ dpr = [
531
+ x.item()
532
+ for x in torch.linspace(0, config.drop_path_rate, sum(config.depths) * 2)
533
+ ]
534
+
535
+ depth_offset = 0
536
+ convs = []
537
+ blocks = []
538
+ for i in range(num_stages):
539
+ conv_embed = ConvEmbed(
540
+ patch_size=config.patch_size[i],
541
+ stride=config.patch_stride[i],
542
+ padding=config.patch_padding[i],
543
+ in_chans=config.in_chans if i == 0 else self.embed_dims[i - 1],
544
+ embed_dim=self.embed_dims[i],
545
+ norm_layer=(
546
+ nn.LayerNorm
547
+ if config.norm_layer == "layer_norm"
548
+ else nn.BatchNorm2d
549
+ ),
550
+ pre_norm=config.patch_prenorm[i],
551
+ )
552
+ convs.append(conv_embed)
553
+
554
+ block = MySequential(
555
+ *[
556
+ MySequential(
557
+ OrderedDict(
558
+ [
559
+ (
560
+ "spatial_block",
561
+ SpatialBlock(
562
+ self.embed_dims[i],
563
+ self.num_heads[i],
564
+ config.window_size,
565
+ drop_path_rate=dpr[depth_offset + j * 2],
566
+ qkv_bias=config.qkv_bias,
567
+ mlp_ratio=config.mlp_ratio,
568
+ conv_at_attn=config.conv_at_attn,
569
+ conv_at_ffn=config.conv_at_ffn,
570
+ ),
571
+ ),
572
+ (
573
+ "channel_block",
574
+ ChannelBlock(
575
+ self.embed_dims[i],
576
+ self.num_groups[i],
577
+ drop_path_rate=dpr[depth_offset + j * 2 + 1],
578
+ qkv_bias=config.qkv_bias,
579
+ mlp_ratio=config.mlp_ratio,
580
+ conv_at_attn=config.conv_at_attn,
581
+ conv_at_ffn=config.conv_at_ffn,
582
+ ),
583
+ ),
584
+ ]
585
+ )
586
+ )
587
+ for j in range(config.depths[i])
588
+ ]
589
+ )
590
+ blocks.append(block)
591
+ depth_offset += config.depths[i] * 2
592
+
593
+ self.convs = nn.ModuleList(convs)
594
+ self.blocks = nn.ModuleList(blocks)
595
+
596
+ self.norms = (
597
+ nn.LayerNorm(self.embed_dims[-1])
598
+ if config.norm_layer == "layer_norm"
599
+ else nn.BatchNorm2d(self.embed_dims[-1])
600
+ )
601
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
602
+ self.head = (
603
+ nn.Linear(self.embed_dims[-1], self.num_classes)
604
+ if self.num_classes > 0
605
+ else nn.Identity()
606
+ )
607
+
608
+ self.apply(self._init_weights)
609
+
610
+ def _init_weights(self, m):
611
+ if isinstance(m, nn.Linear):
612
+ trunc_normal_(m.weight, std=0.02)
613
+ if m.bias is not None:
614
+ nn.init.constant_(m.bias, 0)
615
+ elif isinstance(m, nn.Conv2d):
616
+ nn.init.normal_(m.weight, std=0.02)
617
+ for name, _ in m.named_parameters():
618
+ if name in ["bias"]:
619
+ nn.init.constant_(m.bias, 0)
620
+ elif isinstance(m, nn.LayerNorm):
621
+ nn.init.constant_(m.weight, 1.0)
622
+ nn.init.constant_(m.bias, 0)
623
+ elif isinstance(m, nn.BatchNorm2d):
624
+ nn.init.constant_(m.weight, 1.0)
625
+ nn.init.constant_(m.bias, 0)
626
+
627
+ def forward_features_unpool(self, x):
628
+ """
629
+ forward until avg pooling
630
+ Args:
631
+ x (_type_): input image tensor
632
+ """
633
+ input_size = (x.size(2), x.size(3))
634
+ for conv, block in zip(self.convs, self.blocks):
635
+ x, input_size = conv(x, input_size)
636
+ if self.enable_checkpoint:
637
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
638
+ else:
639
+ x, input_size = block(x, input_size)
640
+ return x
641
+
642
+ def forward_features(self, x):
643
+ x = self.forward_features_unpool(x)
644
+
645
+ # (batch_size, num_tokens, token_dim)
646
+ x = self.avgpool(x.transpose(1, 2))
647
+ # (batch_size, 1, num_tokens)
648
+ x = torch.flatten(x, 1)
649
+ x = self.norms(x)
650
+
651
+ return x
652
+
653
+ def forward(self, x):
654
+ x = self.forward_features(x)
655
+ x = self.head(x)
656
+ return x
657
+
658
+
659
+ # Register the configuration and model
660
+ AutoConfig.register("davit", DaViTConfig)
661
+ AutoModel.register(DaViTConfig, DaViTModel)