anicolson commited on
Commit
f1ed0bf
1 Parent(s): d7a0b0e

Upload model

Browse files
Files changed (4) hide show
  1. README.md +199 -0
  2. config.json +54 -0
  3. model.safetensors +3 -0
  4. modelling_uniformer.py +412 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/datasets/work/hb-mlaifsp-mm/work/checkpoints/uniformer_base_tl_384",
3
+ "architectures": [
4
+ "UniFormerModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "attn_drop_rate": 0.0,
8
+ "auto_map": {
9
+ "AutoModel": "modelling_uniformer.UniFormerModel"
10
+ },
11
+ "conv_stem": false,
12
+ "depth": [
13
+ 5,
14
+ 8,
15
+ 20,
16
+ 7
17
+ ],
18
+ "drop_path_rate": 0.3,
19
+ "drop_rate": 0.0,
20
+ "embed_dim": [
21
+ 64,
22
+ 128,
23
+ 320,
24
+ 512
25
+ ],
26
+ "encoder_stride": 16,
27
+ "head_dim": 64,
28
+ "hidden_act": "gelu",
29
+ "hidden_dropout_prob": 0.0,
30
+ "hidden_size": 768,
31
+ "image_size": 384,
32
+ "in_chans": 3,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 3072,
35
+ "layer_norm_eps": 1e-06,
36
+ "mlp_ratio": 4,
37
+ "model_type": "vit",
38
+ "num_attention_heads": 12,
39
+ "num_channels": 3,
40
+ "num_classes": 1000,
41
+ "num_hidden_layers": 12,
42
+ "patch_size": [
43
+ 4,
44
+ 2,
45
+ 2,
46
+ 2
47
+ ],
48
+ "projection_size": null,
49
+ "qk_scale": null,
50
+ "qkv_bias": true,
51
+ "representation_size": null,
52
+ "torch_dtype": "float32",
53
+ "transformers_version": "4.40.2"
54
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0571c333da6cbecab2e8f03af5ca2b87608c1a5e269fbb5b070e175f02e6fd34
3
+ size 197150032
modelling_uniformer.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+ from math import isqrt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+ from transformers import ViTConfig
10
+ from transformers.modeling_outputs import ModelOutput
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import logging
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ layer_scale = False
18
+ init_value = 1e-6
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class CMlp(nn.Module):
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ class Attention(nn.Module):
60
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ head_dim = dim // num_heads
64
+ self.scale = qk_scale or head_dim ** -0.5
65
+
66
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv[0], qkv[1], qkv[2]
75
+
76
+ attn = (q @ k.transpose(-2, -1)) * self.scale
77
+ attn = attn.softmax(dim=-1)
78
+ attn = self.attn_drop(attn)
79
+
80
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
81
+ x = self.proj(x)
82
+ x = self.proj_drop(x)
83
+ return x
84
+
85
+
86
+ class CBlock(nn.Module):
87
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU):
88
+ super().__init__()
89
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
90
+ self.norm1 = nn.BatchNorm2d(dim)
91
+ self.conv1 = nn.Conv2d(dim, dim, 1)
92
+ self.conv2 = nn.Conv2d(dim, dim, 1)
93
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = nn.BatchNorm2d(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ def forward(self, x):
100
+ x = x + self.pos_embed(x)
101
+ x = x + self.module_1(x)
102
+ x = x + self.module_2(x)
103
+ return x
104
+
105
+ def module_1(self, x):
106
+ x = self.norm1(x.to(dtype=self.norm1.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
107
+ x = self.conv1(x)
108
+ x = self.attn(x)
109
+ x = self.conv2(x)
110
+ x = self.drop_path(x)
111
+ return x
112
+
113
+ def module_2(self, x):
114
+ x = self.norm2(x.to(dtype=self.norm2.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
115
+ x = self.mlp(x)
116
+ x = self.drop_path(x)
117
+ return x
118
+
119
+ class SABlock(nn.Module):
120
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
121
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
122
+ super().__init__()
123
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
124
+ self.norm1 = norm_layer(dim)
125
+ self.attn = Attention(
126
+ dim,
127
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
128
+ attn_drop=attn_drop, proj_drop=drop)
129
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
130
+ self.norm2 = norm_layer(dim)
131
+ mlp_hidden_dim = int(dim * mlp_ratio)
132
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
133
+ global layer_scale
134
+ self.ls = layer_scale
135
+ if self.ls:
136
+ global init_value
137
+ print(f"Use layer_scale: {layer_scale}, init_values: {init_value}")
138
+ self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
139
+ self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
140
+
141
+ def forward(self, x):
142
+ x = x + self.pos_embed(x)
143
+ B, N, H, W = x.shape
144
+ x = x.flatten(2).transpose(1, 2)
145
+ if self.ls:
146
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
147
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
148
+ else:
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ x = x.transpose(1, 2).reshape(B, N, H, W)
152
+ return x
153
+
154
+
155
+ class HeadEmbedding(nn.Module):
156
+ def __init__(self, in_channels, out_channels):
157
+ super(HeadEmbedding, self).__init__()
158
+
159
+ self.proj = nn.Sequential(
160
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
161
+ nn.BatchNorm2d(out_channels // 2),
162
+ nn.GELU(),
163
+ nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
164
+ nn.BatchNorm2d(out_channels),
165
+ )
166
+
167
+ def forward(self, x):
168
+ x = self.proj(x)
169
+ return x
170
+
171
+
172
+ class MiddleEmbedding(nn.Module):
173
+ def __init__(self, in_channels, out_channels):
174
+ super(MiddleEmbedding, self).__init__()
175
+
176
+ self.proj = nn.Sequential(
177
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
178
+ nn.BatchNorm2d(out_channels),
179
+ )
180
+
181
+ def forward(self, x):
182
+ x = self.proj(x)
183
+ return x
184
+
185
+
186
+ class PatchEmbed(nn.Module):
187
+ def __init__(self, image_size=224, patch_size=16, in_chans=3, embed_dim=768):
188
+ super().__init__()
189
+ image_size = to_2tuple(image_size)
190
+ patch_size = to_2tuple(patch_size)
191
+ num_patches_height = image_size[0] // patch_size[0]
192
+ num_patches_width = image_size[1] // patch_size[1]
193
+ num_patches = num_patches_height * num_patches_width
194
+ self.image_size = image_size
195
+ self.patch_size = patch_size
196
+ self.num_patches = num_patches
197
+
198
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
199
+ self.norm = nn.LayerNorm(embed_dim)
200
+
201
+ def forward(self, x):
202
+ _, _, H, W = x.shape
203
+ assert H == self.image_size[0] and W == self.image_size[1], \
204
+ f"Input image size ({H}*{W}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
205
+ x = self.proj(x)
206
+ B, _, H, W = x.shape
207
+ x = x.flatten(2).transpose(1, 2)
208
+ x = self.norm(x)
209
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
210
+ return x
211
+
212
+
213
+ class UniFormer(nn.Module):
214
+ def __init__(self, depth=[3, 4, 8, 3], image_size=224, in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
215
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, patch_size=[4, 2, 2, 2],
216
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., conv_stem=False, layer_norm_eps=1e-6, **kwargs):
217
+ super().__init__()
218
+ self.num_classes = num_classes
219
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
220
+ norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
221
+ if conv_stem:
222
+ self.patch_embed1 = HeadEmbedding(in_channels=in_chans, out_channels=embed_dim[0])
223
+ self.patch_embed2 = MiddleEmbedding(in_channels=embed_dim[0], out_channels=embed_dim[1])
224
+ self.patch_embed3 = MiddleEmbedding(in_channels=embed_dim[1], out_channels=embed_dim[2])
225
+ self.patch_embed4 = MiddleEmbedding(in_channels=embed_dim[2], out_channels=embed_dim[3])
226
+ else:
227
+ self.patch_embed1 = PatchEmbed(
228
+ image_size=image_size, patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0])
229
+ self.patch_embed2 = PatchEmbed(
230
+ image_size=image_size // patch_size[0], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1])
231
+ self.patch_embed3 = PatchEmbed(
232
+ image_size=image_size // (patch_size[0]*patch_size[1]), patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2])
233
+ self.patch_embed4 = PatchEmbed(
234
+ image_size=image_size // (patch_size[0]*patch_size[1]*patch_size[2]), patch_size=patch_size[3], in_chans=embed_dim[2], embed_dim=embed_dim[3])
235
+
236
+ self.pos_drop = nn.Dropout(p=drop_rate)
237
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
238
+ num_heads = [dim // head_dim for dim in embed_dim]
239
+ self.blocks1 = nn.ModuleList([
240
+ CBlock(dim=embed_dim[0], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i])
241
+ for i in range(depth[0])])
242
+ self.blocks2 = nn.ModuleList([
243
+ CBlock(dim=embed_dim[1], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i+depth[0]])
244
+ for i in range(depth[1])])
245
+ self.blocks3 = nn.ModuleList([
246
+ SABlock(
247
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
249
+ for i in range(depth[2])])
250
+ self.blocks4 = nn.ModuleList([
251
+ SABlock(
252
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
253
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
254
+ for i in range(depth[3])])
255
+ self.norm = nn.BatchNorm2d(embed_dim[-1])
256
+
257
+ # Representation layer
258
+ if representation_size:
259
+ self.num_features = representation_size
260
+ self.pre_logits = nn.Sequential(OrderedDict([
261
+ ('fc', nn.Linear(embed_dim, representation_size)),
262
+ ('act', nn.Tanh())
263
+ ]))
264
+ else:
265
+ self.pre_logits = nn.Identity()
266
+
267
+ def forward_features(self, x):
268
+ x = self.patch_embed1(x)
269
+ x = self.pos_drop(x)
270
+ for blk in self.blocks1:
271
+ x = blk(x)
272
+ x = self.patch_embed2(x)
273
+ for blk in self.blocks2:
274
+ x = blk(x)
275
+ x = self.patch_embed3(x)
276
+ for blk in self.blocks3:
277
+ x = blk(x)
278
+ x = self.patch_embed4(x)
279
+ for blk in self.blocks4:
280
+ x = blk(x)
281
+ x = self.norm(x.to(dtype=self.norm.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
282
+ x = self.pre_logits(x)
283
+ return x
284
+
285
+ def forward(self, x):
286
+ x = self.forward_features(x)
287
+ return x
288
+
289
+
290
+ class UniFormerPreTrainedModel(PreTrainedModel):
291
+ """
292
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
293
+ models.
294
+ """
295
+
296
+ config_class = ViTConfig
297
+ base_model_prefix = "vit"
298
+ main_input_name = "pixel_values"
299
+
300
+ def _init_weights(self, m):
301
+ if isinstance(m, nn.Linear):
302
+ trunc_normal_(m.weight, std=.02)
303
+ if isinstance(m, nn.Linear) and m.bias is not None:
304
+ nn.init.constant_(m.bias, 0)
305
+ elif isinstance(m, nn.LayerNorm):
306
+ nn.init.constant_(m.bias, 0)
307
+ nn.init.constant_(m.weight, 1.0)
308
+
309
+
310
+ class UniFormerProjectionHead(torch.nn.Module):
311
+
312
+ def __init__(self, config) -> None:
313
+ super().__init__()
314
+
315
+ # Layer normalisation before projection:
316
+ self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
317
+
318
+ # No bias as following layer normalisation with bias:
319
+ self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
320
+
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ x = self.layer_norm(x)
324
+ x = self.projection(x)
325
+ return x
326
+
327
+
328
+ class UniFormerModel(UniFormerPreTrainedModel):
329
+ def __init__(self, config):
330
+ super().__init__(config)
331
+
332
+ self.uniformer = UniFormer(**vars(config))
333
+
334
+ # Initialize weights and apply final processing:
335
+ self.post_init()
336
+
337
+ def forward(
338
+ self,
339
+ pixel_values: Optional[torch.Tensor] = None,
340
+ output_hidden_states: Optional[bool] = None,
341
+ return_dict: Optional[bool] = None,
342
+ ) -> Union[Tuple, ModelOutput]:
343
+
344
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
345
+
346
+ last_hidden_state = self.uniformer(pixel_values)
347
+
348
+ # Flatten h x w:
349
+ last_hidden_state = torch.flatten(last_hidden_state, 2)
350
+
351
+ # Permute last hidden state:
352
+ last_hidden_state = torch.permute(last_hidden_state, [0, 2, 1])
353
+
354
+ # return last_hidden_state
355
+ if not return_dict:
356
+ return last_hidden_state
357
+
358
+ return ModelOutput(last_hidden_state=last_hidden_state)
359
+
360
+
361
+ class MultiUniFormerWithProjectionHead(UniFormerPreTrainedModel):
362
+ def __init__(self, config):
363
+ super().__init__(config)
364
+
365
+ self.uniformer = UniFormer(**vars(config))
366
+ self.projection_head = UniFormerProjectionHead(config)
367
+
368
+ # Initialize weights and apply final processing:
369
+ self.post_init()
370
+
371
+ def forward(
372
+ self,
373
+ pixel_values: Optional[torch.Tensor] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = None,
376
+ ) -> Union[Tuple, ModelOutput]:
377
+
378
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
379
+
380
+ # Flatten the batch and study_id dimensions:
381
+ assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.'
382
+ last_hidden_state = self.uniformer(pixel_values.view(-1, *pixel_values.shape[2:]))
383
+ # last_hidden_state = self.uniformer(pixel_values.flatten(start_dim=0, end_dim=1))
384
+
385
+ # Flatten h x w:
386
+ last_hidden_state = torch.flatten(last_hidden_state, 2)
387
+
388
+ # Project the features for each spatial position to the decoder's hidden size:
389
+ projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
390
+
391
+ # Concatenate the features for each chest X-ray:
392
+ projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
393
+
394
+ # Derive the attention mask from the pixel values:
395
+ mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None]
396
+ attention_mask = torch.ones(
397
+ [projection.shape[0], pixel_values.shape[1], projection.shape[1] // pixel_values.shape[1]],
398
+ dtype=torch.long,
399
+ device=mask.device,
400
+ )
401
+ attention_mask = attention_mask * mask
402
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1)
403
+
404
+ if not return_dict:
405
+ return projection
406
+
407
+ return ModelOutput(last_hidden_state=projection, attention_mask=attention_mask)
408
+
409
+
410
+ if __name__ == '__main__':
411
+ y = PatchEmbed()
412
+ y(torch.randn(2, 3, 224, 224))