Spaces:
Runtime error
Runtime error
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +81 -0
- configs/chunk000.yaml +89 -0
- configs/chunkseq003.yaml +67 -0
- configs/pseudoseg000.yaml +110 -0
- examples/1.2.826.0.1.3680043.15773.nii.gz +3 -0
- packages.txt +1 -0
- requirements.txt +7 -0
- seg.ckpt +3 -0
- seq.ckpt +3 -0
- skp/.DS_Store +0 -0
- skp/__init__.py +0 -0
- skp/__pycache__/__init__.cpython-39.pyc +0 -0
- skp/__pycache__/builder.cpython-39.pyc +0 -0
- skp/builder.py +187 -0
- skp/models/__init__.py +1 -0
- skp/models/__pycache__/__init__.cpython-39.pyc +0 -0
- skp/models/__pycache__/backbones.cpython-39.pyc +0 -0
- skp/models/__pycache__/engine.cpython-39.pyc +0 -0
- skp/models/__pycache__/sequence.cpython-39.pyc +0 -0
- skp/models/__pycache__/tools.cpython-39.pyc +0 -0
- skp/models/backbones.py +114 -0
- skp/models/engine.py +257 -0
- skp/models/pooling/__init__.py +3 -0
- skp/models/pooling/__pycache__/__init__.cpython-39.pyc +0 -0
- skp/models/pooling/__pycache__/gem.cpython-39.pyc +0 -0
- skp/models/pooling/__pycache__/pool1d.cpython-39.pyc +0 -0
- skp/models/pooling/__pycache__/pool2d.cpython-39.pyc +0 -0
- skp/models/pooling/__pycache__/pool3d.cpython-39.pyc +0 -0
- skp/models/pooling/gem.py +35 -0
- skp/models/pooling/pool1d.py +107 -0
- skp/models/pooling/pool2d.py +16 -0
- skp/models/pooling/pool3d.py +107 -0
- skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml +109 -0
- skp/models/rev_mvit/__init__.py +0 -0
- skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/common.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc +0 -0
- skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc +0 -0
- skp/models/rev_mvit/attention.py +568 -0
- skp/models/rev_mvit/batchnorm_helper.py +112 -0
- skp/models/rev_mvit/common.py +154 -0
- skp/models/rev_mvit/head_helper.py +140 -0
- skp/models/rev_mvit/reversible_mvit.py +696 -0
- skp/models/rev_mvit/stem_helper.py +325 -0
- skp/models/rev_mvit/utils.py +221 -0
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import glob
|
3 |
+
import gradio as gr
|
4 |
+
import mediapy
|
5 |
+
import nibabel
|
6 |
+
import numpy as np
|
7 |
+
import shutil
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from skp import builder
|
13 |
+
|
14 |
+
|
15 |
+
def window(x, WL=400, WW=2500):
|
16 |
+
lower, upper = WL - WW // 2, WL + WW // 2
|
17 |
+
x = np.clip(x, lower, upper)
|
18 |
+
x = x - lower
|
19 |
+
x = x / (upper - lower)
|
20 |
+
return (x * 255).astype("uint8")
|
21 |
+
|
22 |
+
|
23 |
+
def rescale(x):
|
24 |
+
x = x / 255.
|
25 |
+
x = x - 0.5
|
26 |
+
x = x * 2.0
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
def generate_segmentation_video(study):
|
31 |
+
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
|
32 |
+
img = window(img)
|
33 |
+
|
34 |
+
X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
|
35 |
+
X = F.interpolate(X, size=(192, 192, 192), mode="nearest")
|
36 |
+
X = rescale(X)
|
37 |
+
with torch.no_grad():
|
38 |
+
seg_output = seg_model(X)
|
39 |
+
|
40 |
+
seg_output = torch.sigmoid(seg_output)
|
41 |
+
p_spine = seg_output[:, :7].sum(1)
|
42 |
+
seg_output = torch.argmax(seg_output, dim=1) + 1
|
43 |
+
seg_output[p_spine < 0.5] = 0
|
44 |
+
seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest")
|
45 |
+
seg_output = seg_output.squeeze(0).squeeze(0).numpy()
|
46 |
+
seg_output = (seg_output * 255 / 7).astype("uint8")
|
47 |
+
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
|
48 |
+
|
49 |
+
frames = []
|
50 |
+
skip = 8
|
51 |
+
for idx in range(0, img.shape[2], skip):
|
52 |
+
i = img[:, :, idx]
|
53 |
+
o = seg_output[:, :, idx]
|
54 |
+
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
|
55 |
+
frame = np.concatenate((i, o), 1)
|
56 |
+
frames.append(frame)
|
57 |
+
mediapy.write_video("video.mp4", frames, fps=30)
|
58 |
+
return "video.mp4"
|
59 |
+
|
60 |
+
|
61 |
+
ffmpeg_path = shutil.which('ffmpeg')
|
62 |
+
mediapy.set_ffmpeg(ffmpeg_path)
|
63 |
+
|
64 |
+
config = OmegaConf.load("configs/pseudoseg000.yaml")
|
65 |
+
config.model.load_pretrained = "seg.ckpt"
|
66 |
+
seg_model = builder.build_model(config).eval()
|
67 |
+
examples = glob.glob("examples/*.nii.gz")
|
68 |
+
|
69 |
+
with gr.Blocks(theme="dark-peach") as demo:
|
70 |
+
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
|
71 |
+
button_predict = gr.Button("Predict")
|
72 |
+
video_output = gr.Video()
|
73 |
+
button_predict.click(fn=generate_segmentation_video,
|
74 |
+
inputs=select_study,
|
75 |
+
outputs=video_output)
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
demo.launch(debug=True, share=True)
|
80 |
+
|
81 |
+
|
configs/chunk000.yaml
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
seed: 88
|
3 |
+
save_dir: ../experiments/
|
4 |
+
|
5 |
+
|
6 |
+
data:
|
7 |
+
annotations: ../data/train_vertebra_chunks_kfold.csv
|
8 |
+
data_dir: ../data/train-numpy-vertebra-chunks
|
9 |
+
input: filename
|
10 |
+
target: fracture
|
11 |
+
outer_fold: 0
|
12 |
+
dataset:
|
13 |
+
name: NumpyChunkDataset
|
14 |
+
params:
|
15 |
+
flip: true
|
16 |
+
invert: false
|
17 |
+
channels: grayscale
|
18 |
+
z_lt: resample_resample
|
19 |
+
z_gt: resample_resample
|
20 |
+
num_images: 64
|
21 |
+
|
22 |
+
|
23 |
+
transform:
|
24 |
+
resize:
|
25 |
+
name: resize_ignore_3d
|
26 |
+
params:
|
27 |
+
imsize: [64, 288, 288]
|
28 |
+
augment:
|
29 |
+
null
|
30 |
+
crop:
|
31 |
+
null
|
32 |
+
preprocess:
|
33 |
+
name: Preprocessor
|
34 |
+
params:
|
35 |
+
image_range: [0, 255]
|
36 |
+
input_range: [0, 1]
|
37 |
+
mean: [0.5]
|
38 |
+
sdev: [0.5]
|
39 |
+
|
40 |
+
|
41 |
+
task:
|
42 |
+
name: ClassificationTask
|
43 |
+
params:
|
44 |
+
|
45 |
+
|
46 |
+
model:
|
47 |
+
name: Net3D
|
48 |
+
params:
|
49 |
+
backbone: x3d_l
|
50 |
+
backbone_params:
|
51 |
+
z_strides: [1, 1, 1, 1, 1]
|
52 |
+
pretrained: true
|
53 |
+
num_classes: 1
|
54 |
+
dropout: 0.2
|
55 |
+
pool: avg
|
56 |
+
in_channels: 1
|
57 |
+
multisample_dropout: true
|
58 |
+
|
59 |
+
|
60 |
+
loss:
|
61 |
+
name: BCEWithLogitsLoss
|
62 |
+
params:
|
63 |
+
|
64 |
+
|
65 |
+
optimizer:
|
66 |
+
name: AdamW
|
67 |
+
params:
|
68 |
+
lr: 3.0e-4
|
69 |
+
weight_decay: 5.0e-4
|
70 |
+
|
71 |
+
|
72 |
+
scheduler:
|
73 |
+
name: CosineAnnealingLR
|
74 |
+
params:
|
75 |
+
final_lr: 0.0
|
76 |
+
|
77 |
+
|
78 |
+
train:
|
79 |
+
batch_size: 4
|
80 |
+
num_epochs: 10
|
81 |
+
|
82 |
+
|
83 |
+
evaluate:
|
84 |
+
metrics: [AUROC]
|
85 |
+
monitor: auc_mean
|
86 |
+
mode: max
|
87 |
+
|
88 |
+
|
89 |
+
|
configs/chunkseq003.yaml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
seed: 88
|
3 |
+
save_dir: ../experiments/
|
4 |
+
|
5 |
+
|
6 |
+
data:
|
7 |
+
annotations: ../data/train_chunk_features_kfold.csv
|
8 |
+
data_dir: ../data/train-chunk000-features/foldx
|
9 |
+
input: filename
|
10 |
+
target: [C1, C2, C3, C4, C5, C6, C7, patient_overall]
|
11 |
+
outer_fold: 0
|
12 |
+
dataset:
|
13 |
+
name: FeatureDataset
|
14 |
+
params:
|
15 |
+
seq_len: 7
|
16 |
+
reverse: false
|
17 |
+
normalize: false
|
18 |
+
exam_level_label: true
|
19 |
+
|
20 |
+
|
21 |
+
task:
|
22 |
+
name: ClassificationTask
|
23 |
+
params:
|
24 |
+
|
25 |
+
|
26 |
+
model:
|
27 |
+
name: DualTransformer
|
28 |
+
params:
|
29 |
+
num_classes: 1
|
30 |
+
embedding_dim: 432
|
31 |
+
hidden_dim: 864
|
32 |
+
n_layers: 3
|
33 |
+
n_heads: 16
|
34 |
+
|
35 |
+
|
36 |
+
loss:
|
37 |
+
name: MultilabelWeightedBCE
|
38 |
+
params:
|
39 |
+
weights: [1, 1, 1, 1, 1, 1, 1, 7]
|
40 |
+
pos_weight: 2.0
|
41 |
+
|
42 |
+
|
43 |
+
optimizer:
|
44 |
+
name: AdamW
|
45 |
+
params:
|
46 |
+
lr: 1.0e-5
|
47 |
+
weight_decay: 5.0e-4
|
48 |
+
|
49 |
+
|
50 |
+
scheduler:
|
51 |
+
name: CosineAnnealingLR
|
52 |
+
params:
|
53 |
+
final_lr: 0
|
54 |
+
|
55 |
+
|
56 |
+
train:
|
57 |
+
batch_size: 32
|
58 |
+
num_epochs: 25
|
59 |
+
|
60 |
+
|
61 |
+
evaluate:
|
62 |
+
batch_size: 1
|
63 |
+
metrics: [CompetitionMetric, AUROC]
|
64 |
+
monitor: comp_metric
|
65 |
+
mode: min
|
66 |
+
|
67 |
+
|
configs/pseudoseg000.yaml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
seed: 88
|
3 |
+
save_dir: ../experiments/
|
4 |
+
|
5 |
+
|
6 |
+
data:
|
7 |
+
annotations: ../data/train_seg_whole_192_kfold_with_pseudo.csv
|
8 |
+
data_dir: ../data/
|
9 |
+
input: filename
|
10 |
+
target: label
|
11 |
+
outer_fold: 0
|
12 |
+
dataset:
|
13 |
+
name: NumpyChunkSegmentDataset
|
14 |
+
params:
|
15 |
+
segmentation_format: numpy
|
16 |
+
channels: grayscale
|
17 |
+
flip: true
|
18 |
+
transpose: true
|
19 |
+
invert: false
|
20 |
+
verbose: true
|
21 |
+
num_images: 192
|
22 |
+
z_lt: resample_resample
|
23 |
+
z_gt: resample_resample
|
24 |
+
one_hot_encode: true
|
25 |
+
num_classes: 8
|
26 |
+
add_foreground_channel: false
|
27 |
+
|
28 |
+
|
29 |
+
transform:
|
30 |
+
resize:
|
31 |
+
name: resize_ignore_3d
|
32 |
+
params:
|
33 |
+
imsize: [192, 192, 192]
|
34 |
+
augment:
|
35 |
+
null
|
36 |
+
crop:
|
37 |
+
null
|
38 |
+
preprocess:
|
39 |
+
name: Preprocessor
|
40 |
+
params:
|
41 |
+
image_range: [0, 255]
|
42 |
+
input_range: [0, 1]
|
43 |
+
mean: [0.5]
|
44 |
+
sdev: [0.5]
|
45 |
+
|
46 |
+
|
47 |
+
task:
|
48 |
+
name: SegmentationTask3D
|
49 |
+
params:
|
50 |
+
chunk_validation: true
|
51 |
+
|
52 |
+
|
53 |
+
model:
|
54 |
+
name: NetSegment3D
|
55 |
+
params:
|
56 |
+
architecture: DeepLabV3Plus_3D
|
57 |
+
encoder_name: x3d_l
|
58 |
+
encoder_params:
|
59 |
+
pretrained: true
|
60 |
+
output_stride: 16
|
61 |
+
z_strides: [2, 2, 2, 2, 2]
|
62 |
+
decoder_params:
|
63 |
+
upsampling: 4
|
64 |
+
deep_supervision: true
|
65 |
+
num_classes: 8
|
66 |
+
in_channels: 1
|
67 |
+
dropout: 0.2
|
68 |
+
|
69 |
+
|
70 |
+
loss:
|
71 |
+
name: SupervisorLoss
|
72 |
+
params:
|
73 |
+
segmentation_loss: DiceBCELoss
|
74 |
+
scale_factors: [0.25, 0.25]
|
75 |
+
loss_weights: [1.0, 0.25, 0.25]
|
76 |
+
loss_params:
|
77 |
+
dice_loss_params:
|
78 |
+
mode: multilabel
|
79 |
+
exponent: 2
|
80 |
+
smooth: 1.0
|
81 |
+
bce_loss_params:
|
82 |
+
smooth_factor: 0.01
|
83 |
+
pos_weight: 1.0
|
84 |
+
dice_loss_weight: 1.0
|
85 |
+
bce_loss_weight: 0.2
|
86 |
+
|
87 |
+
|
88 |
+
optimizer:
|
89 |
+
name: AdamW
|
90 |
+
params:
|
91 |
+
lr: 3.0e-4
|
92 |
+
weight_decay: 5.0e-4
|
93 |
+
|
94 |
+
|
95 |
+
scheduler:
|
96 |
+
name: CosineAnnealingLR
|
97 |
+
params:
|
98 |
+
final_lr: 0.0
|
99 |
+
|
100 |
+
|
101 |
+
train:
|
102 |
+
batch_size: 4
|
103 |
+
num_epochs: 10
|
104 |
+
|
105 |
+
|
106 |
+
evaluate:
|
107 |
+
batch_size: 1
|
108 |
+
metrics: [DSC]
|
109 |
+
monitor: dsc_ignore_mean
|
110 |
+
mode: max
|
examples/1.2.826.0.1.3680043.15773.nii.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a316d4cdb9534c662a209dea2b50fd57168398b1a658d14937ce285d3b792917
|
3 |
+
size 65868417
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf
|
2 |
+
mediapy
|
3 |
+
nibabel
|
4 |
+
opencv-python
|
5 |
+
timm
|
6 |
+
torch
|
7 |
+
transformers
|
seg.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa6ee0036af98df68621b5cacbed9b4cd290eb1b59c6af7785a7e9c81ed74afa
|
3 |
+
size 21569386
|
seq.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:653637b500e3ae5ffab8b07f34d36662396ab3eacec8024e5ecea952d7c2c07e
|
3 |
+
size 18011334
|
skp/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
skp/__init__.py
ADDED
File without changes
|
skp/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (169 Bytes). View file
|
|
skp/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (5.29 kB). View file
|
|
skp/builder.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from . import models
|
5 |
+
|
6 |
+
|
7 |
+
def get_name_and_params(base):
|
8 |
+
name = getattr(base, 'name')
|
9 |
+
params = getattr(base, 'params') or {}
|
10 |
+
return name, params
|
11 |
+
|
12 |
+
|
13 |
+
def get_transform(base, transform, mode=None):
|
14 |
+
if not base: return None
|
15 |
+
transform = getattr(base, transform)
|
16 |
+
if not transform: return None
|
17 |
+
name, params = get_name_and_params(transform)
|
18 |
+
if mode:
|
19 |
+
params.update({'mode': mode})
|
20 |
+
return getattr(data.transforms, name)(**params)
|
21 |
+
|
22 |
+
|
23 |
+
def build_transforms(cfg, mode):
|
24 |
+
# 1-Resize
|
25 |
+
resizer = get_transform(cfg.transform, 'resize')
|
26 |
+
# 2-(Optional) Data augmentation
|
27 |
+
augmenter = None
|
28 |
+
if mode == "train":
|
29 |
+
augmenter = get_transform(cfg.transform, 'augment')
|
30 |
+
# 3-(Optional) Crop
|
31 |
+
cropper = get_transform(cfg.transform, 'crop', mode=mode)
|
32 |
+
# 4-Preprocess
|
33 |
+
preprocessor = get_transform(cfg.transform, 'preprocess')
|
34 |
+
return {
|
35 |
+
'resize': resizer,
|
36 |
+
'augment': augmenter,
|
37 |
+
'crop': cropper,
|
38 |
+
'preprocess': preprocessor
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def build_dataset(cfg, data_info, mode):
|
43 |
+
dataset_class = getattr(data.datasets, cfg.data.dataset.name)
|
44 |
+
dataset_params = cfg.data.dataset.params
|
45 |
+
dataset_params.test_mode = mode != 'train'
|
46 |
+
dataset_params = dict(dataset_params)
|
47 |
+
if "FeatureDataset" not in cfg.data.dataset.name:
|
48 |
+
transforms = build_transforms(cfg, mode)
|
49 |
+
dataset_params.update(transforms)
|
50 |
+
dataset_params.update(data_info)
|
51 |
+
return dataset_class(**dataset_params)
|
52 |
+
|
53 |
+
|
54 |
+
def build_dataloader(cfg, dataset, mode):
|
55 |
+
|
56 |
+
def worker_init_fn(worker_id):
|
57 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
58 |
+
|
59 |
+
dataloader_params = {}
|
60 |
+
dataloader_params['num_workers'] = cfg.data.num_workers
|
61 |
+
dataloader_params['drop_last'] = mode == 'train'
|
62 |
+
dataloader_params['shuffle'] = mode == 'train'
|
63 |
+
dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True)
|
64 |
+
if mode in ('train', 'valid'):
|
65 |
+
if mode == "train":
|
66 |
+
dataloader_params['batch_size'] = cfg.train.batch_size
|
67 |
+
elif mode == "valid":
|
68 |
+
dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
|
69 |
+
sampler = None
|
70 |
+
if cfg.data.get("sampler") and mode == 'train':
|
71 |
+
name, params = get_name_and_params(cfg.data.sampler)
|
72 |
+
sampler = getattr(data.samplers, name)(dataset, **params)
|
73 |
+
if sampler:
|
74 |
+
dataloader_params['shuffle'] = False
|
75 |
+
if cfg.strategy == 'ddp':
|
76 |
+
sampler = data.samplers.DistributedSamplerWrapper(sampler)
|
77 |
+
dataloader_params['sampler'] = sampler
|
78 |
+
print(f'Using sampler {sampler} for training ...')
|
79 |
+
elif cfg.strategy == 'ddp':
|
80 |
+
dataloader_params["shuffle"] = False
|
81 |
+
dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train")
|
82 |
+
else:
|
83 |
+
assert cfg.strategy != "ddp", "DDP currently not supported for inference"
|
84 |
+
dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
|
85 |
+
|
86 |
+
loader = DataLoader(dataset,
|
87 |
+
**dataloader_params,
|
88 |
+
worker_init_fn=worker_init_fn)
|
89 |
+
return loader
|
90 |
+
|
91 |
+
|
92 |
+
def build_model(cfg):
|
93 |
+
name, params = get_name_and_params(cfg.model)
|
94 |
+
if cfg.model.params.get("cnn_params", None):
|
95 |
+
cnn_params = cfg.model.params.cnn_params
|
96 |
+
if cnn_params.get("load_pretrained_backbone", None):
|
97 |
+
if "foldx" in cnn_params.load_pretrained_backbone:
|
98 |
+
cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\
|
99 |
+
replace("foldx", f"fold{cfg.data.outer_fold}")
|
100 |
+
print(f'Creating model <{name}> ...')
|
101 |
+
model = getattr(models.engine, name)(**params)
|
102 |
+
if 'backbone' in cfg.model.params:
|
103 |
+
print(f' Using backbone <{cfg.model.params.backbone}> ...')
|
104 |
+
if 'pretrained' in cfg.model.params:
|
105 |
+
print(f' Pretrained : {cfg.model.params.pretrained}')
|
106 |
+
if "load_pretrained" in cfg.model:
|
107 |
+
import re
|
108 |
+
if "foldx" in cfg.model.load_pretrained:
|
109 |
+
cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}")
|
110 |
+
print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}")
|
111 |
+
weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict']
|
112 |
+
weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k}
|
113 |
+
model.load_state_dict(weights)
|
114 |
+
return model
|
115 |
+
|
116 |
+
|
117 |
+
def build_loss(cfg):
|
118 |
+
name, params = get_name_and_params(cfg.loss)
|
119 |
+
print(f'Using loss function <{name}> ...')
|
120 |
+
params = dict(params)
|
121 |
+
if "pos_weight" in params:
|
122 |
+
params["pos_weight"] = torch.tensor(params["pos_weight"])
|
123 |
+
criterion = getattr(losses, name)(**params)
|
124 |
+
return criterion
|
125 |
+
|
126 |
+
|
127 |
+
def build_scheduler(cfg, optimizer):
|
128 |
+
# Some schedulers will require manipulation of config params
|
129 |
+
# My specifications were to make it more intuitive for me
|
130 |
+
name, params = get_name_and_params(cfg.scheduler)
|
131 |
+
print(f'Using learning rate schedule <{name}> ...')
|
132 |
+
|
133 |
+
if name == 'CosineAnnealingLR':
|
134 |
+
# eta_min <-> final_lr
|
135 |
+
# Set T_max as 100000 ... this is changed in on_train_start() method
|
136 |
+
# of the LightningModule task
|
137 |
+
|
138 |
+
params = {
|
139 |
+
'T_max': 100000,
|
140 |
+
'eta_min': max(params.final_lr, 1.0e-8)
|
141 |
+
}
|
142 |
+
|
143 |
+
if name in ('OneCycleLR', 'CustomOneCycleLR'):
|
144 |
+
# Use learning rate from optimizer parameters as initial learning rate
|
145 |
+
lr_0 = cfg.optimizer.params.lr
|
146 |
+
lr_1 = params.max_lr
|
147 |
+
lr_2 = params.final_lr
|
148 |
+
# lr_0 -> lr_1 -> lr_2
|
149 |
+
pct_start = params.pct_start
|
150 |
+
params = {}
|
151 |
+
params['steps_per_epoch'] = 100000 # see above- will fix in task
|
152 |
+
params['epochs'] = cfg.train.num_epochs
|
153 |
+
params['max_lr'] = lr_1
|
154 |
+
params['pct_start'] = pct_start
|
155 |
+
params['div_factor'] = lr_1 / lr_0 # max/init
|
156 |
+
params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final
|
157 |
+
|
158 |
+
scheduler = getattr(optim, name)(optimizer=optimizer, **params)
|
159 |
+
|
160 |
+
# Some schedulers might need more manipulation after instantiation
|
161 |
+
if name in ('OneCycleLR', 'CustomOneCycleLR'):
|
162 |
+
scheduler.pct_start = params['pct_start']
|
163 |
+
|
164 |
+
# Set update frequency
|
165 |
+
if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'):
|
166 |
+
scheduler.update_frequency = 'on_batch'
|
167 |
+
elif name in ('ReduceLROnPlateau'):
|
168 |
+
scheduler.update_frequency = 'on_valid'
|
169 |
+
else:
|
170 |
+
scheduler.update_frequency = 'on_epoch'
|
171 |
+
|
172 |
+
return scheduler
|
173 |
+
|
174 |
+
|
175 |
+
def build_optimizer(cfg, parameters):
|
176 |
+
name, params = get_name_and_params(cfg.optimizer)
|
177 |
+
print(f'Using optimizer <{name}> ...')
|
178 |
+
optimizer = getattr(optim, name)(parameters, **params)
|
179 |
+
return optimizer
|
180 |
+
|
181 |
+
|
182 |
+
def build_task(cfg, model):
|
183 |
+
name, params = get_name_and_params(cfg.task)
|
184 |
+
print(f'Building task <{name}> ...')
|
185 |
+
return getattr(tasks, name)(cfg, model, **params)
|
186 |
+
|
187 |
+
|
skp/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import engine
|
skp/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (207 Bytes). View file
|
|
skp/models/__pycache__/backbones.cpython-39.pyc
ADDED
Binary file (3.9 kB). View file
|
|
skp/models/__pycache__/engine.cpython-39.pyc
ADDED
Binary file (10.1 kB). View file
|
|
skp/models/__pycache__/sequence.cpython-39.pyc
ADDED
Binary file (5.67 kB). View file
|
|
skp/models/__pycache__/tools.cpython-39.pyc
ADDED
Binary file (932 Bytes). View file
|
|
skp/models/backbones.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import timm
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from timm.models.vision_transformer import VisionTransformer
|
7 |
+
from timm.models.swin_transformer_v2 import SwinTransformerV2
|
8 |
+
|
9 |
+
from .vmz.backbones import *
|
10 |
+
|
11 |
+
|
12 |
+
def check_name(name, s):
|
13 |
+
return bool(re.search(s, name))
|
14 |
+
|
15 |
+
|
16 |
+
def create_backbone(name, pretrained, features_only=False, **kwargs):
|
17 |
+
try:
|
18 |
+
model = timm.create_model(name, pretrained=pretrained,
|
19 |
+
features_only=features_only,
|
20 |
+
num_classes=0, global_pool="")
|
21 |
+
except Exception as e:
|
22 |
+
assert name in BACKBONES, f"{name} is not a valid backbone"
|
23 |
+
model = BACKBONES[name](pretrained=pretrained, features_only=features_only, **kwargs)
|
24 |
+
with torch.no_grad():
|
25 |
+
if check_name(name, r"x3d|csn|r2plus1d|i3d"):
|
26 |
+
dim_feats = model(torch.randn((2, 3, 64, 64, 64))).size(1)
|
27 |
+
elif isinstance(model, (VisionTransformer, SwinTransformerV2)):
|
28 |
+
dim_feats = model.norm.normalized_shape[0]
|
29 |
+
else:
|
30 |
+
dim_feats = model(torch.randn((2, 3, 128, 128))).size(1)
|
31 |
+
return model, dim_feats
|
32 |
+
|
33 |
+
|
34 |
+
def create_csn(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs):
|
35 |
+
if features_only:
|
36 |
+
raise Exception("features_only is currently not supported")
|
37 |
+
if not pretrained:
|
38 |
+
from pytorchvideo.models import hub
|
39 |
+
model = getattr(hub, name)(pretrained=False)
|
40 |
+
else:
|
41 |
+
model = torch.hub.load("facebookresearch/pytorchvideo:main", model=name, pretrained=pretrained)
|
42 |
+
model.blocks[5] = nn.Identity()
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def create_x3d(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs):
|
47 |
+
if not pretrained:
|
48 |
+
from pytorchvideo.models import hub
|
49 |
+
model = getattr(hub, name)(pretrained=False)
|
50 |
+
else:
|
51 |
+
model = torch.hub.load("facebookresearch/pytorchvideo", model=name, pretrained=pretrained)
|
52 |
+
for idx, z in enumerate(z_strides):
|
53 |
+
assert z in [1, 2], "Only z-strides of 1 or 2 are supported"
|
54 |
+
if z == 2:
|
55 |
+
if idx == 0:
|
56 |
+
stem_layer = model.blocks[0].conv.conv_t
|
57 |
+
w = stem_layer.weight
|
58 |
+
w = w.repeat(1, 1, 3, 1, 1)
|
59 |
+
in_channels, out_channels = stem_layer.in_channels, stem_layer.out_channels
|
60 |
+
model.blocks[0].conv.conv_t = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
|
61 |
+
else:
|
62 |
+
model.blocks[idx].res_blocks[0].branch1_conv.stride = (2, 2, 2)
|
63 |
+
model.blocks[idx].res_blocks[0].branch2.conv_b.stride = (2, 2, 2)
|
64 |
+
|
65 |
+
if features_only:
|
66 |
+
model.blocks[-1] = nn.Identity()
|
67 |
+
model = X3D_Features(model)
|
68 |
+
else:
|
69 |
+
model.blocks[-1] = nn.Sequential(
|
70 |
+
model.blocks[-1].pool.pre_conv,
|
71 |
+
model.blocks[-1].pool.pre_norm,
|
72 |
+
model.blocks[-1].pool.pre_act,
|
73 |
+
)
|
74 |
+
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def create_i3d(name, pretrained, features_only=False, **kwargs):
|
79 |
+
from pytorchvideo.models import hub
|
80 |
+
model = getattr(hub, name)(pretrained=pretrained)
|
81 |
+
model.blocks[-1] = nn.Identity()
|
82 |
+
return model
|
83 |
+
|
84 |
+
|
85 |
+
class X3D_Features(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self, model):
|
88 |
+
super().__init__()
|
89 |
+
self.model = model
|
90 |
+
self.out_channels = [24, 24, 48, 96, 192]
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
features = []
|
94 |
+
for idx in range(len(self.model.blocks) - 1):
|
95 |
+
x = self.model.blocks[idx](x)
|
96 |
+
features.append(x)
|
97 |
+
return features
|
98 |
+
|
99 |
+
|
100 |
+
BACKBONES = {
|
101 |
+
"x3d_xs": partial(create_x3d, name="x3d_xs"),
|
102 |
+
"x3d_s": partial(create_x3d, name="x3d_s"),
|
103 |
+
"x3d_m": partial(create_x3d, name="x3d_m"),
|
104 |
+
"x3d_l": partial(create_x3d, name="x3d_l"),
|
105 |
+
"i3d_r50": partial(create_i3d, name="i3d_r50"),
|
106 |
+
"csn_r101": partial(create_csn, name="csn_r101"),
|
107 |
+
"ir_csn_50": ir_csn_50,
|
108 |
+
"ir_csn_101": ir_csn_101,
|
109 |
+
"ir_csn_152": ir_csn_152,
|
110 |
+
"ip_csn_50": ip_csn_50,
|
111 |
+
"ip_csn_101": ip_csn_101,
|
112 |
+
"ip_csn_152": ip_csn_152,
|
113 |
+
"r2plus1d_34": r2plus1d_34
|
114 |
+
}
|
skp/models/engine.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from pytorchvideo.models.x3d import create_x3d_stem
|
9 |
+
from timm.models.vision_transformer import VisionTransformer
|
10 |
+
from timm.models.swin_transformer_v2 import SwinTransformerV2
|
11 |
+
from . import backbones
|
12 |
+
from . import segmentation
|
13 |
+
from .pooling import create_pool2d_layer, create_pool3d_layer
|
14 |
+
from .sequence import Transformer, DualTransformer, DualTransformerV2
|
15 |
+
from .tools import change_initial_stride, change_num_input_channels
|
16 |
+
|
17 |
+
|
18 |
+
class Net2D(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
backbone,
|
22 |
+
pretrained,
|
23 |
+
num_classes,
|
24 |
+
dropout,
|
25 |
+
pool,
|
26 |
+
in_channels=3,
|
27 |
+
change_stride=None,
|
28 |
+
feature_reduction=None,
|
29 |
+
multisample_dropout=False,
|
30 |
+
load_pretrained_backbone=None,
|
31 |
+
freeze_backbone=False,
|
32 |
+
backbone_params={},
|
33 |
+
pool_layer_params={}):
|
34 |
+
|
35 |
+
super().__init__()
|
36 |
+
self.backbone, dim_feats = backbones.create_backbone(name=backbone, pretrained=pretrained, **backbone_params)
|
37 |
+
if isinstance(pool, str):
|
38 |
+
self.pool_layer = create_pool2d_layer(name=pool, **pool_layer_params)
|
39 |
+
else:
|
40 |
+
self.pool_layer = nn.Identity()
|
41 |
+
if pool == "catavgmax":
|
42 |
+
dim_feats *= 2
|
43 |
+
self.msdo = multisample_dropout
|
44 |
+
if in_channels != 3:
|
45 |
+
self.backbone = change_num_input_channels(self.backbone, in_channels)
|
46 |
+
if change_stride:
|
47 |
+
self.backbone = change_initial_stride(self.backbone, tuple(change_stride), in_channels)
|
48 |
+
self.dropout = nn.Dropout(p=dropout)
|
49 |
+
if isinstance(feature_reduction, int):
|
50 |
+
# Use 1D grouped convolution to reduce # of parameters
|
51 |
+
groups = math.gcd(dim_feats, feature_reduction)
|
52 |
+
self.feature_reduction = nn.Conv1d(dim_feats, feature_reduction, groups=groups, kernel_size=1,
|
53 |
+
stride=1, bias=False)
|
54 |
+
dim_feats = feature_reduction
|
55 |
+
self.classifier = nn.Linear(dim_feats, num_classes)
|
56 |
+
|
57 |
+
if load_pretrained_backbone:
|
58 |
+
# Assumes that model has a `backbone` attribute
|
59 |
+
# Note: if you want to load the entire pretrained model, this is done via the
|
60 |
+
# builder.build_model function
|
61 |
+
print(f"Loading pretrained backbone from {load_pretrained_backbone} ...")
|
62 |
+
weights = torch.load(load_pretrained_backbone, map_location=lambda storage, loc: storage)['state_dict']
|
63 |
+
weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items()}
|
64 |
+
# Get feature_reduction, if present
|
65 |
+
feat_reduce_weight = {re.sub(r"^feature_reduction.", "", k): v
|
66 |
+
for k, v in weights.items() if "feature_reduction" in k}
|
67 |
+
# Get backbone only
|
68 |
+
weights = {re.sub(r'^backbone.', '', k) : v for k,v in weights.items() if 'backbone' in k}
|
69 |
+
self.backbone.load_state_dict(weights)
|
70 |
+
if len(feat_reduce_weight) > 0:
|
71 |
+
print("Also loading feature reduction layer ...")
|
72 |
+
self.feature_reduction.load_state_dict(feat_reduce_weight)
|
73 |
+
|
74 |
+
if freeze_backbone:
|
75 |
+
print("Freezing backbone ...")
|
76 |
+
for param in self.backbone.parameters():
|
77 |
+
param.requires_grad = False
|
78 |
+
|
79 |
+
def extract_features(self, x):
|
80 |
+
features = self.backbone(x)
|
81 |
+
features = self.pool_layer(features)
|
82 |
+
if isinstance(self.backbone, VisionTransformer):
|
83 |
+
features = features[:, self.backbone.num_prefix_tokens:].mean(dim=1)
|
84 |
+
if isinstance(self.backbone, SwinTransformerV2):
|
85 |
+
features = features.mean(dim=1)
|
86 |
+
if hasattr(self, "feature_reduction"):
|
87 |
+
features = self.feature_reduction(features.unsqueeze(-1)).squeeze(-1)
|
88 |
+
return features
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
features = self.extract_features(x)
|
92 |
+
if self.msdo:
|
93 |
+
x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
|
94 |
+
else:
|
95 |
+
x = self.classifier(self.dropout(features))
|
96 |
+
# Important nuance:
|
97 |
+
# For binary classification, the model returns a tensor of shape (N,)
|
98 |
+
# Otherwise, (N,C)
|
99 |
+
return x[:, 0] if self.classifier.out_features == 1 else x
|
100 |
+
|
101 |
+
|
102 |
+
class SeqNet2D(Net2D):
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
# x.shape = (N, C, Z, H, W)
|
106 |
+
features = torch.stack([self.extract_features(x[:, :, _]) for _ in range(x.size(2))], dim=2)
|
107 |
+
features = features.max(2)[0]
|
108 |
+
|
109 |
+
if self.msdo:
|
110 |
+
x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
|
111 |
+
else:
|
112 |
+
x = self.classifier(self.dropout(features))
|
113 |
+
# Important nuance:
|
114 |
+
# For binary classification, the model returns a tensor of shape (N,)
|
115 |
+
# Otherwise, (N,C)
|
116 |
+
return x[:, 0] if self.classifier.out_features == 1 else x
|
117 |
+
|
118 |
+
|
119 |
+
class TDCNN(nn.Module):
|
120 |
+
|
121 |
+
def __init__(self, cnn_params, transformer_params, freeze_cnn=False, freeze_transformer=False):
|
122 |
+
super().__init__()
|
123 |
+
self.cnn = Net2D(**cnn_params)
|
124 |
+
del self.cnn.dropout
|
125 |
+
del self.cnn.classifier
|
126 |
+
self.transformer = Transformer(**transformer_params)
|
127 |
+
|
128 |
+
if freeze_cnn:
|
129 |
+
for param in self.cnn.parameters():
|
130 |
+
param.requires_grad = False
|
131 |
+
|
132 |
+
if freeze_transformer:
|
133 |
+
for param in self.transformer.parameters():
|
134 |
+
param.requires_grad = False
|
135 |
+
|
136 |
+
def extract_features(self, x):
|
137 |
+
N, C, Z, H, W = x.size()
|
138 |
+
assert N == 1, "For feature extraction, batch size must be 1"
|
139 |
+
features = self.cnn.extract_features(x.squeeze(0).transpose(0, 1)).unsqueeze(0)
|
140 |
+
# features.shape = (1, Z, dim_feats)
|
141 |
+
return self.transformer.extract_features((features, torch.ones((features.size(0), features.size(1))).to(features.device)))
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
# BCZHW
|
145 |
+
features = torch.stack([self.cnn.extract_features(x[:, :, i]) for i in range(x.size(2))], dim=1)
|
146 |
+
# B, seq_len, dim_feat
|
147 |
+
return self.transformer((features, torch.ones((features.size(0), features.size(1))).to(features.device)))
|
148 |
+
|
149 |
+
|
150 |
+
class Net2DWith3DStem(Net2D):
|
151 |
+
|
152 |
+
def __init__(self, *args, **kwargs):
|
153 |
+
stem_out_channels = kwargs.pop("stem_out_channels", 24)
|
154 |
+
load_pretrained_stem = kwargs.pop("load_pretrained_stem", None)
|
155 |
+
conv_kernel_size = tuple(kwargs.pop("conv_kernel_size", (5, 3, 3)))
|
156 |
+
conv_stride = tuple(kwargs.pop("conv_stride", (1, 2, 2)))
|
157 |
+
in_channels = kwargs.pop("in_channels", 3)
|
158 |
+
kwargs["in_channels"] = stem_out_channels
|
159 |
+
super().__init__(*args, **kwargs)
|
160 |
+
self.stem_layer = create_x3d_stem(in_channels=in_channels,
|
161 |
+
out_channels=stem_out_channels,
|
162 |
+
conv_kernel_size=conv_kernel_size,
|
163 |
+
conv_stride=conv_stride)
|
164 |
+
if kwargs["pretrained"]:
|
165 |
+
from pytorchvideo.models.hub import x3d_l
|
166 |
+
self.stem_layer.load_state_dict(x3d_l(pretrained=True).blocks[0].state_dict())
|
167 |
+
|
168 |
+
if load_pretrained_stem:
|
169 |
+
import re
|
170 |
+
print(f" Loading pretrained stem from {load_pretrained_stem} ...")
|
171 |
+
weights = torch.load(load_pretrained_stem, map_location=lambda storage, loc: storage)['state_dict']
|
172 |
+
stem_weights = {k.replace("model.backbone.blocks.0.", ""): v for k, v in weights.items() if "backbone.blocks.0" in k}
|
173 |
+
self.stem_layer.load_state_dict(stem_weights)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
x = self.stem_layer(x)
|
177 |
+
x = x.mean(3)
|
178 |
+
features = self.extract_features(x)
|
179 |
+
if self.msdo:
|
180 |
+
x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
|
181 |
+
else:
|
182 |
+
x = self.classifier(self.dropout(features))
|
183 |
+
# Important nuance:
|
184 |
+
# For binary classification, the model returns a tensor of shape (N,)
|
185 |
+
# Otherwise, (N,C)
|
186 |
+
return x[:, 0] if self.classifier.out_features == 1 else x
|
187 |
+
|
188 |
+
|
189 |
+
class Net3D(Net2D):
|
190 |
+
|
191 |
+
def __init__(self, *args, **kwargs):
|
192 |
+
z_strides = kwargs.pop("z_strides", [1,1,1,1,1])
|
193 |
+
super().__init__(*args, **kwargs)
|
194 |
+
self.pool_layer = create_pool3d_layer(name=kwargs["pool"], **kwargs.pop("pool_layer_params", {}))
|
195 |
+
|
196 |
+
|
197 |
+
class NetSegment2D(nn.Module):
|
198 |
+
""" For now, this class essentially servers as a wrapper for the
|
199 |
+
segmentation model which is mostly defined in the segmentation submodule,
|
200 |
+
similar to the original segmentation_models.pytorch.
|
201 |
+
|
202 |
+
It may be worth refactoring it in the future, such that you define this as
|
203 |
+
a general class, then select your choice of encoder and decoder. The encoder
|
204 |
+
is pretty much the same across all the segmentation models currently
|
205 |
+
implemented (DeepLabV3+, FPN, Unet).
|
206 |
+
"""
|
207 |
+
def __init__(self,
|
208 |
+
architecture,
|
209 |
+
encoder_name,
|
210 |
+
encoder_params,
|
211 |
+
decoder_params,
|
212 |
+
num_classes,
|
213 |
+
dropout,
|
214 |
+
in_channels,
|
215 |
+
load_pretrained_encoder=None,
|
216 |
+
freeze_encoder=False,
|
217 |
+
deep_supervision=False,
|
218 |
+
pool_layer_params={},
|
219 |
+
aux_head_params={}):
|
220 |
+
|
221 |
+
super().__init__()
|
222 |
+
|
223 |
+
self.segmentation_model = getattr(segmentation, architecture)(
|
224 |
+
encoder_name=encoder_name,
|
225 |
+
encoder_params=encoder_params,
|
226 |
+
dropout=dropout,
|
227 |
+
classes=num_classes,
|
228 |
+
deep_supervision=deep_supervision,
|
229 |
+
in_channels=in_channels,
|
230 |
+
**decoder_params
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
if load_pretrained_encoder:
|
235 |
+
# Assumes that model has a `encoder` attribute
|
236 |
+
# Note: if you want to load the entire pretrained model, this is done via the
|
237 |
+
# builder.build_model function
|
238 |
+
print(f"Loading pretrained encoder from {load_pretrained_encoder} ...")
|
239 |
+
weights = torch.load(load_pretrained_encoder, map_location=lambda storage, loc: storage)['state_dict']
|
240 |
+
weights = {re.sub(r'^model.segmentation_model', '', k) : v for k,v in weights.items()}
|
241 |
+
# Get encoder only
|
242 |
+
weights = {re.sub(r'^encoder.', '', k) : v for k,v in weights.items() if 'backbone' in k}
|
243 |
+
self.segmentation_model.encoder.load_state_dict(weights)
|
244 |
+
|
245 |
+
if freeze_encoder:
|
246 |
+
print("Freezing encoder ...")
|
247 |
+
for param in self.segmentation_model.encoder.parameters():
|
248 |
+
param.requires_grad = False
|
249 |
+
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
return self.segmentation_model(x)
|
253 |
+
|
254 |
+
|
255 |
+
class NetSegment3D(NetSegment2D):
|
256 |
+
|
257 |
+
pass
|
skp/models/pooling/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .pool3d import create_pool3d_layer
|
2 |
+
from .pool2d import create_pool2d_layer
|
3 |
+
from .pool1d import create_pool1d_layer
|
skp/models/pooling/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (334 Bytes). View file
|
|
skp/models/pooling/__pycache__/gem.cpython-39.pyc
ADDED
Binary file (1.67 kB). View file
|
|
skp/models/pooling/__pycache__/pool1d.cpython-39.pyc
ADDED
Binary file (4.28 kB). View file
|
|
skp/models/pooling/__pycache__/pool2d.cpython-39.pyc
ADDED
Binary file (678 Bytes). View file
|
|
skp/models/pooling/__pycache__/pool3d.cpython-39.pyc
ADDED
Binary file (4.29 kB). View file
|
|
skp/models/pooling/gem.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
# From: https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/layers/pooling.py
|
7 |
+
def gem_1d(x, p=3, eps=1e-6):
|
8 |
+
return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1),)).pow(1./p)
|
9 |
+
|
10 |
+
|
11 |
+
def gem_2d(x, p=3, eps=1e-6):
|
12 |
+
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
|
13 |
+
|
14 |
+
|
15 |
+
def gem_3d(x, p=3, eps=1e-6):
|
16 |
+
return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(-3), x.size(-2), x.size(-1))).pow(1./p)
|
17 |
+
|
18 |
+
|
19 |
+
_GEM_FN = {
|
20 |
+
1: gem_1d, 2: gem_2d, 3: gem_3d
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
class GeM(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, p=3, eps=1e-6, dim=2):
|
27 |
+
super().__init__()
|
28 |
+
self.p = nn.Parameter(torch.ones(1)*p)
|
29 |
+
self.eps = eps
|
30 |
+
self.dim = dim
|
31 |
+
self.flatten = nn.Flatten(1)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
pooled = _GEM_FN[self.dim](x, p=self.p, eps=self.eps)
|
35 |
+
return self.flatten(pooled)
|
skp/models/pooling/pool1d.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .gem import GeM
|
6 |
+
|
7 |
+
|
8 |
+
def adaptive_avgmax_pool1d(x, output_size=1):
|
9 |
+
x_avg = F.adaptive_avg_pool1d(x, output_size)
|
10 |
+
x_max = F.adaptive_max_pool1d(x, output_size)
|
11 |
+
return 0.5 * (x_avg + x_max)
|
12 |
+
|
13 |
+
|
14 |
+
def adaptive_catavgmax_pool1d(x, output_size=1):
|
15 |
+
x_avg = F.adaptive_avg_pool1d(x, output_size)
|
16 |
+
x_max = F.adaptive_max_pool1d(x, output_size)
|
17 |
+
return torch.cat((x_avg, x_max), 1)
|
18 |
+
|
19 |
+
|
20 |
+
def select_adaptive_pool1d(x, pool_type='avg', output_size=1):
|
21 |
+
"""Selectable global pooling function with dynamic input kernel size
|
22 |
+
"""
|
23 |
+
if pool_type == 'avg':
|
24 |
+
x = F.adaptive_avg_pool1d(x, output_size)
|
25 |
+
elif pool_type == 'avgmax':
|
26 |
+
x = adaptive_avgmax_pool1d(x, output_size)
|
27 |
+
elif pool_type == 'catavgmax':
|
28 |
+
x = adaptive_catavgmax_pool1d(x, output_size)
|
29 |
+
elif pool_type == 'max':
|
30 |
+
x = F.adaptive_max_pool1d(x, output_size)
|
31 |
+
else:
|
32 |
+
assert False, 'Invalid pool type: %s' % pool_type
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class FastAdaptiveAvgPool1d(nn.Module):
|
37 |
+
def __init__(self, flatten=False):
|
38 |
+
super(FastAdaptiveAvgPool1d, self).__init__()
|
39 |
+
self.flatten = flatten
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return x.mean(2, keepdim=not self.flatten)
|
43 |
+
|
44 |
+
|
45 |
+
class AdaptiveAvgMaxPool1d(nn.Module):
|
46 |
+
def __init__(self, output_size=1):
|
47 |
+
super(AdaptiveAvgMaxPool1d, self).__init__()
|
48 |
+
self.output_size = output_size
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
return adaptive_avgmax_pool1d(x, self.output_size)
|
52 |
+
|
53 |
+
|
54 |
+
class AdaptiveCatAvgMaxPool1d(nn.Module):
|
55 |
+
def __init__(self, output_size=1):
|
56 |
+
super(AdaptiveCatAvgMaxPool1d, self).__init__()
|
57 |
+
self.output_size = output_size
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
return adaptive_catavgmax_pool1d(x, self.output_size)
|
61 |
+
|
62 |
+
|
63 |
+
class SelectAdaptivePool1d(nn.Module):
|
64 |
+
"""Selectable global pooling layer with dynamic input kernel size
|
65 |
+
"""
|
66 |
+
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
67 |
+
super(SelectAdaptivePool1d, self).__init__()
|
68 |
+
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
69 |
+
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
70 |
+
if pool_type == '':
|
71 |
+
self.pool = nn.Identity() # pass through
|
72 |
+
elif pool_type == 'fast':
|
73 |
+
assert output_size == 1
|
74 |
+
self.pool = FastAdaptiveAvgPool1d(flatten)
|
75 |
+
self.flatten = nn.Identity()
|
76 |
+
elif pool_type == 'avg':
|
77 |
+
self.pool = nn.AdaptiveAvgPool1d(output_size)
|
78 |
+
elif pool_type == 'avgmax':
|
79 |
+
self.pool = AdaptiveAvgMaxPool1d(output_size)
|
80 |
+
elif pool_type == 'catavgmax':
|
81 |
+
self.pool = AdaptiveCatAvgMaxPool1d(output_size)
|
82 |
+
elif pool_type == 'max':
|
83 |
+
self.pool = nn.AdaptiveMaxPool1d(output_size)
|
84 |
+
else:
|
85 |
+
assert False, 'Invalid pool type: %s' % pool_type
|
86 |
+
|
87 |
+
def is_identity(self):
|
88 |
+
return not self.pool_type
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.pool(x)
|
92 |
+
x = self.flatten(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
def __repr__(self):
|
96 |
+
return self.__class__.__name__ + ' (' \
|
97 |
+
+ 'pool_type=' + self.pool_type \
|
98 |
+
+ ', flatten=' + str(self.flatten) + ')'
|
99 |
+
|
100 |
+
|
101 |
+
def create_pool1d_layer(name, **kwargs):
|
102 |
+
assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
|
103 |
+
if name != "gem":
|
104 |
+
pool1d_layer = SelectAdaptivePool1d(pool_type=name, flatten=True)
|
105 |
+
elif name == "gem":
|
106 |
+
pool1d_layer = GeM(dim=1, **kwargs)
|
107 |
+
return pool1d_layer
|
skp/models/pooling/pool2d.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from timm.models.layers import SelectAdaptivePool2d
|
6 |
+
|
7 |
+
from .gem import GeM
|
8 |
+
|
9 |
+
|
10 |
+
def create_pool2d_layer(name, **kwargs):
|
11 |
+
assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
|
12 |
+
if name != "gem":
|
13 |
+
pool2d_layer = SelectAdaptivePool2d(pool_type=name, flatten=True)
|
14 |
+
elif name == "gem":
|
15 |
+
pool2d_layer = GeM(dim=2, **kwargs)
|
16 |
+
return pool2d_layer
|
skp/models/pooling/pool3d.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .gem import GeM
|
6 |
+
|
7 |
+
|
8 |
+
def adaptive_avgmax_pool3d(x, output_size=1):
|
9 |
+
x_avg = F.adaptive_avg_pool3d(x, output_size)
|
10 |
+
x_max = F.adaptive_max_pool3d(x, output_size)
|
11 |
+
return 0.5 * (x_avg + x_max)
|
12 |
+
|
13 |
+
|
14 |
+
def adaptive_catavgmax_pool3d(x, output_size=1):
|
15 |
+
x_avg = F.adaptive_avg_pool3d(x, output_size)
|
16 |
+
x_max = F.adaptive_max_pool3d(x, output_size)
|
17 |
+
return torch.cat((x_avg, x_max), 1)
|
18 |
+
|
19 |
+
|
20 |
+
def select_adaptive_pool3d(x, pool_type='avg', output_size=1):
|
21 |
+
"""Selectable global pooling function with dynamic input kernel size
|
22 |
+
"""
|
23 |
+
if pool_type == 'avg':
|
24 |
+
x = F.adaptive_avg_pool3d(x, output_size)
|
25 |
+
elif pool_type == 'avgmax':
|
26 |
+
x = adaptive_avgmax_pool3d(x, output_size)
|
27 |
+
elif pool_type == 'catavgmax':
|
28 |
+
x = adaptive_catavgmax_pool3d(x, output_size)
|
29 |
+
elif pool_type == 'max':
|
30 |
+
x = F.adaptive_max_pool3d(x, output_size)
|
31 |
+
else:
|
32 |
+
assert False, 'Invalid pool type: %s' % pool_type
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class FastAdaptiveAvgPool3d(nn.Module):
|
37 |
+
def __init__(self, flatten=False):
|
38 |
+
super(FastAdaptiveAvgPool3d, self).__init__()
|
39 |
+
self.flatten = flatten
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return x.mean((2,3,4), keepdim=not self.flatten)
|
43 |
+
|
44 |
+
|
45 |
+
class AdaptiveAvgMaxPool3d(nn.Module):
|
46 |
+
def __init__(self, output_size=1):
|
47 |
+
super(AdaptiveAvgMaxPool3d, self).__init__()
|
48 |
+
self.output_size = output_size
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
return adaptive_avgmax_pool3d(x, self.output_size)
|
52 |
+
|
53 |
+
|
54 |
+
class AdaptiveCatAvgMaxPool3d(nn.Module):
|
55 |
+
def __init__(self, output_size=1):
|
56 |
+
super(AdaptiveCatAvgMaxPool3d, self).__init__()
|
57 |
+
self.output_size = output_size
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
return adaptive_catavgmax_pool3d(x, self.output_size)
|
61 |
+
|
62 |
+
|
63 |
+
class SelectAdaptivePool3d(nn.Module):
|
64 |
+
"""Selectable global pooling layer with dynamic input kernel size
|
65 |
+
"""
|
66 |
+
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
67 |
+
super(SelectAdaptivePool3d, self).__init__()
|
68 |
+
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
69 |
+
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
70 |
+
if pool_type == '':
|
71 |
+
self.pool = nn.Identity() # pass through
|
72 |
+
elif pool_type == 'fast':
|
73 |
+
assert output_size == 1
|
74 |
+
self.pool = FastAdaptiveAvgPool3d(flatten)
|
75 |
+
self.flatten = nn.Identity()
|
76 |
+
elif pool_type == 'avg':
|
77 |
+
self.pool = nn.AdaptiveAvgPool3d(output_size)
|
78 |
+
elif pool_type == 'avgmax':
|
79 |
+
self.pool = AdaptiveAvgMaxPool3d(output_size)
|
80 |
+
elif pool_type == 'catavgmax':
|
81 |
+
self.pool = AdaptiveCatAvgMaxPool3d(output_size)
|
82 |
+
elif pool_type == 'max':
|
83 |
+
self.pool = nn.AdaptiveMaxPool3d(output_size)
|
84 |
+
else:
|
85 |
+
assert False, 'Invalid pool type: %s' % pool_type
|
86 |
+
|
87 |
+
def is_identity(self):
|
88 |
+
return not self.pool_type
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.pool(x)
|
92 |
+
x = self.flatten(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
def __repr__(self):
|
96 |
+
return self.__class__.__name__ + ' (' \
|
97 |
+
+ 'pool_type=' + self.pool_type \
|
98 |
+
+ ', flatten=' + str(self.flatten) + ')'
|
99 |
+
|
100 |
+
|
101 |
+
def create_pool3d_layer(name, **kwargs):
|
102 |
+
assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
|
103 |
+
if name != "gem":
|
104 |
+
pool1d_layer = SelectAdaptivePool3d(pool_type=name, flatten=True)
|
105 |
+
elif name == "gem":
|
106 |
+
pool1d_layer = GeM(dim=3, **kwargs)
|
107 |
+
return pool1d_layer
|
skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: imagenet
|
4 |
+
BATCH_SIZE: 256
|
5 |
+
EVAL_PERIOD: 10
|
6 |
+
CHECKPOINT_PERIOD: 1
|
7 |
+
AUTO_RESUME: True
|
8 |
+
|
9 |
+
DATA:
|
10 |
+
# PATH_TO_DATA_DIR: path-to-imagenet-dir
|
11 |
+
MEAN: [0.485, 0.456, 0.406]
|
12 |
+
STD: [0.229, 0.224, 0.225]
|
13 |
+
NUM_FRAMES: 64
|
14 |
+
TRAIN_CROP_SIZE: 224
|
15 |
+
TEST_CROP_SIZE: 224
|
16 |
+
INPUT_CHANNEL_NUM: [3]
|
17 |
+
MVIT:
|
18 |
+
PATCH_2D: False
|
19 |
+
ZERO_DECAY_POS_CLS: False
|
20 |
+
MODE: "conv"
|
21 |
+
CLS_EMBED_ON: False
|
22 |
+
PATCH_KERNEL: [3, 7, 7]
|
23 |
+
PATCH_STRIDE: [2, 4, 4]
|
24 |
+
PATCH_PADDING: [1, 3, 3]
|
25 |
+
EMBED_DIM: 96
|
26 |
+
NUM_HEADS: 1
|
27 |
+
MLP_RATIO: 4.0
|
28 |
+
QKV_BIAS: True
|
29 |
+
DROPPATH_RATE: 0.1
|
30 |
+
DROPOUT_RATE: 0.0
|
31 |
+
DEPTH: 16
|
32 |
+
LAYER_SCALE_INIT_VALUE: 0.0
|
33 |
+
HEAD_INIT_SCALE: 1.0
|
34 |
+
USE_MEAN_POOLING: False
|
35 |
+
USE_ABS_POS: True
|
36 |
+
USE_FIXED_SINCOS_POS: False
|
37 |
+
SEP_POS_EMBED: False
|
38 |
+
REL_POS_SPATIAL: False
|
39 |
+
REL_POS_TEMPORAL: False
|
40 |
+
REL_POS_ZERO_INIT: False
|
41 |
+
RESIDUAL_POOLING: False
|
42 |
+
NORM: "layernorm"
|
43 |
+
NORM_STEM: False
|
44 |
+
DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]]
|
45 |
+
HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]]
|
46 |
+
POOL_FIRST: null
|
47 |
+
POOL_KVQ_KERNEL: [1, 3, 3]
|
48 |
+
POOL_KV_STRIDE_ADAPTIVE: [1, 4, 4]
|
49 |
+
POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]]
|
50 |
+
SEPARATE_QKV : True
|
51 |
+
REV:
|
52 |
+
ENABLE: True
|
53 |
+
RESPATH_FUSE: "concat"
|
54 |
+
BUFFER_LAYERS : [1,3, 14]
|
55 |
+
RES_PATH : "conv"
|
56 |
+
PRE_Q_FUSION: "concat_linear_2"
|
57 |
+
DETECTION:
|
58 |
+
ENABLE: False
|
59 |
+
AUG:
|
60 |
+
ENABLE: True
|
61 |
+
COLOR_JITTER: 0.4
|
62 |
+
AA_TYPE: rand-m9-n6-mstd0.5-inc1
|
63 |
+
INTERPOLATION: bicubic
|
64 |
+
RE_PROB: 0.25
|
65 |
+
RE_MODE: pixel
|
66 |
+
RE_COUNT: 1
|
67 |
+
RE_SPLIT: False
|
68 |
+
MIXUP:
|
69 |
+
ENABLE: True
|
70 |
+
ALPHA: 0.8
|
71 |
+
CUTMIX_ALPHA: 1.0
|
72 |
+
PROB: 1.0
|
73 |
+
SWITCH_PROB: 0.5
|
74 |
+
LABEL_SMOOTH_VALUE: 0.1
|
75 |
+
SOLVER:
|
76 |
+
BASE_LR_SCALE_NUM_SHARDS: True
|
77 |
+
BASE_LR: 0.00025
|
78 |
+
LR_POLICY: cosine
|
79 |
+
MAX_EPOCH: 300
|
80 |
+
MOMENTUM: 0.9
|
81 |
+
WEIGHT_DECAY: 0.05
|
82 |
+
WARMUP_EPOCHS: 70.0
|
83 |
+
WARMUP_START_LR: 1e-8
|
84 |
+
OPTIMIZING_METHOD: adamw
|
85 |
+
COSINE_AFTER_WARMUP: True
|
86 |
+
COSINE_END_LR: 1e-6
|
87 |
+
ZERO_WD_1D_PARAM: True
|
88 |
+
CLIP_GRAD_L2NORM: 1.0
|
89 |
+
MODEL:
|
90 |
+
NUM_CLASSES: 1000
|
91 |
+
ARCH: mvit
|
92 |
+
MODEL_NAME: MViT
|
93 |
+
LOSS_FUNC: soft_cross_entropy
|
94 |
+
DROPOUT_RATE: 0.0
|
95 |
+
HEAD_ACT: "softmax"
|
96 |
+
DETACH_FINAL_FC: False
|
97 |
+
CONTRASTIVE:
|
98 |
+
NUM_MLP_LAYERS: 1
|
99 |
+
TEST:
|
100 |
+
ENABLE: False
|
101 |
+
DATASET: imagenet
|
102 |
+
BATCH_SIZE: 256
|
103 |
+
DATA_LOADER:
|
104 |
+
NUM_WORKERS: 8
|
105 |
+
PIN_MEMORY: True
|
106 |
+
NUM_GPUS: 2
|
107 |
+
NUM_SHARDS: 1
|
108 |
+
RNG_SEED: 0
|
109 |
+
OUTPUT_DIR: .
|
skp/models/rev_mvit/__init__.py
ADDED
File without changes
|
skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (161 Bytes). View file
|
|
skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc
ADDED
Binary file (10.4 kB). View file
|
|
skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc
ADDED
Binary file (3.65 kB). View file
|
|
skp/models/rev_mvit/__pycache__/common.cpython-39.pyc
ADDED
Binary file (4.94 kB). View file
|
|
skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc
ADDED
Binary file (3.46 kB). View file
|
|
skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc
ADDED
Binary file (13.8 kB). View file
|
|
skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc
ADDED
Binary file (8.37 kB). View file
|
|
skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (5.42 kB). View file
|
|
skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc
ADDED
Binary file (10.8 kB). View file
|
|
skp/models/rev_mvit/attention.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.nn.init import trunc_normal_
|
10 |
+
|
11 |
+
from .common import DropPath, Mlp
|
12 |
+
|
13 |
+
|
14 |
+
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
|
15 |
+
if pool is None:
|
16 |
+
return tensor, thw_shape
|
17 |
+
tensor_dim = tensor.ndim
|
18 |
+
if tensor_dim == 4:
|
19 |
+
pass
|
20 |
+
elif tensor_dim == 3:
|
21 |
+
tensor = tensor.unsqueeze(1)
|
22 |
+
else:
|
23 |
+
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
|
24 |
+
|
25 |
+
if has_cls_embed:
|
26 |
+
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
|
27 |
+
|
28 |
+
B, N, L, C = tensor.shape
|
29 |
+
T, H, W = thw_shape
|
30 |
+
tensor = (
|
31 |
+
tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
|
32 |
+
)
|
33 |
+
|
34 |
+
tensor = pool(tensor)
|
35 |
+
|
36 |
+
thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
|
37 |
+
L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
|
38 |
+
tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
|
39 |
+
if has_cls_embed:
|
40 |
+
tensor = torch.cat((cls_tok, tensor), dim=2)
|
41 |
+
if norm is not None:
|
42 |
+
tensor = norm(tensor)
|
43 |
+
# Assert tensor_dim in [3, 4]
|
44 |
+
if tensor_dim == 4:
|
45 |
+
pass
|
46 |
+
else: # tensor_dim == 3:
|
47 |
+
tensor = tensor.squeeze(1)
|
48 |
+
return tensor, thw_shape
|
49 |
+
|
50 |
+
|
51 |
+
def get_rel_pos(rel_pos, d):
|
52 |
+
if isinstance(d, int):
|
53 |
+
ori_d = rel_pos.shape[0]
|
54 |
+
if ori_d == d:
|
55 |
+
return rel_pos
|
56 |
+
else:
|
57 |
+
# Interpolate rel pos.
|
58 |
+
new_pos_embed = F.interpolate(
|
59 |
+
rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1),
|
60 |
+
size=d,
|
61 |
+
mode="linear",
|
62 |
+
)
|
63 |
+
|
64 |
+
return new_pos_embed.reshape(-1, d).permute(1, 0)
|
65 |
+
|
66 |
+
|
67 |
+
def cal_rel_pos_spatial(
|
68 |
+
attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Decomposed Spatial Relative Positional Embeddings.
|
72 |
+
"""
|
73 |
+
sp_idx = 1 if has_cls_embed else 0
|
74 |
+
q_t, q_h, q_w = q_shape
|
75 |
+
k_t, k_h, k_w = k_shape
|
76 |
+
dh = int(2 * max(q_h, k_h) - 1)
|
77 |
+
dw = int(2 * max(q_w, k_w) - 1)
|
78 |
+
|
79 |
+
# Scale up rel pos if shapes for q and k are different.
|
80 |
+
q_h_ratio = max(k_h / q_h, 1.0)
|
81 |
+
k_h_ratio = max(q_h / k_h, 1.0)
|
82 |
+
dist_h = (
|
83 |
+
torch.arange(q_h)[:, None] * q_h_ratio
|
84 |
+
- torch.arange(k_h)[None, :] * k_h_ratio
|
85 |
+
)
|
86 |
+
dist_h += (k_h - 1) * k_h_ratio
|
87 |
+
q_w_ratio = max(k_w / q_w, 1.0)
|
88 |
+
k_w_ratio = max(q_w / k_w, 1.0)
|
89 |
+
dist_w = (
|
90 |
+
torch.arange(q_w)[:, None] * q_w_ratio
|
91 |
+
- torch.arange(k_w)[None, :] * k_w_ratio
|
92 |
+
)
|
93 |
+
dist_w += (k_w - 1) * k_w_ratio
|
94 |
+
|
95 |
+
# Intepolate rel pos if needed.
|
96 |
+
rel_pos_h = get_rel_pos(rel_pos_h, dh)
|
97 |
+
rel_pos_w = get_rel_pos(rel_pos_w, dw)
|
98 |
+
Rh = rel_pos_h[dist_h.long()]
|
99 |
+
Rw = rel_pos_w[dist_w.long()]
|
100 |
+
|
101 |
+
B, n_head, q_N, dim = q.shape
|
102 |
+
|
103 |
+
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
|
104 |
+
rel_h_q = torch.einsum(
|
105 |
+
"bythwc,hkc->bythwk", r_q, Rh
|
106 |
+
) # [B, H, q_t, qh, qw, k_h]
|
107 |
+
rel_w_q = torch.einsum(
|
108 |
+
"bythwc,wkc->bythwk", r_q, Rw
|
109 |
+
) # [B, H, q_t, qh, qw, k_w]
|
110 |
+
|
111 |
+
attn[:, :, sp_idx:, sp_idx:] = (
|
112 |
+
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
|
113 |
+
+ rel_h_q[:, :, :, :, :, None, :, None]
|
114 |
+
+ rel_w_q[:, :, :, :, :, None, None, :]
|
115 |
+
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w)
|
116 |
+
|
117 |
+
return attn
|
118 |
+
|
119 |
+
|
120 |
+
def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t):
|
121 |
+
"""
|
122 |
+
Temporal Relative Positional Embeddings.
|
123 |
+
"""
|
124 |
+
sp_idx = 1 if has_cls_embed else 0
|
125 |
+
q_t, q_h, q_w = q_shape
|
126 |
+
k_t, k_h, k_w = k_shape
|
127 |
+
dt = int(2 * max(q_t, k_t) - 1)
|
128 |
+
# Intepolate rel pos if needed.
|
129 |
+
rel_pos_t = get_rel_pos(rel_pos_t, dt)
|
130 |
+
|
131 |
+
# Scale up rel pos if shapes for q and k are different.
|
132 |
+
q_t_ratio = max(k_t / q_t, 1.0)
|
133 |
+
k_t_ratio = max(q_t / k_t, 1.0)
|
134 |
+
dist_t = (
|
135 |
+
torch.arange(q_t)[:, None] * q_t_ratio
|
136 |
+
- torch.arange(k_t)[None, :] * k_t_ratio
|
137 |
+
)
|
138 |
+
dist_t += (k_t - 1) * k_t_ratio
|
139 |
+
Rt = rel_pos_t[dist_t.long()]
|
140 |
+
|
141 |
+
B, n_head, q_N, dim = q.shape
|
142 |
+
|
143 |
+
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
|
144 |
+
# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
|
145 |
+
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(
|
146 |
+
q_t, B * n_head * q_h * q_w, dim
|
147 |
+
)
|
148 |
+
|
149 |
+
# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
|
150 |
+
rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
|
151 |
+
# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
|
152 |
+
rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
|
153 |
+
|
154 |
+
attn[:, :, sp_idx:, sp_idx:] = (
|
155 |
+
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
|
156 |
+
+ rel[:, :, :, :, :, :, None, None]
|
157 |
+
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w)
|
158 |
+
|
159 |
+
return attn
|
160 |
+
|
161 |
+
|
162 |
+
class MultiScaleAttention(nn.Module):
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
dim,
|
166 |
+
dim_out,
|
167 |
+
input_size,
|
168 |
+
num_heads=8,
|
169 |
+
qkv_bias=False,
|
170 |
+
drop_rate=0.0,
|
171 |
+
kernel_q=(1, 1, 1),
|
172 |
+
kernel_kv=(1, 1, 1),
|
173 |
+
stride_q=(1, 1, 1),
|
174 |
+
stride_kv=(1, 1, 1),
|
175 |
+
norm_layer=nn.LayerNorm,
|
176 |
+
has_cls_embed=True,
|
177 |
+
# Options include `conv`, `avg`, and `max`.
|
178 |
+
mode="conv",
|
179 |
+
# If True, perform pool before projection.
|
180 |
+
pool_first=False,
|
181 |
+
rel_pos_spatial=False,
|
182 |
+
rel_pos_temporal=False,
|
183 |
+
rel_pos_zero_init=False,
|
184 |
+
residual_pooling=False,
|
185 |
+
separate_qkv=False,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.pool_first = pool_first
|
189 |
+
self.separate_qkv = separate_qkv
|
190 |
+
self.drop_rate = drop_rate
|
191 |
+
self.num_heads = num_heads
|
192 |
+
self.dim_out = dim_out
|
193 |
+
head_dim = dim_out // num_heads
|
194 |
+
self.scale = head_dim**-0.5
|
195 |
+
self.has_cls_embed = has_cls_embed
|
196 |
+
self.mode = mode
|
197 |
+
padding_q = [int(q // 2) for q in kernel_q]
|
198 |
+
padding_kv = [int(kv // 2) for kv in kernel_kv]
|
199 |
+
|
200 |
+
if pool_first or separate_qkv:
|
201 |
+
self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
|
202 |
+
self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
|
203 |
+
self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
|
204 |
+
else:
|
205 |
+
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
|
206 |
+
|
207 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
208 |
+
if drop_rate > 0.0:
|
209 |
+
self.proj_drop = nn.Dropout(drop_rate)
|
210 |
+
|
211 |
+
# Skip pooling with kernel and stride size of (1, 1, 1).
|
212 |
+
if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1:
|
213 |
+
kernel_q = ()
|
214 |
+
if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1:
|
215 |
+
kernel_kv = ()
|
216 |
+
|
217 |
+
if mode in ("avg", "max"):
|
218 |
+
pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d
|
219 |
+
self.pool_q = (
|
220 |
+
pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
|
221 |
+
if len(kernel_q) > 0
|
222 |
+
else None
|
223 |
+
)
|
224 |
+
self.pool_k = (
|
225 |
+
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
|
226 |
+
if len(kernel_kv) > 0
|
227 |
+
else None
|
228 |
+
)
|
229 |
+
self.pool_v = (
|
230 |
+
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
|
231 |
+
if len(kernel_kv) > 0
|
232 |
+
else None
|
233 |
+
)
|
234 |
+
elif mode == "conv" or mode == "conv_unshared":
|
235 |
+
if pool_first:
|
236 |
+
dim_conv = dim // num_heads if mode == "conv" else dim
|
237 |
+
else:
|
238 |
+
dim_conv = dim_out // num_heads if mode == "conv" else dim_out
|
239 |
+
self.pool_q = (
|
240 |
+
nn.Conv3d(
|
241 |
+
dim_conv,
|
242 |
+
dim_conv,
|
243 |
+
kernel_q,
|
244 |
+
stride=stride_q,
|
245 |
+
padding=padding_q,
|
246 |
+
groups=dim_conv,
|
247 |
+
bias=False,
|
248 |
+
)
|
249 |
+
if len(kernel_q) > 0
|
250 |
+
else None
|
251 |
+
)
|
252 |
+
self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None
|
253 |
+
self.pool_k = (
|
254 |
+
nn.Conv3d(
|
255 |
+
dim_conv,
|
256 |
+
dim_conv,
|
257 |
+
kernel_kv,
|
258 |
+
stride=stride_kv,
|
259 |
+
padding=padding_kv,
|
260 |
+
groups=dim_conv,
|
261 |
+
bias=False,
|
262 |
+
)
|
263 |
+
if len(kernel_kv) > 0
|
264 |
+
else None
|
265 |
+
)
|
266 |
+
self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
|
267 |
+
self.pool_v = (
|
268 |
+
nn.Conv3d(
|
269 |
+
dim_conv,
|
270 |
+
dim_conv,
|
271 |
+
kernel_kv,
|
272 |
+
stride=stride_kv,
|
273 |
+
padding=padding_kv,
|
274 |
+
groups=dim_conv,
|
275 |
+
bias=False,
|
276 |
+
)
|
277 |
+
if len(kernel_kv) > 0
|
278 |
+
else None
|
279 |
+
)
|
280 |
+
self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
|
281 |
+
else:
|
282 |
+
raise NotImplementedError(f"Unsupported model {mode}")
|
283 |
+
|
284 |
+
self.rel_pos_spatial = rel_pos_spatial
|
285 |
+
self.rel_pos_temporal = rel_pos_temporal
|
286 |
+
if self.rel_pos_spatial:
|
287 |
+
assert input_size[1] == input_size[2]
|
288 |
+
size = input_size[1]
|
289 |
+
q_size = size // stride_q[1] if len(stride_q) > 0 else size
|
290 |
+
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
|
291 |
+
rel_sp_dim = 2 * max(q_size, kv_size) - 1
|
292 |
+
|
293 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
|
294 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
|
295 |
+
if not rel_pos_zero_init:
|
296 |
+
trunc_normal_(self.rel_pos_h, std=0.02)
|
297 |
+
trunc_normal_(self.rel_pos_w, std=0.02)
|
298 |
+
if self.rel_pos_temporal:
|
299 |
+
self.rel_pos_t = nn.Parameter(
|
300 |
+
torch.zeros(2 * input_size[0] - 1, head_dim)
|
301 |
+
)
|
302 |
+
if not rel_pos_zero_init:
|
303 |
+
trunc_normal_(self.rel_pos_t, std=0.02)
|
304 |
+
|
305 |
+
self.residual_pooling = residual_pooling
|
306 |
+
|
307 |
+
def forward(self, x, thw_shape):
|
308 |
+
B, N, _ = x.shape
|
309 |
+
|
310 |
+
if self.pool_first:
|
311 |
+
if self.mode == "conv_unshared":
|
312 |
+
fold_dim = 1
|
313 |
+
else:
|
314 |
+
fold_dim = self.num_heads
|
315 |
+
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
|
316 |
+
q = k = v = x
|
317 |
+
else:
|
318 |
+
assert self.mode != "conv_unshared"
|
319 |
+
if not self.separate_qkv:
|
320 |
+
qkv = (
|
321 |
+
self.qkv(x)
|
322 |
+
.reshape(B, N, 3, self.num_heads, -1)
|
323 |
+
.permute(2, 0, 3, 1, 4)
|
324 |
+
)
|
325 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
326 |
+
else:
|
327 |
+
q = k = v = x
|
328 |
+
q = (
|
329 |
+
self.q(q)
|
330 |
+
.reshape(B, N, self.num_heads, -1)
|
331 |
+
.permute(0, 2, 1, 3)
|
332 |
+
)
|
333 |
+
k = (
|
334 |
+
self.k(k)
|
335 |
+
.reshape(B, N, self.num_heads, -1)
|
336 |
+
.permute(0, 2, 1, 3)
|
337 |
+
)
|
338 |
+
v = (
|
339 |
+
self.v(v)
|
340 |
+
.reshape(B, N, self.num_heads, -1)
|
341 |
+
.permute(0, 2, 1, 3)
|
342 |
+
)
|
343 |
+
|
344 |
+
q, q_shape = attention_pool(
|
345 |
+
q,
|
346 |
+
self.pool_q,
|
347 |
+
thw_shape,
|
348 |
+
has_cls_embed=self.has_cls_embed,
|
349 |
+
norm=self.norm_q if hasattr(self, "norm_q") else None,
|
350 |
+
)
|
351 |
+
k, k_shape = attention_pool(
|
352 |
+
k,
|
353 |
+
self.pool_k,
|
354 |
+
thw_shape,
|
355 |
+
has_cls_embed=self.has_cls_embed,
|
356 |
+
norm=self.norm_k if hasattr(self, "norm_k") else None,
|
357 |
+
)
|
358 |
+
v, v_shape = attention_pool(
|
359 |
+
v,
|
360 |
+
self.pool_v,
|
361 |
+
thw_shape,
|
362 |
+
has_cls_embed=self.has_cls_embed,
|
363 |
+
norm=self.norm_v if hasattr(self, "norm_v") else None,
|
364 |
+
)
|
365 |
+
|
366 |
+
if self.pool_first:
|
367 |
+
q_N = (
|
368 |
+
numpy.prod(q_shape) + 1
|
369 |
+
if self.has_cls_embed
|
370 |
+
else numpy.prod(q_shape)
|
371 |
+
)
|
372 |
+
k_N = (
|
373 |
+
numpy.prod(k_shape) + 1
|
374 |
+
if self.has_cls_embed
|
375 |
+
else numpy.prod(k_shape)
|
376 |
+
)
|
377 |
+
v_N = (
|
378 |
+
numpy.prod(v_shape) + 1
|
379 |
+
if self.has_cls_embed
|
380 |
+
else numpy.prod(v_shape)
|
381 |
+
)
|
382 |
+
|
383 |
+
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
|
384 |
+
q = (
|
385 |
+
self.q(q)
|
386 |
+
.reshape(B, q_N, self.num_heads, -1)
|
387 |
+
.permute(0, 2, 1, 3)
|
388 |
+
)
|
389 |
+
|
390 |
+
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
|
391 |
+
v = (
|
392 |
+
self.v(v)
|
393 |
+
.reshape(B, v_N, self.num_heads, -1)
|
394 |
+
.permute(0, 2, 1, 3)
|
395 |
+
)
|
396 |
+
|
397 |
+
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
|
398 |
+
k = (
|
399 |
+
self.k(k)
|
400 |
+
.reshape(B, k_N, self.num_heads, -1)
|
401 |
+
.permute(0, 2, 1, 3)
|
402 |
+
)
|
403 |
+
|
404 |
+
N = q.shape[2]
|
405 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
406 |
+
if self.rel_pos_spatial:
|
407 |
+
attn = cal_rel_pos_spatial(
|
408 |
+
attn,
|
409 |
+
q,
|
410 |
+
k,
|
411 |
+
self.has_cls_embed,
|
412 |
+
q_shape,
|
413 |
+
k_shape,
|
414 |
+
self.rel_pos_h,
|
415 |
+
self.rel_pos_w,
|
416 |
+
)
|
417 |
+
|
418 |
+
if self.rel_pos_temporal:
|
419 |
+
attn = cal_rel_pos_temporal(
|
420 |
+
attn,
|
421 |
+
q,
|
422 |
+
self.has_cls_embed,
|
423 |
+
q_shape,
|
424 |
+
k_shape,
|
425 |
+
self.rel_pos_t,
|
426 |
+
)
|
427 |
+
attn = attn.softmax(dim=-1)
|
428 |
+
|
429 |
+
x = attn @ v
|
430 |
+
|
431 |
+
if self.residual_pooling:
|
432 |
+
if self.has_cls_embed:
|
433 |
+
x[:, :, 1:, :] += q[:, :, 1:, :]
|
434 |
+
else:
|
435 |
+
x = x + q
|
436 |
+
|
437 |
+
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
|
438 |
+
x = self.proj(x)
|
439 |
+
|
440 |
+
if self.drop_rate > 0.0:
|
441 |
+
x = self.proj_drop(x)
|
442 |
+
return x, q_shape
|
443 |
+
|
444 |
+
|
445 |
+
class MultiScaleBlock(nn.Module):
|
446 |
+
def __init__(
|
447 |
+
self,
|
448 |
+
dim,
|
449 |
+
dim_out,
|
450 |
+
num_heads,
|
451 |
+
input_size,
|
452 |
+
mlp_ratio=4.0,
|
453 |
+
qkv_bias=False,
|
454 |
+
qk_scale=None,
|
455 |
+
drop_rate=0.0,
|
456 |
+
drop_path=0.0,
|
457 |
+
layer_scale_init_value=0.0,
|
458 |
+
act_layer=nn.GELU,
|
459 |
+
norm_layer=nn.LayerNorm,
|
460 |
+
up_rate=None,
|
461 |
+
kernel_q=(1, 1, 1),
|
462 |
+
kernel_kv=(1, 1, 1),
|
463 |
+
stride_q=(1, 1, 1),
|
464 |
+
stride_kv=(1, 1, 1),
|
465 |
+
mode="conv",
|
466 |
+
has_cls_embed=True,
|
467 |
+
pool_first=False,
|
468 |
+
rel_pos_spatial=False,
|
469 |
+
rel_pos_temporal=False,
|
470 |
+
rel_pos_zero_init=False,
|
471 |
+
residual_pooling=False,
|
472 |
+
dim_mul_in_att=False,
|
473 |
+
separate_qkv=False,
|
474 |
+
):
|
475 |
+
super().__init__()
|
476 |
+
self.dim = dim
|
477 |
+
self.dim_out = dim_out
|
478 |
+
self.norm1 = norm_layer(dim)
|
479 |
+
self.dim_mul_in_att = dim_mul_in_att
|
480 |
+
kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
|
481 |
+
stride_skip = stride_q
|
482 |
+
padding_skip = [int(skip // 2) for skip in kernel_skip]
|
483 |
+
att_dim = dim_out if dim_mul_in_att else dim
|
484 |
+
self.attn = MultiScaleAttention(
|
485 |
+
dim,
|
486 |
+
att_dim,
|
487 |
+
num_heads=num_heads,
|
488 |
+
input_size=input_size,
|
489 |
+
qkv_bias=qkv_bias,
|
490 |
+
drop_rate=drop_rate,
|
491 |
+
kernel_q=kernel_q,
|
492 |
+
kernel_kv=kernel_kv,
|
493 |
+
stride_q=stride_q,
|
494 |
+
stride_kv=stride_kv,
|
495 |
+
norm_layer=norm_layer,
|
496 |
+
has_cls_embed=has_cls_embed,
|
497 |
+
mode=mode,
|
498 |
+
pool_first=pool_first,
|
499 |
+
rel_pos_spatial=rel_pos_spatial,
|
500 |
+
rel_pos_temporal=rel_pos_temporal,
|
501 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
502 |
+
residual_pooling=residual_pooling,
|
503 |
+
separate_qkv=separate_qkv,
|
504 |
+
)
|
505 |
+
self.drop_path = (
|
506 |
+
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
507 |
+
)
|
508 |
+
self.norm2 = norm_layer(att_dim)
|
509 |
+
mlp_hidden_dim = int(att_dim * mlp_ratio)
|
510 |
+
self.has_cls_embed = has_cls_embed
|
511 |
+
# TODO: check the use case for up_rate, and merge the following lines
|
512 |
+
if up_rate is not None and up_rate > 1:
|
513 |
+
mlp_dim_out = dim * up_rate
|
514 |
+
else:
|
515 |
+
mlp_dim_out = dim_out
|
516 |
+
self.mlp = Mlp(
|
517 |
+
in_features=att_dim,
|
518 |
+
hidden_features=mlp_hidden_dim,
|
519 |
+
out_features=mlp_dim_out,
|
520 |
+
act_layer=act_layer,
|
521 |
+
drop_rate=drop_rate,
|
522 |
+
)
|
523 |
+
if layer_scale_init_value > 0:
|
524 |
+
self.gamma_1 = nn.Parameter(
|
525 |
+
layer_scale_init_value * torch.ones((dim)), requires_grad=True
|
526 |
+
)
|
527 |
+
self.gamma_2 = nn.Parameter(
|
528 |
+
layer_scale_init_value * torch.ones((dim_out)),
|
529 |
+
requires_grad=True,
|
530 |
+
)
|
531 |
+
else:
|
532 |
+
self.gamma_1, self.gamma_2 = None, None
|
533 |
+
|
534 |
+
if dim != dim_out:
|
535 |
+
self.proj = nn.Linear(dim, dim_out)
|
536 |
+
|
537 |
+
self.pool_skip = (
|
538 |
+
nn.MaxPool3d(
|
539 |
+
kernel_skip, stride_skip, padding_skip, ceil_mode=False
|
540 |
+
)
|
541 |
+
if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1
|
542 |
+
else None
|
543 |
+
)
|
544 |
+
|
545 |
+
def forward(self, x, thw_shape=None):
|
546 |
+
x_norm = self.norm1(x)
|
547 |
+
x_block, thw_shape_new = self.attn(x_norm, thw_shape)
|
548 |
+
if self.dim_mul_in_att and self.dim != self.dim_out:
|
549 |
+
x = self.proj(x_norm)
|
550 |
+
x_res, _ = attention_pool(
|
551 |
+
x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed
|
552 |
+
)
|
553 |
+
if self.gamma_1 is not None:
|
554 |
+
x = x_res + self.drop_path(self.gamma_1 * x_block)
|
555 |
+
else:
|
556 |
+
x = x_res + self.drop_path(x_block)
|
557 |
+
x_norm = self.norm2(x)
|
558 |
+
x_mlp = self.mlp(x_norm)
|
559 |
+
if not self.dim_mul_in_att and self.dim != self.dim_out:
|
560 |
+
x = self.proj(x_norm)
|
561 |
+
if self.gamma_2 is not None:
|
562 |
+
x = x + self.drop_path(self.gamma_2 * x_mlp)
|
563 |
+
else:
|
564 |
+
x = x + self.drop_path(x_mlp)
|
565 |
+
if thw_shape:
|
566 |
+
return x, thw_shape_new
|
567 |
+
else:
|
568 |
+
return x
|
skp/models/rev_mvit/batchnorm_helper.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
"""BatchNorm (BN) utility functions and custom batch-size BN implementations"""
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from pytorchvideo.layers.batch_norm import (
|
11 |
+
NaiveSyncBatchNorm1d,
|
12 |
+
NaiveSyncBatchNorm3d,
|
13 |
+
) # noqa
|
14 |
+
|
15 |
+
|
16 |
+
def get_norm(cfg):
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
cfg (CfgNode): model building configs, details are in the comments of
|
20 |
+
the config file.
|
21 |
+
Returns:
|
22 |
+
nn.Module: the normalization layer.
|
23 |
+
"""
|
24 |
+
if cfg.BN.NORM_TYPE in {"batchnorm", "sync_batchnorm_apex"}:
|
25 |
+
return nn.BatchNorm3d
|
26 |
+
elif cfg.BN.NORM_TYPE == "sub_batchnorm":
|
27 |
+
return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS)
|
28 |
+
elif cfg.BN.NORM_TYPE == "sync_batchnorm":
|
29 |
+
return partial(
|
30 |
+
NaiveSyncBatchNorm3d,
|
31 |
+
num_sync_devices=cfg.BN.NUM_SYNC_DEVICES,
|
32 |
+
global_sync=cfg.BN.GLOBAL_SYNC,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
raise NotImplementedError(
|
36 |
+
"Norm type {} is not supported".format(cfg.BN.NORM_TYPE)
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class SubBatchNorm3d(nn.Module):
|
41 |
+
"""
|
42 |
+
The standard BN layer computes stats across all examples in a GPU. In some
|
43 |
+
cases it is desirable to compute stats across only a subset of examples
|
44 |
+
(e.g., in multigrid training https://arxiv.org/abs/1912.00998).
|
45 |
+
SubBatchNorm3d splits the batch dimension into N splits, and run BN on
|
46 |
+
each of them separately (so that the stats are computed on each subset of
|
47 |
+
examples (1/N of batch) independently. During evaluation, it aggregates
|
48 |
+
the stats from all splits into one BN.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, num_splits, **args):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
num_splits (int): number of splits.
|
55 |
+
args (list): other arguments.
|
56 |
+
"""
|
57 |
+
super(SubBatchNorm3d, self).__init__()
|
58 |
+
self.num_splits = num_splits
|
59 |
+
num_features = args["num_features"]
|
60 |
+
# Keep only one set of weight and bias.
|
61 |
+
if args.get("affine", True):
|
62 |
+
self.affine = True
|
63 |
+
args["affine"] = False
|
64 |
+
self.weight = torch.nn.Parameter(torch.ones(num_features))
|
65 |
+
self.bias = torch.nn.Parameter(torch.zeros(num_features))
|
66 |
+
else:
|
67 |
+
self.affine = False
|
68 |
+
self.bn = nn.BatchNorm3d(**args)
|
69 |
+
args["num_features"] = num_features * num_splits
|
70 |
+
self.split_bn = nn.BatchNorm3d(**args)
|
71 |
+
|
72 |
+
def _get_aggregated_mean_std(self, means, stds, n):
|
73 |
+
"""
|
74 |
+
Calculate the aggregated mean and stds.
|
75 |
+
Args:
|
76 |
+
means (tensor): mean values.
|
77 |
+
stds (tensor): standard deviations.
|
78 |
+
n (int): number of sets of means and stds.
|
79 |
+
"""
|
80 |
+
mean = means.view(n, -1).sum(0) / n
|
81 |
+
std = (
|
82 |
+
stds.view(n, -1).sum(0) / n
|
83 |
+
+ ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n
|
84 |
+
)
|
85 |
+
return mean.detach(), std.detach()
|
86 |
+
|
87 |
+
def aggregate_stats(self):
|
88 |
+
"""
|
89 |
+
Synchronize running_mean, and running_var. Call this before eval.
|
90 |
+
"""
|
91 |
+
if self.split_bn.track_running_stats:
|
92 |
+
(
|
93 |
+
self.bn.running_mean.data,
|
94 |
+
self.bn.running_var.data,
|
95 |
+
) = self._get_aggregated_mean_std(
|
96 |
+
self.split_bn.running_mean,
|
97 |
+
self.split_bn.running_var,
|
98 |
+
self.num_splits,
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
if self.training:
|
103 |
+
n, c, t, h, w = x.shape
|
104 |
+
x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
|
105 |
+
x = self.split_bn(x)
|
106 |
+
x = x.view(n, c, t, h, w)
|
107 |
+
else:
|
108 |
+
x = self.bn(x)
|
109 |
+
if self.affine:
|
110 |
+
x = x * self.weight.view((-1, 1, 1, 1))
|
111 |
+
x = x + self.bias.view((-1, 1, 1, 1))
|
112 |
+
return x
|
skp/models/rev_mvit/common.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class Mlp(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
in_features,
|
11 |
+
hidden_features=None,
|
12 |
+
out_features=None,
|
13 |
+
act_layer=nn.GELU,
|
14 |
+
drop_rate=0.0,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.drop_rate = drop_rate
|
18 |
+
out_features = out_features or in_features
|
19 |
+
hidden_features = hidden_features or in_features
|
20 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
+
self.act = act_layer()
|
22 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
+
if self.drop_rate > 0.0:
|
24 |
+
self.drop = nn.Dropout(drop_rate)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.fc1(x)
|
28 |
+
x = self.act(x)
|
29 |
+
if self.drop_rate > 0.0:
|
30 |
+
x = self.drop(x)
|
31 |
+
x = self.fc2(x)
|
32 |
+
if self.drop_rate > 0.0:
|
33 |
+
x = self.drop(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class Permute(nn.Module):
|
38 |
+
def __init__(self, dims):
|
39 |
+
super().__init__()
|
40 |
+
self.dims = dims
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return x.permute(*self.dims)
|
44 |
+
|
45 |
+
|
46 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
47 |
+
"""
|
48 |
+
Stochastic Depth per sample.
|
49 |
+
"""
|
50 |
+
if drop_prob == 0.0 or not training:
|
51 |
+
return x
|
52 |
+
keep_prob = 1 - drop_prob
|
53 |
+
shape = (x.shape[0],) + (1,) * (
|
54 |
+
x.ndim - 1
|
55 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
56 |
+
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
57 |
+
mask.floor_() # binarize
|
58 |
+
output = x.div(keep_prob) * mask
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
class DropPath(nn.Module):
|
63 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
64 |
+
|
65 |
+
def __init__(self, drop_prob=None):
|
66 |
+
super(DropPath, self).__init__()
|
67 |
+
self.drop_prob = drop_prob
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return drop_path(x, self.drop_prob, self.training)
|
71 |
+
|
72 |
+
|
73 |
+
class TwoStreamFusion(nn.Module):
|
74 |
+
def __init__(self, mode, dim=None, kernel=3, padding=1):
|
75 |
+
"""
|
76 |
+
A general constructor for neural modules fusing two equal sized tensors
|
77 |
+
in forward. Following options are supported:
|
78 |
+
|
79 |
+
"add" / "max" / "min" / "avg" : respective operations on the two halves.
|
80 |
+
"concat" : NOOP.
|
81 |
+
"concat_linear_{dim_mult}_{drop_rate}" : MLP to fuse with hidden dim "dim_mult"
|
82 |
+
(optional, def 1.) higher than input dim
|
83 |
+
with optional dropout "drop_rate" (def: 0.)
|
84 |
+
"ln+concat_linear_{dim_mult}_{drop_rate}" : perform MLP after layernorm on the input.
|
85 |
+
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
self.mode = mode
|
89 |
+
if mode == "add":
|
90 |
+
self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).sum(
|
91 |
+
dim=0
|
92 |
+
)
|
93 |
+
elif mode == "max":
|
94 |
+
self.fuse_fn = (
|
95 |
+
lambda x: torch.stack(torch.chunk(x, 2, dim=2))
|
96 |
+
.max(dim=0)
|
97 |
+
.values
|
98 |
+
)
|
99 |
+
elif mode == "min":
|
100 |
+
self.fuse_fn = (
|
101 |
+
lambda x: torch.stack(torch.chunk(x, 2, dim=2))
|
102 |
+
.min(dim=0)
|
103 |
+
.values
|
104 |
+
)
|
105 |
+
elif mode == "avg":
|
106 |
+
self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).mean(
|
107 |
+
dim=0
|
108 |
+
)
|
109 |
+
elif mode == "concat":
|
110 |
+
# x itself is the channel concat version
|
111 |
+
self.fuse_fn = lambda x: x
|
112 |
+
elif "concat_linear" in mode:
|
113 |
+
if len(mode.split("_")) == 2:
|
114 |
+
dim_mult = 1.0
|
115 |
+
drop_rate = 0.0
|
116 |
+
elif len(mode.split("_")) == 3:
|
117 |
+
dim_mult = float(mode.split("_")[-1])
|
118 |
+
drop_rate = 0.0
|
119 |
+
|
120 |
+
elif len(mode.split("_")) == 4:
|
121 |
+
dim_mult = float(mode.split("_")[-2])
|
122 |
+
drop_rate = float(mode.split("_")[-1])
|
123 |
+
else:
|
124 |
+
raise NotImplementedError
|
125 |
+
|
126 |
+
if mode.split("+")[0] == "ln":
|
127 |
+
self.fuse_fn = nn.Sequential(
|
128 |
+
nn.LayerNorm(dim),
|
129 |
+
Mlp(
|
130 |
+
in_features=dim,
|
131 |
+
hidden_features=int(dim * dim_mult),
|
132 |
+
act_layer=nn.GELU,
|
133 |
+
out_features=dim,
|
134 |
+
drop_rate=drop_rate,
|
135 |
+
),
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
self.fuse_fn = Mlp(
|
139 |
+
in_features=dim,
|
140 |
+
hidden_features=int(dim * dim_mult),
|
141 |
+
act_layer=nn.GELU,
|
142 |
+
out_features=dim,
|
143 |
+
drop_rate=drop_rate,
|
144 |
+
)
|
145 |
+
|
146 |
+
else:
|
147 |
+
raise NotImplementedError
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
if "concat_linear" in self.mode:
|
151 |
+
return self.fuse_fn(x) + x
|
152 |
+
|
153 |
+
else:
|
154 |
+
return self.fuse_fn(x)
|
skp/models/rev_mvit/head_helper.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
"""ResNe(X)t Head helper."""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .batchnorm_helper import (
|
10 |
+
NaiveSyncBatchNorm1d as NaiveSyncBatchNorm1d,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class MLPHead(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
dim_in,
|
18 |
+
dim_out,
|
19 |
+
mlp_dim,
|
20 |
+
num_layers,
|
21 |
+
bn_on=False,
|
22 |
+
bias=True,
|
23 |
+
flatten=False,
|
24 |
+
xavier_init=True,
|
25 |
+
bn_sync_num=1,
|
26 |
+
global_sync=False,
|
27 |
+
):
|
28 |
+
super(MLPHead, self).__init__()
|
29 |
+
self.flatten = flatten
|
30 |
+
b = False if bn_on else bias
|
31 |
+
# assert bn_on or bn_sync_num=1
|
32 |
+
mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)]
|
33 |
+
mlp_layers[-1].xavier_init = xavier_init
|
34 |
+
for i in range(1, num_layers):
|
35 |
+
if bn_on:
|
36 |
+
if global_sync or bn_sync_num > 1:
|
37 |
+
mlp_layers.append(
|
38 |
+
NaiveSyncBatchNorm1d(
|
39 |
+
num_sync_devices=bn_sync_num,
|
40 |
+
global_sync=global_sync,
|
41 |
+
num_features=mlp_dim,
|
42 |
+
)
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
mlp_layers.append(nn.BatchNorm1d(num_features=mlp_dim))
|
46 |
+
mlp_layers.append(nn.ReLU(inplace=True))
|
47 |
+
if i == num_layers - 1:
|
48 |
+
d = dim_out
|
49 |
+
b = bias
|
50 |
+
else:
|
51 |
+
d = mlp_dim
|
52 |
+
mlp_layers.append(nn.Linear(mlp_dim, d, bias=b))
|
53 |
+
mlp_layers[-1].xavier_init = xavier_init
|
54 |
+
self.projection = nn.Sequential(*mlp_layers)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if x.ndim == 5:
|
58 |
+
x = x.permute((0, 2, 3, 4, 1))
|
59 |
+
if self.flatten:
|
60 |
+
x = x.reshape(-1, x.shape[-1])
|
61 |
+
|
62 |
+
return self.projection(x)
|
63 |
+
|
64 |
+
|
65 |
+
class TransformerBasicHead(nn.Module):
|
66 |
+
"""
|
67 |
+
BasicHead. No pool.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
dim_in,
|
73 |
+
num_classes,
|
74 |
+
dropout_rate=0.0,
|
75 |
+
act_func="softmax",
|
76 |
+
cfg=None,
|
77 |
+
):
|
78 |
+
"""
|
79 |
+
Perform linear projection and activation as head for tranformers.
|
80 |
+
Args:
|
81 |
+
dim_in (int): the channel dimension of the input to the head.
|
82 |
+
num_classes (int): the channel dimensions of the output to the head.
|
83 |
+
dropout_rate (float): dropout rate. If equal to 0.0, perform no
|
84 |
+
dropout.
|
85 |
+
act_func (string): activation function to use. 'softmax': applies
|
86 |
+
softmax on the output. 'sigmoid': applies sigmoid on the output.
|
87 |
+
"""
|
88 |
+
super(TransformerBasicHead, self).__init__()
|
89 |
+
if dropout_rate > 0.0:
|
90 |
+
self.dropout = nn.Dropout(dropout_rate)
|
91 |
+
self.projection = nn.Linear(dim_in, num_classes, bias=True)
|
92 |
+
|
93 |
+
if cfg.CONTRASTIVE.NUM_MLP_LAYERS == 1:
|
94 |
+
self.projection = nn.Linear(dim_in, num_classes, bias=True)
|
95 |
+
else:
|
96 |
+
self.projection = MLPHead(
|
97 |
+
dim_in,
|
98 |
+
num_classes,
|
99 |
+
cfg.CONTRASTIVE.MLP_DIM,
|
100 |
+
cfg.CONTRASTIVE.NUM_MLP_LAYERS,
|
101 |
+
bn_on=cfg.CONTRASTIVE.BN_MLP,
|
102 |
+
bn_sync_num=cfg.BN.NUM_SYNC_DEVICES
|
103 |
+
if cfg.CONTRASTIVE.BN_SYNC_MLP
|
104 |
+
else 1,
|
105 |
+
global_sync=(
|
106 |
+
cfg.CONTRASTIVE.BN_SYNC_MLP and cfg.BN.GLOBAL_SYNC
|
107 |
+
),
|
108 |
+
)
|
109 |
+
self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC
|
110 |
+
|
111 |
+
# Softmax for evaluation and testing.
|
112 |
+
if act_func == "softmax":
|
113 |
+
self.act = nn.Softmax(dim=1)
|
114 |
+
elif act_func == "sigmoid":
|
115 |
+
self.act = nn.Sigmoid()
|
116 |
+
elif act_func == "none":
|
117 |
+
self.act = None
|
118 |
+
else:
|
119 |
+
raise NotImplementedError(
|
120 |
+
"{} is not supported as an activation"
|
121 |
+
"function.".format(act_func)
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
if hasattr(self, "dropout"):
|
126 |
+
x = self.dropout(x)
|
127 |
+
if self.detach_final_fc:
|
128 |
+
x = x.detach()
|
129 |
+
x = self.projection(x)
|
130 |
+
|
131 |
+
if not self.training:
|
132 |
+
if self.act is not None:
|
133 |
+
x = self.act(x)
|
134 |
+
# Performs fully convolutional inference.
|
135 |
+
if x.ndim == 5 and x.shape[1:4] > torch.Size([1, 1, 1]):
|
136 |
+
x = x.mean([1, 2, 3])
|
137 |
+
|
138 |
+
x = x.view(x.shape[0], -1)
|
139 |
+
|
140 |
+
return x
|
skp/models/rev_mvit/reversible_mvit.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from functools import partial
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function as Function
|
6 |
+
|
7 |
+
from .attention import MultiScaleAttention, attention_pool
|
8 |
+
from .common import Mlp, TwoStreamFusion, drop_path
|
9 |
+
from .utils import round_width
|
10 |
+
|
11 |
+
|
12 |
+
class ReversibleMViT(nn.Module):
|
13 |
+
"""
|
14 |
+
Reversible model builder. This builds the reversible transformer encoder
|
15 |
+
and allows reversible training.
|
16 |
+
|
17 |
+
Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong,
|
18 |
+
Christoph Feichtenhofer, Jitendra Malik
|
19 |
+
"Reversible Vision Transformers"
|
20 |
+
|
21 |
+
https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, config, model):
|
25 |
+
"""
|
26 |
+
The `__init__` method of any subclass should also contain these
|
27 |
+
arguments.
|
28 |
+
Args:
|
29 |
+
cfg (CfgNode): model building configs, details are in the
|
30 |
+
comments of the config file.
|
31 |
+
model (nn.Module): parent MViT module this module forms
|
32 |
+
a reversible encoder in.
|
33 |
+
"""
|
34 |
+
|
35 |
+
super().__init__()
|
36 |
+
self.cfg = config
|
37 |
+
|
38 |
+
embed_dim = self.cfg.MVIT.EMBED_DIM
|
39 |
+
depth = self.cfg.MVIT.DEPTH
|
40 |
+
num_heads = self.cfg.MVIT.NUM_HEADS
|
41 |
+
mlp_ratio = self.cfg.MVIT.MLP_RATIO
|
42 |
+
qkv_bias = self.cfg.MVIT.QKV_BIAS
|
43 |
+
|
44 |
+
drop_path_rate = self.cfg.MVIT.DROPPATH_RATE
|
45 |
+
self.dropout = config.MVIT.DROPOUT_RATE
|
46 |
+
self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION
|
47 |
+
dpr = [
|
48 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
49 |
+
] # stochastic depth decay rule
|
50 |
+
|
51 |
+
input_size = model.patch_dims
|
52 |
+
|
53 |
+
self.layers = nn.ModuleList([])
|
54 |
+
self.no_custom_backward = False
|
55 |
+
|
56 |
+
if self.cfg.MVIT.NORM == "layernorm":
|
57 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
58 |
+
else:
|
59 |
+
raise NotImplementedError("Only supports layernorm.")
|
60 |
+
|
61 |
+
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
|
62 |
+
for i in range(len(self.cfg.MVIT.DIM_MUL)):
|
63 |
+
dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1]
|
64 |
+
for i in range(len(self.cfg.MVIT.HEAD_MUL)):
|
65 |
+
head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][
|
66 |
+
1
|
67 |
+
]
|
68 |
+
|
69 |
+
pool_q = model.pool_q
|
70 |
+
pool_kv = model.pool_kv
|
71 |
+
stride_q = model.stride_q
|
72 |
+
stride_kv = model.stride_kv
|
73 |
+
|
74 |
+
for i in range(depth):
|
75 |
+
|
76 |
+
num_heads = round_width(num_heads, head_mul[i])
|
77 |
+
|
78 |
+
# Upsampling inside the MHPA, input to the Q-pooling block is lower C dimension
|
79 |
+
# This localizes the feature changes in a single block, making more computation reversible.
|
80 |
+
embed_dim = round_width(
|
81 |
+
embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads
|
82 |
+
)
|
83 |
+
dim_out = round_width(
|
84 |
+
embed_dim,
|
85 |
+
dim_mul[i],
|
86 |
+
divisor=round_width(num_heads, head_mul[i + 1]),
|
87 |
+
)
|
88 |
+
|
89 |
+
if i in self.cfg.MVIT.REV.BUFFER_LAYERS:
|
90 |
+
layer_type = StageTransitionBlock
|
91 |
+
input_mult = 2 if "concat" in self.pre_q_fusion else 1
|
92 |
+
else:
|
93 |
+
layer_type = ReversibleBlock
|
94 |
+
input_mult = 1
|
95 |
+
|
96 |
+
dimout_correction = (
|
97 |
+
2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1
|
98 |
+
)
|
99 |
+
|
100 |
+
self.layers.append(
|
101 |
+
layer_type(
|
102 |
+
dim=embed_dim
|
103 |
+
* input_mult, # added only for concat fusion before Qpooling layers
|
104 |
+
input_size=input_size,
|
105 |
+
dim_out=dim_out * input_mult // dimout_correction,
|
106 |
+
num_heads=num_heads,
|
107 |
+
cfg=self.cfg,
|
108 |
+
mlp_ratio=mlp_ratio,
|
109 |
+
qkv_bias=qkv_bias,
|
110 |
+
drop_path=dpr[i],
|
111 |
+
norm_layer=norm_layer,
|
112 |
+
kernel_q=pool_q[i] if len(pool_q) > i else [],
|
113 |
+
kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
|
114 |
+
stride_q=stride_q[i] if len(stride_q) > i else [],
|
115 |
+
stride_kv=stride_kv[i] if len(stride_kv) > i else [],
|
116 |
+
layer_id=i,
|
117 |
+
pre_q_fusion=self.pre_q_fusion,
|
118 |
+
)
|
119 |
+
)
|
120 |
+
# F is the attention block
|
121 |
+
self.layers[-1].F.thw = input_size
|
122 |
+
|
123 |
+
if len(stride_q[i]) > 0:
|
124 |
+
input_size = [
|
125 |
+
size // stride
|
126 |
+
for size, stride in zip(input_size, stride_q[i])
|
127 |
+
]
|
128 |
+
|
129 |
+
embed_dim = dim_out
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def vanilla_backward(h, layers, buffer):
|
133 |
+
"""
|
134 |
+
Using rev layers without rev backpropagation. Debugging purposes only.
|
135 |
+
Activated with self.no_custom_backward.
|
136 |
+
"""
|
137 |
+
|
138 |
+
# split into hidden states (h) and attention_output (a)
|
139 |
+
h, a = torch.chunk(h, 2, dim=-1)
|
140 |
+
for _, layer in enumerate(layers):
|
141 |
+
a, h = layer(a, h)
|
142 |
+
|
143 |
+
return torch.cat([a, h], dim=-1)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
|
147 |
+
# process the layers in a reversible stack and an irreversible stack.
|
148 |
+
stack = []
|
149 |
+
for l_i in range(len(self.layers)):
|
150 |
+
if isinstance(self.layers[l_i], StageTransitionBlock):
|
151 |
+
stack.append(("StageTransition", l_i))
|
152 |
+
else:
|
153 |
+
if len(stack) == 0 or stack[-1][0] == "StageTransition":
|
154 |
+
stack.append(("Reversible", []))
|
155 |
+
stack[-1][1].append(l_i)
|
156 |
+
|
157 |
+
for layer_seq in stack:
|
158 |
+
|
159 |
+
if layer_seq[0] == "StageTransition":
|
160 |
+
x = self.layers[layer_seq[1]](x)
|
161 |
+
|
162 |
+
else:
|
163 |
+
x = torch.cat([x, x], dim=-1)
|
164 |
+
|
165 |
+
# no need for custom backprop in eval/model stat log
|
166 |
+
if not self.training or self.no_custom_backward:
|
167 |
+
executing_fn = ReversibleMViT.vanilla_backward
|
168 |
+
else:
|
169 |
+
executing_fn = RevBackProp.apply
|
170 |
+
|
171 |
+
x = executing_fn(
|
172 |
+
x,
|
173 |
+
self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1],
|
174 |
+
[], # buffer activations
|
175 |
+
)
|
176 |
+
|
177 |
+
# Apply dropout
|
178 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class RevBackProp(Function):
|
184 |
+
"""
|
185 |
+
Custom Backpropagation function to allow (A) flusing memory in foward
|
186 |
+
and (B) activation recomputation reversibly in backward for gradient calculation.
|
187 |
+
|
188 |
+
Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
|
189 |
+
"""
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def forward(
|
193 |
+
ctx,
|
194 |
+
x,
|
195 |
+
layers,
|
196 |
+
buffer_layers, # List of layer ids for int activation to buffer
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Reversible Forward pass. Any intermediate activations from `buffer_layers` are
|
200 |
+
cached in ctx for forward pass. This is not necessary for standard usecases.
|
201 |
+
Each reversible layer implements its own forward pass logic.
|
202 |
+
"""
|
203 |
+
buffer_layers.sort()
|
204 |
+
|
205 |
+
X_1, X_2 = torch.chunk(x, 2, dim=-1)
|
206 |
+
|
207 |
+
intermediate = []
|
208 |
+
|
209 |
+
for layer in layers:
|
210 |
+
|
211 |
+
X_1, X_2 = layer(X_1, X_2)
|
212 |
+
|
213 |
+
if layer.layer_id in buffer_layers:
|
214 |
+
intermediate.extend([X_1.detach(), X_2.detach()])
|
215 |
+
|
216 |
+
if len(buffer_layers) == 0:
|
217 |
+
all_tensors = [X_1.detach(), X_2.detach()]
|
218 |
+
else:
|
219 |
+
intermediate = [torch.LongTensor(buffer_layers), *intermediate]
|
220 |
+
all_tensors = [X_1.detach(), X_2.detach(), *intermediate]
|
221 |
+
|
222 |
+
ctx.save_for_backward(*all_tensors)
|
223 |
+
ctx.layers = layers
|
224 |
+
|
225 |
+
return torch.cat([X_1, X_2], dim=-1)
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def backward(ctx, dx):
|
229 |
+
"""
|
230 |
+
Reversible Backward pass. Any intermediate activations from `buffer_layers` are
|
231 |
+
recovered from ctx. Each layer implements its own loic for backward pass (both
|
232 |
+
activation recomputation and grad calculation).
|
233 |
+
"""
|
234 |
+
dX_1, dX_2 = torch.chunk(dx, 2, dim=-1)
|
235 |
+
|
236 |
+
# retrieve params from ctx for backward
|
237 |
+
X_1, X_2, *int_tensors = ctx.saved_tensors
|
238 |
+
|
239 |
+
# no buffering
|
240 |
+
if len(int_tensors) != 0:
|
241 |
+
buffer_layers = int_tensors[0].tolist()
|
242 |
+
|
243 |
+
else:
|
244 |
+
buffer_layers = []
|
245 |
+
|
246 |
+
layers = ctx.layers
|
247 |
+
|
248 |
+
for _, layer in enumerate(layers[::-1]):
|
249 |
+
|
250 |
+
if layer.layer_id in buffer_layers:
|
251 |
+
|
252 |
+
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
|
253 |
+
Y_1=int_tensors[
|
254 |
+
buffer_layers.index(layer.layer_id) * 2 + 1
|
255 |
+
],
|
256 |
+
Y_2=int_tensors[
|
257 |
+
buffer_layers.index(layer.layer_id) * 2 + 2
|
258 |
+
],
|
259 |
+
dY_1=dX_1,
|
260 |
+
dY_2=dX_2,
|
261 |
+
)
|
262 |
+
|
263 |
+
else:
|
264 |
+
|
265 |
+
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
|
266 |
+
Y_1=X_1,
|
267 |
+
Y_2=X_2,
|
268 |
+
dY_1=dX_1,
|
269 |
+
dY_2=dX_2,
|
270 |
+
)
|
271 |
+
|
272 |
+
dx = torch.cat([dX_1, dX_2], dim=-1)
|
273 |
+
|
274 |
+
del int_tensors
|
275 |
+
del dX_1, dX_2, X_1, X_2
|
276 |
+
|
277 |
+
return dx, None, None
|
278 |
+
|
279 |
+
|
280 |
+
class StageTransitionBlock(nn.Module):
|
281 |
+
"""
|
282 |
+
Blocks for changing the feature dimensions in MViT (using Q-pooling).
|
283 |
+
See Section 3.3.1 in paper for details.
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __init__(
|
287 |
+
self,
|
288 |
+
dim,
|
289 |
+
input_size,
|
290 |
+
dim_out,
|
291 |
+
num_heads,
|
292 |
+
mlp_ratio,
|
293 |
+
qkv_bias,
|
294 |
+
drop_path,
|
295 |
+
kernel_q,
|
296 |
+
kernel_kv,
|
297 |
+
stride_q,
|
298 |
+
stride_kv,
|
299 |
+
cfg,
|
300 |
+
norm_layer=nn.LayerNorm,
|
301 |
+
pre_q_fusion=None,
|
302 |
+
layer_id=0,
|
303 |
+
):
|
304 |
+
"""
|
305 |
+
Uses the same structure of F and G functions as Reversible Block except
|
306 |
+
without using reversible forward (and backward) pass.
|
307 |
+
"""
|
308 |
+
super().__init__()
|
309 |
+
|
310 |
+
self.drop_path_rate = drop_path
|
311 |
+
|
312 |
+
embed_dim = dim
|
313 |
+
|
314 |
+
self.F = AttentionSubBlock(
|
315 |
+
dim=embed_dim,
|
316 |
+
input_size=input_size,
|
317 |
+
num_heads=num_heads,
|
318 |
+
cfg=cfg,
|
319 |
+
dim_out=dim_out,
|
320 |
+
kernel_q=kernel_q,
|
321 |
+
kernel_kv=kernel_kv,
|
322 |
+
stride_q=stride_q,
|
323 |
+
stride_kv=stride_kv,
|
324 |
+
norm_layer=norm_layer,
|
325 |
+
)
|
326 |
+
|
327 |
+
self.G = MLPSubblock(
|
328 |
+
dim=dim_out,
|
329 |
+
mlp_ratio=mlp_ratio,
|
330 |
+
norm_layer=norm_layer,
|
331 |
+
)
|
332 |
+
|
333 |
+
self.layer_id = layer_id
|
334 |
+
|
335 |
+
self.is_proj = False
|
336 |
+
self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON
|
337 |
+
|
338 |
+
self.is_conv = False
|
339 |
+
self.pool_first = cfg.MVIT.POOL_FIRST
|
340 |
+
self.mode = cfg.MVIT.MODE
|
341 |
+
self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim)
|
342 |
+
|
343 |
+
if cfg.MVIT.REV.RES_PATH == "max":
|
344 |
+
self.res_conv = False
|
345 |
+
self.pool_skip = nn.MaxPool3d(
|
346 |
+
# self.attention.attn.pool_q.kernel_size,
|
347 |
+
[s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride],
|
348 |
+
self.F.attn.pool_q.stride,
|
349 |
+
[int(k // 2) for k in self.F.attn.pool_q.stride],
|
350 |
+
# self.attention.attn.pool_q.padding,
|
351 |
+
ceil_mode=False,
|
352 |
+
)
|
353 |
+
|
354 |
+
elif cfg.MVIT.REV.RES_PATH == "conv":
|
355 |
+
self.res_conv = True
|
356 |
+
else:
|
357 |
+
raise NotImplementedError
|
358 |
+
|
359 |
+
# Add a linear projection in residual branch
|
360 |
+
if embed_dim != dim_out:
|
361 |
+
self.is_proj = True
|
362 |
+
self.res_proj = nn.Linear(embed_dim, dim_out, bias=True)
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
x,
|
367 |
+
):
|
368 |
+
"""
|
369 |
+
Forward logic is similar to MultiScaleBlock with Q-pooling.
|
370 |
+
"""
|
371 |
+
x = self.pre_q_fuse(x)
|
372 |
+
|
373 |
+
# fork tensor for residual connections
|
374 |
+
x_res = x
|
375 |
+
|
376 |
+
# This uses conv to pool the residual hidden features
|
377 |
+
# but done before pooling only if not pool_first
|
378 |
+
if self.is_proj and not self.pool_first:
|
379 |
+
x_res = self.res_proj(x_res)
|
380 |
+
|
381 |
+
if self.res_conv:
|
382 |
+
|
383 |
+
# Pooling the hidden features with the same conv as Q
|
384 |
+
N, L, C = x_res.shape
|
385 |
+
|
386 |
+
# This handling is the same as that of q in MultiScaleAttention
|
387 |
+
if self.mode == "conv_unshared":
|
388 |
+
fold_dim = 1
|
389 |
+
else:
|
390 |
+
fold_dim = self.F.attn.num_heads
|
391 |
+
|
392 |
+
# Output is (B, N, L, C)
|
393 |
+
x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute(
|
394 |
+
0, 2, 1, 3
|
395 |
+
)
|
396 |
+
|
397 |
+
x_res, _ = attention_pool(
|
398 |
+
x_res,
|
399 |
+
self.F.attn.pool_q,
|
400 |
+
# thw_shape = self.attention.attn.thw,
|
401 |
+
thw_shape=self.F.thw,
|
402 |
+
has_cls_embed=self.has_cls_embed,
|
403 |
+
norm=self.F.attn.norm_q
|
404 |
+
if hasattr(self.F.attn, "norm_q")
|
405 |
+
else None,
|
406 |
+
)
|
407 |
+
x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C)
|
408 |
+
|
409 |
+
else:
|
410 |
+
# Pooling the hidden features with max op
|
411 |
+
x_res, _ = attention_pool(
|
412 |
+
x_res,
|
413 |
+
self.pool_skip,
|
414 |
+
thw_shape=self.F.attn.thw,
|
415 |
+
has_cls_embed=self.has_cls_embed,
|
416 |
+
)
|
417 |
+
|
418 |
+
# If pool_first then project to higher dim now
|
419 |
+
if self.is_proj and self.pool_first:
|
420 |
+
x_res = self.res_proj(x_res)
|
421 |
+
|
422 |
+
x = self.F(x)
|
423 |
+
x = x_res + x
|
424 |
+
x = x + self.G(x)
|
425 |
+
|
426 |
+
x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training)
|
427 |
+
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
class ReversibleBlock(nn.Module):
|
432 |
+
"""
|
433 |
+
Reversible Blocks for Reversible Vision Transformer and also
|
434 |
+
for state-preserving blocks in Reversible MViT. See Section
|
435 |
+
3.3.2 in paper for details.
|
436 |
+
"""
|
437 |
+
|
438 |
+
def __init__(
|
439 |
+
self,
|
440 |
+
dim,
|
441 |
+
input_size,
|
442 |
+
dim_out,
|
443 |
+
num_heads,
|
444 |
+
mlp_ratio,
|
445 |
+
qkv_bias,
|
446 |
+
drop_path,
|
447 |
+
kernel_q,
|
448 |
+
kernel_kv,
|
449 |
+
stride_q,
|
450 |
+
stride_kv,
|
451 |
+
cfg,
|
452 |
+
norm_layer=nn.LayerNorm,
|
453 |
+
layer_id=0,
|
454 |
+
**kwargs
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Block is composed entirely of function F (Attention
|
458 |
+
sub-block) and G (MLP sub-block) including layernorm.
|
459 |
+
"""
|
460 |
+
super().__init__()
|
461 |
+
|
462 |
+
self.drop_path_rate = drop_path
|
463 |
+
|
464 |
+
self.F = AttentionSubBlock(
|
465 |
+
dim=dim,
|
466 |
+
input_size=input_size,
|
467 |
+
num_heads=num_heads,
|
468 |
+
cfg=cfg,
|
469 |
+
dim_out=dim_out,
|
470 |
+
kernel_q=kernel_q,
|
471 |
+
kernel_kv=kernel_kv,
|
472 |
+
stride_q=stride_q,
|
473 |
+
stride_kv=stride_kv,
|
474 |
+
norm_layer=norm_layer,
|
475 |
+
)
|
476 |
+
|
477 |
+
self.G = MLPSubblock(
|
478 |
+
dim=dim,
|
479 |
+
mlp_ratio=mlp_ratio,
|
480 |
+
norm_layer=norm_layer,
|
481 |
+
)
|
482 |
+
|
483 |
+
self.layer_id = layer_id
|
484 |
+
|
485 |
+
self.seeds = {}
|
486 |
+
|
487 |
+
def seed_cuda(self, key):
|
488 |
+
"""
|
489 |
+
Fix seeds to allow for stochastic elements such as
|
490 |
+
dropout to be reproduced exactly in activation
|
491 |
+
recomputation in the backward pass.
|
492 |
+
"""
|
493 |
+
|
494 |
+
# randomize seeds
|
495 |
+
# use cuda generator if available
|
496 |
+
if (
|
497 |
+
hasattr(torch.cuda, "default_generators")
|
498 |
+
and len(torch.cuda.default_generators) > 0
|
499 |
+
):
|
500 |
+
# GPU
|
501 |
+
device_idx = torch.cuda.current_device()
|
502 |
+
seed = torch.cuda.default_generators[device_idx].seed()
|
503 |
+
else:
|
504 |
+
# CPU
|
505 |
+
seed = int(torch.seed() % sys.maxsize)
|
506 |
+
|
507 |
+
self.seeds[key] = seed
|
508 |
+
torch.manual_seed(self.seeds[key])
|
509 |
+
|
510 |
+
def forward(self, X_1, X_2):
|
511 |
+
"""
|
512 |
+
forward pass equations:
|
513 |
+
Y_1 = X_1 + Attention(X_2), F = Attention
|
514 |
+
Y_2 = X_2 + MLP(Y_1), G = MLP
|
515 |
+
"""
|
516 |
+
|
517 |
+
self.seed_cuda("attn")
|
518 |
+
# Y_1 : attn_output
|
519 |
+
f_X_2 = self.F(X_2)
|
520 |
+
|
521 |
+
self.seed_cuda("droppath")
|
522 |
+
f_X_2_dropped = drop_path(
|
523 |
+
f_X_2, drop_prob=self.drop_path_rate, training=self.training
|
524 |
+
)
|
525 |
+
|
526 |
+
# Y_1 = X_1 + f(X_2)
|
527 |
+
Y_1 = X_1 + f_X_2_dropped
|
528 |
+
|
529 |
+
# free memory
|
530 |
+
del X_1
|
531 |
+
|
532 |
+
self.seed_cuda("FFN")
|
533 |
+
g_Y_1 = self.G(Y_1)
|
534 |
+
|
535 |
+
torch.manual_seed(self.seeds["droppath"])
|
536 |
+
g_Y_1_dropped = drop_path(
|
537 |
+
g_Y_1, drop_prob=self.drop_path_rate, training=self.training
|
538 |
+
)
|
539 |
+
|
540 |
+
# Y_2 = X_2 + g(Y_1)
|
541 |
+
Y_2 = X_2 + g_Y_1_dropped
|
542 |
+
|
543 |
+
del X_2
|
544 |
+
|
545 |
+
return Y_1, Y_2
|
546 |
+
|
547 |
+
def backward_pass(
|
548 |
+
self,
|
549 |
+
Y_1,
|
550 |
+
Y_2,
|
551 |
+
dY_1,
|
552 |
+
dY_2,
|
553 |
+
):
|
554 |
+
"""
|
555 |
+
equation for activation recomputation:
|
556 |
+
X_2 = Y_2 - G(Y_1), G = MLP
|
557 |
+
X_1 = Y_1 - F(X_2), F = Attention
|
558 |
+
"""
|
559 |
+
|
560 |
+
# temporarily record intermediate activation for G
|
561 |
+
# and use them for gradient calculcation of G
|
562 |
+
with torch.enable_grad():
|
563 |
+
|
564 |
+
Y_1.requires_grad = True
|
565 |
+
|
566 |
+
torch.manual_seed(self.seeds["FFN"])
|
567 |
+
g_Y_1 = self.G(Y_1)
|
568 |
+
|
569 |
+
torch.manual_seed(self.seeds["droppath"])
|
570 |
+
g_Y_1 = drop_path(
|
571 |
+
g_Y_1, drop_prob=self.drop_path_rate, training=self.training
|
572 |
+
)
|
573 |
+
|
574 |
+
g_Y_1.backward(dY_2, retain_graph=True)
|
575 |
+
|
576 |
+
# activation recomputation is by design and not part of
|
577 |
+
# the computation graph in forward pass.
|
578 |
+
with torch.no_grad():
|
579 |
+
|
580 |
+
X_2 = Y_2 - g_Y_1
|
581 |
+
del g_Y_1
|
582 |
+
|
583 |
+
dY_1 = dY_1 + Y_1.grad
|
584 |
+
Y_1.grad = None
|
585 |
+
|
586 |
+
# record F activations and calc gradients on F
|
587 |
+
with torch.enable_grad():
|
588 |
+
X_2.requires_grad = True
|
589 |
+
|
590 |
+
torch.manual_seed(self.seeds["attn"])
|
591 |
+
f_X_2 = self.F(X_2)
|
592 |
+
|
593 |
+
torch.manual_seed(self.seeds["droppath"])
|
594 |
+
f_X_2 = drop_path(
|
595 |
+
f_X_2, drop_prob=self.drop_path_rate, training=self.training
|
596 |
+
)
|
597 |
+
|
598 |
+
f_X_2.backward(dY_1, retain_graph=True)
|
599 |
+
|
600 |
+
# propagate reverse computed acitvations at the start of
|
601 |
+
# the previou block for backprop.s
|
602 |
+
with torch.no_grad():
|
603 |
+
|
604 |
+
X_1 = Y_1 - f_X_2
|
605 |
+
|
606 |
+
del f_X_2, Y_1
|
607 |
+
dY_2 = dY_2 + X_2.grad
|
608 |
+
|
609 |
+
X_2.grad = None
|
610 |
+
X_2 = X_2.detach()
|
611 |
+
|
612 |
+
return X_1, X_2, dY_1, dY_2
|
613 |
+
|
614 |
+
|
615 |
+
class MLPSubblock(nn.Module):
|
616 |
+
"""
|
617 |
+
This creates the function G such that the entire block can be
|
618 |
+
expressed as F(G(X)). Includes pre-LayerNorm.
|
619 |
+
"""
|
620 |
+
|
621 |
+
def __init__(
|
622 |
+
self,
|
623 |
+
dim,
|
624 |
+
mlp_ratio,
|
625 |
+
norm_layer=nn.LayerNorm,
|
626 |
+
):
|
627 |
+
|
628 |
+
super().__init__()
|
629 |
+
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)
|
630 |
+
|
631 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
632 |
+
|
633 |
+
self.mlp = Mlp(
|
634 |
+
in_features=dim,
|
635 |
+
hidden_features=mlp_hidden_dim,
|
636 |
+
act_layer=nn.GELU,
|
637 |
+
)
|
638 |
+
|
639 |
+
def forward(self, x):
|
640 |
+
return self.mlp(self.norm(x))
|
641 |
+
|
642 |
+
|
643 |
+
class AttentionSubBlock(nn.Module):
|
644 |
+
"""
|
645 |
+
This creates the function F such that the entire block can be
|
646 |
+
expressed as F(G(X)). Includes pre-LayerNorm.
|
647 |
+
"""
|
648 |
+
|
649 |
+
def __init__(
|
650 |
+
self,
|
651 |
+
dim,
|
652 |
+
input_size,
|
653 |
+
num_heads,
|
654 |
+
cfg,
|
655 |
+
dim_out=None,
|
656 |
+
kernel_q=(1, 1, 1),
|
657 |
+
kernel_kv=(1, 1, 1),
|
658 |
+
stride_q=(1, 1, 1),
|
659 |
+
stride_kv=(1, 1, 1),
|
660 |
+
norm_layer=nn.LayerNorm,
|
661 |
+
):
|
662 |
+
|
663 |
+
super().__init__()
|
664 |
+
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)
|
665 |
+
|
666 |
+
# This will be set externally during init
|
667 |
+
self.thw = None
|
668 |
+
|
669 |
+
# the actual attention details are the same as Multiscale
|
670 |
+
# attention for MViTv2 (with channel up=projection inside block)
|
671 |
+
# can also implement no upprojection attention for vanilla ViT
|
672 |
+
self.attn = MultiScaleAttention(
|
673 |
+
dim,
|
674 |
+
dim_out,
|
675 |
+
input_size=input_size,
|
676 |
+
num_heads=num_heads,
|
677 |
+
kernel_q=kernel_q,
|
678 |
+
kernel_kv=kernel_kv,
|
679 |
+
stride_q=stride_q,
|
680 |
+
stride_kv=stride_kv,
|
681 |
+
norm_layer=norm_layer,
|
682 |
+
drop_rate=cfg.MVIT.DROPOUT_RATE,
|
683 |
+
qkv_bias=cfg.MVIT.QKV_BIAS,
|
684 |
+
has_cls_embed=cfg.MVIT.CLS_EMBED_ON,
|
685 |
+
mode=cfg.MVIT.MODE,
|
686 |
+
pool_first=cfg.MVIT.POOL_FIRST,
|
687 |
+
rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL,
|
688 |
+
rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL,
|
689 |
+
rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT,
|
690 |
+
residual_pooling=cfg.MVIT.RESIDUAL_POOLING,
|
691 |
+
separate_qkv=cfg.MVIT.SEPARATE_QKV,
|
692 |
+
)
|
693 |
+
|
694 |
+
def forward(self, x):
|
695 |
+
out, _ = self.attn(self.norm(x), self.thw)
|
696 |
+
return out
|
skp/models/rev_mvit/stem_helper.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
"""ResNe(X)t 3D stem helper."""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
def get_stem_func(name):
|
11 |
+
"""
|
12 |
+
Retrieves the stem module by name.
|
13 |
+
"""
|
14 |
+
trans_funcs = {"x3d_stem": X3DStem, "basic_stem": ResNetBasicStem}
|
15 |
+
assert (
|
16 |
+
name in trans_funcs.keys()
|
17 |
+
), "Transformation function '{}' not supported".format(name)
|
18 |
+
return trans_funcs[name]
|
19 |
+
|
20 |
+
|
21 |
+
class VideoModelStem(nn.Module):
|
22 |
+
"""
|
23 |
+
Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool
|
24 |
+
on input data tensor for one or multiple pathways.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
dim_in,
|
30 |
+
dim_out,
|
31 |
+
kernel,
|
32 |
+
stride,
|
33 |
+
padding,
|
34 |
+
inplace_relu=True,
|
35 |
+
eps=1e-5,
|
36 |
+
bn_mmt=0.1,
|
37 |
+
norm_module=nn.BatchNorm3d,
|
38 |
+
stem_func_name="basic_stem",
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
The `__init__` method of any subclass should also contain these
|
42 |
+
arguments. List size of 1 for single pathway models (C2D, I3D, Slow
|
43 |
+
and etc), list size of 2 for two pathway models (SlowFast).
|
44 |
+
|
45 |
+
Args:
|
46 |
+
dim_in (list): the list of channel dimensions of the inputs.
|
47 |
+
dim_out (list): the output dimension of the convolution in the stem
|
48 |
+
layer.
|
49 |
+
kernel (list): the kernels' size of the convolutions in the stem
|
50 |
+
layers. Temporal kernel size, height kernel size, width kernel
|
51 |
+
size in order.
|
52 |
+
stride (list): the stride sizes of the convolutions in the stem
|
53 |
+
layer. Temporal kernel stride, height kernel size, width kernel
|
54 |
+
size in order.
|
55 |
+
padding (list): the paddings' sizes of the convolutions in the stem
|
56 |
+
layer. Temporal padding size, height padding size, width padding
|
57 |
+
size in order.
|
58 |
+
inplace_relu (bool): calculate the relu on the original input
|
59 |
+
without allocating new memory.
|
60 |
+
eps (float): epsilon for batch norm.
|
61 |
+
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
|
62 |
+
PyTorch = 1 - BN momentum in Caffe2.
|
63 |
+
norm_module (nn.Module): nn.Module for the normalization layer. The
|
64 |
+
default is nn.BatchNorm3d.
|
65 |
+
stem_func_name (string): name of the the stem function applied on
|
66 |
+
input to the network.
|
67 |
+
"""
|
68 |
+
super(VideoModelStem, self).__init__()
|
69 |
+
|
70 |
+
assert (
|
71 |
+
len(
|
72 |
+
{
|
73 |
+
len(dim_in),
|
74 |
+
len(dim_out),
|
75 |
+
len(kernel),
|
76 |
+
len(stride),
|
77 |
+
len(padding),
|
78 |
+
}
|
79 |
+
)
|
80 |
+
== 1
|
81 |
+
), "Input pathway dimensions are not consistent. {} {} {} {} {}".format(
|
82 |
+
len(dim_in),
|
83 |
+
len(dim_out),
|
84 |
+
len(kernel),
|
85 |
+
len(stride),
|
86 |
+
len(padding),
|
87 |
+
)
|
88 |
+
|
89 |
+
self.num_pathways = len(dim_in)
|
90 |
+
self.kernel = kernel
|
91 |
+
self.stride = stride
|
92 |
+
self.padding = padding
|
93 |
+
self.inplace_relu = inplace_relu
|
94 |
+
self.eps = eps
|
95 |
+
self.bn_mmt = bn_mmt
|
96 |
+
# Construct the stem layer.
|
97 |
+
self._construct_stem(dim_in, dim_out, norm_module, stem_func_name)
|
98 |
+
|
99 |
+
def _construct_stem(self, dim_in, dim_out, norm_module, stem_func_name):
|
100 |
+
trans_func = get_stem_func(stem_func_name)
|
101 |
+
|
102 |
+
for pathway in range(len(dim_in)):
|
103 |
+
stem = trans_func(
|
104 |
+
dim_in[pathway],
|
105 |
+
dim_out[pathway],
|
106 |
+
self.kernel[pathway],
|
107 |
+
self.stride[pathway],
|
108 |
+
self.padding[pathway],
|
109 |
+
self.inplace_relu,
|
110 |
+
self.eps,
|
111 |
+
self.bn_mmt,
|
112 |
+
norm_module,
|
113 |
+
)
|
114 |
+
self.add_module("pathway{}_stem".format(pathway), stem)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
assert (
|
118 |
+
len(x) == self.num_pathways
|
119 |
+
), "Input tensor does not contain {} pathway".format(self.num_pathways)
|
120 |
+
# use a new list, don't modify in-place the x list, which is bad for activation checkpointing.
|
121 |
+
y = []
|
122 |
+
for pathway in range(len(x)):
|
123 |
+
m = getattr(self, "pathway{}_stem".format(pathway))
|
124 |
+
y.append(m(x[pathway]))
|
125 |
+
return y
|
126 |
+
|
127 |
+
|
128 |
+
class ResNetBasicStem(nn.Module):
|
129 |
+
"""
|
130 |
+
ResNe(X)t 3D stem module.
|
131 |
+
Performs spatiotemporal Convolution, BN, and Relu following by a
|
132 |
+
spatiotemporal pooling.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
dim_in,
|
138 |
+
dim_out,
|
139 |
+
kernel,
|
140 |
+
stride,
|
141 |
+
padding,
|
142 |
+
inplace_relu=True,
|
143 |
+
eps=1e-5,
|
144 |
+
bn_mmt=0.1,
|
145 |
+
norm_module=nn.BatchNorm3d,
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
The `__init__` method of any subclass should also contain these arguments.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
dim_in (int): the channel dimension of the input. Normally 3 is used
|
152 |
+
for rgb input, and 2 or 3 is used for optical flow input.
|
153 |
+
dim_out (int): the output dimension of the convolution in the stem
|
154 |
+
layer.
|
155 |
+
kernel (list): the kernel size of the convolution in the stem layer.
|
156 |
+
temporal kernel size, height kernel size, width kernel size in
|
157 |
+
order.
|
158 |
+
stride (list): the stride size of the convolution in the stem layer.
|
159 |
+
temporal kernel stride, height kernel size, width kernel size in
|
160 |
+
order.
|
161 |
+
padding (int): the padding size of the convolution in the stem
|
162 |
+
layer, temporal padding size, height padding size, width
|
163 |
+
padding size in order.
|
164 |
+
inplace_relu (bool): calculate the relu on the original input
|
165 |
+
without allocating new memory.
|
166 |
+
eps (float): epsilon for batch norm.
|
167 |
+
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
|
168 |
+
PyTorch = 1 - BN momentum in Caffe2.
|
169 |
+
norm_module (nn.Module): nn.Module for the normalization layer. The
|
170 |
+
default is nn.BatchNorm3d.
|
171 |
+
"""
|
172 |
+
super(ResNetBasicStem, self).__init__()
|
173 |
+
self.kernel = kernel
|
174 |
+
self.stride = stride
|
175 |
+
self.padding = padding
|
176 |
+
self.inplace_relu = inplace_relu
|
177 |
+
self.eps = eps
|
178 |
+
self.bn_mmt = bn_mmt
|
179 |
+
# Construct the stem layer.
|
180 |
+
self._construct_stem(dim_in, dim_out, norm_module)
|
181 |
+
|
182 |
+
def _construct_stem(self, dim_in, dim_out, norm_module):
|
183 |
+
self.conv = nn.Conv3d(
|
184 |
+
dim_in,
|
185 |
+
dim_out,
|
186 |
+
self.kernel,
|
187 |
+
stride=self.stride,
|
188 |
+
padding=self.padding,
|
189 |
+
bias=False,
|
190 |
+
)
|
191 |
+
self.bn = norm_module(
|
192 |
+
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
|
193 |
+
)
|
194 |
+
self.relu = nn.ReLU(self.inplace_relu)
|
195 |
+
self.pool_layer = nn.MaxPool3d(
|
196 |
+
kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
x = self.conv(x)
|
201 |
+
x = self.bn(x)
|
202 |
+
x = self.relu(x)
|
203 |
+
x = self.pool_layer(x)
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class X3DStem(nn.Module):
|
208 |
+
"""
|
209 |
+
X3D's 3D stem module.
|
210 |
+
Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a
|
211 |
+
spatiotemporal pooling.
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
dim_in,
|
217 |
+
dim_out,
|
218 |
+
kernel,
|
219 |
+
stride,
|
220 |
+
padding,
|
221 |
+
inplace_relu=True,
|
222 |
+
eps=1e-5,
|
223 |
+
bn_mmt=0.1,
|
224 |
+
norm_module=nn.BatchNorm3d,
|
225 |
+
):
|
226 |
+
"""
|
227 |
+
The `__init__` method of any subclass should also contain these arguments.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
dim_in (int): the channel dimension of the input. Normally 3 is used
|
231 |
+
for rgb input, and 2 or 3 is used for optical flow input.
|
232 |
+
dim_out (int): the output dimension of the convolution in the stem
|
233 |
+
layer.
|
234 |
+
kernel (list): the kernel size of the convolution in the stem layer.
|
235 |
+
temporal kernel size, height kernel size, width kernel size in
|
236 |
+
order.
|
237 |
+
stride (list): the stride size of the convolution in the stem layer.
|
238 |
+
temporal kernel stride, height kernel size, width kernel size in
|
239 |
+
order.
|
240 |
+
padding (int): the padding size of the convolution in the stem
|
241 |
+
layer, temporal padding size, height padding size, width
|
242 |
+
padding size in order.
|
243 |
+
inplace_relu (bool): calculate the relu on the original input
|
244 |
+
without allocating new memory.
|
245 |
+
eps (float): epsilon for batch norm.
|
246 |
+
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
|
247 |
+
PyTorch = 1 - BN momentum in Caffe2.
|
248 |
+
norm_module (nn.Module): nn.Module for the normalization layer. The
|
249 |
+
default is nn.BatchNorm3d.
|
250 |
+
"""
|
251 |
+
super(X3DStem, self).__init__()
|
252 |
+
self.kernel = kernel
|
253 |
+
self.stride = stride
|
254 |
+
self.padding = padding
|
255 |
+
self.inplace_relu = inplace_relu
|
256 |
+
self.eps = eps
|
257 |
+
self.bn_mmt = bn_mmt
|
258 |
+
# Construct the stem layer.
|
259 |
+
self._construct_stem(dim_in, dim_out, norm_module)
|
260 |
+
|
261 |
+
def _construct_stem(self, dim_in, dim_out, norm_module):
|
262 |
+
self.conv_xy = nn.Conv3d(
|
263 |
+
dim_in,
|
264 |
+
dim_out,
|
265 |
+
kernel_size=(1, self.kernel[1], self.kernel[2]),
|
266 |
+
stride=(1, self.stride[1], self.stride[2]),
|
267 |
+
padding=(0, self.padding[1], self.padding[2]),
|
268 |
+
bias=False,
|
269 |
+
)
|
270 |
+
self.conv = nn.Conv3d(
|
271 |
+
dim_out,
|
272 |
+
dim_out,
|
273 |
+
kernel_size=(self.kernel[0], 1, 1),
|
274 |
+
stride=(self.stride[0], 1, 1),
|
275 |
+
padding=(self.padding[0], 0, 0),
|
276 |
+
bias=False,
|
277 |
+
groups=dim_out,
|
278 |
+
)
|
279 |
+
|
280 |
+
self.bn = norm_module(
|
281 |
+
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
|
282 |
+
)
|
283 |
+
self.relu = nn.ReLU(self.inplace_relu)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
x = self.conv_xy(x)
|
287 |
+
x = self.conv(x)
|
288 |
+
x = self.bn(x)
|
289 |
+
x = self.relu(x)
|
290 |
+
return x
|
291 |
+
|
292 |
+
|
293 |
+
class PatchEmbed(nn.Module):
|
294 |
+
"""
|
295 |
+
PatchEmbed.
|
296 |
+
"""
|
297 |
+
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
dim_in=3,
|
301 |
+
dim_out=768,
|
302 |
+
kernel=(1, 16, 16),
|
303 |
+
stride=(1, 4, 4),
|
304 |
+
padding=(1, 7, 7),
|
305 |
+
conv_2d=False,
|
306 |
+
):
|
307 |
+
super().__init__()
|
308 |
+
if conv_2d:
|
309 |
+
conv = nn.Conv2d
|
310 |
+
else:
|
311 |
+
conv = nn.Conv3d
|
312 |
+
self.proj = conv(
|
313 |
+
dim_in,
|
314 |
+
dim_out,
|
315 |
+
kernel_size=kernel,
|
316 |
+
stride=stride,
|
317 |
+
padding=padding,
|
318 |
+
)
|
319 |
+
|
320 |
+
def forward(self, x, keep_spatial=False):
|
321 |
+
x = self.proj(x)
|
322 |
+
if keep_spatial:
|
323 |
+
return x, x.shape
|
324 |
+
# B C (T) H W -> B (T)HW C
|
325 |
+
return x.flatten(2).transpose(1, 2), x.shape
|
skp/models/rev_mvit/utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def round_width(width, multiplier, min_width=1, divisor=1, verbose=False):
|
8 |
+
if not multiplier:
|
9 |
+
return width
|
10 |
+
width *= multiplier
|
11 |
+
min_width = min_width or divisor
|
12 |
+
if verbose:
|
13 |
+
print(f"min width {min_width}")
|
14 |
+
print(f"width {width} divisor {divisor}")
|
15 |
+
print(f"other {int(width + divisor / 2) // divisor * divisor}")
|
16 |
+
|
17 |
+
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
|
18 |
+
if width_out < 0.9 * width:
|
19 |
+
width_out += divisor
|
20 |
+
return int(width_out)
|
21 |
+
|
22 |
+
|
23 |
+
def validate_checkpoint_wrapper_import(checkpoint_wrapper):
|
24 |
+
"""
|
25 |
+
Check if checkpoint_wrapper is imported.
|
26 |
+
"""
|
27 |
+
if checkpoint_wrapper is None:
|
28 |
+
raise ImportError("Please install fairscale.")
|
29 |
+
|
30 |
+
|
31 |
+
def get_gkern(kernlen, std):
|
32 |
+
"""Returns a 2D Gaussian kernel array."""
|
33 |
+
|
34 |
+
def _gaussian_fn(kernlen, std):
|
35 |
+
n = torch.arange(0, kernlen).float()
|
36 |
+
n -= n.mean()
|
37 |
+
n /= std
|
38 |
+
w = torch.exp(-0.5 * n**2)
|
39 |
+
return w
|
40 |
+
|
41 |
+
gkern1d = _gaussian_fn(kernlen, std)
|
42 |
+
gkern2d = torch.outer(gkern1d, gkern1d)
|
43 |
+
return gkern2d / gkern2d.sum()
|
44 |
+
|
45 |
+
|
46 |
+
# --------------------------------------------------------
|
47 |
+
# 2D sine-cosine position embedding
|
48 |
+
# References:
|
49 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
50 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
51 |
+
# --------------------------------------------------------
|
52 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
|
53 |
+
"""
|
54 |
+
grid_size: int of the grid height and width
|
55 |
+
t_size: int of the temporal size
|
56 |
+
return:
|
57 |
+
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
58 |
+
"""
|
59 |
+
assert embed_dim % 4 == 0
|
60 |
+
embed_dim_spatial = embed_dim // 4 * 3
|
61 |
+
embed_dim_temporal = embed_dim // 4
|
62 |
+
|
63 |
+
# spatial
|
64 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
65 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
66 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
67 |
+
grid = np.stack(grid, axis=0)
|
68 |
+
|
69 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
70 |
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
|
71 |
+
embed_dim_spatial, grid
|
72 |
+
)
|
73 |
+
|
74 |
+
# temporal
|
75 |
+
grid_t = np.arange(t_size, dtype=np.float32)
|
76 |
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
|
77 |
+
embed_dim_temporal, grid_t
|
78 |
+
)
|
79 |
+
|
80 |
+
# concate: [T, H, W] order
|
81 |
+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
82 |
+
pos_embed_temporal = np.repeat(
|
83 |
+
pos_embed_temporal, grid_size**2, axis=1
|
84 |
+
) # [T, H*W, D // 4]
|
85 |
+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
86 |
+
pos_embed_spatial = np.repeat(
|
87 |
+
pos_embed_spatial, t_size, axis=0
|
88 |
+
) # [T, H*W, D // 4 * 3]
|
89 |
+
|
90 |
+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
|
91 |
+
pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
|
92 |
+
|
93 |
+
if cls_token:
|
94 |
+
pos_embed = np.concatenate(
|
95 |
+
[np.zeros([1, embed_dim]), pos_embed], axis=0
|
96 |
+
)
|
97 |
+
return pos_embed
|
98 |
+
|
99 |
+
|
100 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
101 |
+
"""
|
102 |
+
grid_size: int of the grid height and width
|
103 |
+
return:
|
104 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
105 |
+
"""
|
106 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
107 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
108 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
109 |
+
grid = np.stack(grid, axis=0)
|
110 |
+
|
111 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
112 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
113 |
+
if cls_token:
|
114 |
+
pos_embed = np.concatenate(
|
115 |
+
[np.zeros([1, embed_dim]), pos_embed], axis=0
|
116 |
+
)
|
117 |
+
return pos_embed
|
118 |
+
|
119 |
+
|
120 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
121 |
+
assert embed_dim % 2 == 0
|
122 |
+
|
123 |
+
# use half of dimensions to encode grid_h
|
124 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(
|
125 |
+
embed_dim // 2, grid[0]
|
126 |
+
) # (H*W, D/2)
|
127 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(
|
128 |
+
embed_dim // 2, grid[1]
|
129 |
+
) # (H*W, D/2)
|
130 |
+
|
131 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
132 |
+
return emb
|
133 |
+
|
134 |
+
|
135 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
136 |
+
"""
|
137 |
+
embed_dim: output dimension for each position
|
138 |
+
pos: a list of positions to be encoded: size (M,)
|
139 |
+
out: (M, D)
|
140 |
+
"""
|
141 |
+
assert embed_dim % 2 == 0
|
142 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
143 |
+
omega /= embed_dim / 2.0
|
144 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
145 |
+
|
146 |
+
pos = pos.reshape(-1) # (M,)
|
147 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
148 |
+
|
149 |
+
emb_sin = np.sin(out) # (M, D/2)
|
150 |
+
emb_cos = np.cos(out) # (M, D/2)
|
151 |
+
|
152 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
153 |
+
return emb
|
154 |
+
|
155 |
+
|
156 |
+
# --------------------------------------------------------
|
157 |
+
# Interpolate position embeddings for high-resolution
|
158 |
+
# References:
|
159 |
+
# DeiT: https://github.com/facebookresearch/deit
|
160 |
+
# --------------------------------------------------------
|
161 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
162 |
+
if "pos_embed" in checkpoint_model:
|
163 |
+
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
164 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
165 |
+
num_patches = model.patch_embed.num_patches
|
166 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
167 |
+
# height (== width) for the checkpoint position embedding
|
168 |
+
orig_size = int(
|
169 |
+
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
|
170 |
+
)
|
171 |
+
# height (== width) for the new position embedding
|
172 |
+
new_size = int(num_patches**0.5)
|
173 |
+
# class_token and dist_token are kept unchanged
|
174 |
+
if orig_size != new_size:
|
175 |
+
print(
|
176 |
+
"Position interpolate from %dx%d to %dx%d"
|
177 |
+
% (orig_size, orig_size, new_size, new_size)
|
178 |
+
)
|
179 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
180 |
+
# only the position tokens are interpolated
|
181 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
182 |
+
pos_tokens = pos_tokens.reshape(
|
183 |
+
-1, orig_size, orig_size, embedding_size
|
184 |
+
).permute(0, 3, 1, 2)
|
185 |
+
pos_tokens = torch.nn.functional.interpolate(
|
186 |
+
pos_tokens,
|
187 |
+
size=(new_size, new_size),
|
188 |
+
mode="bicubic",
|
189 |
+
align_corners=False,
|
190 |
+
)
|
191 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
192 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
193 |
+
checkpoint_model["pos_embed"] = new_pos_embed
|
194 |
+
|
195 |
+
|
196 |
+
def calc_mvit_feature_geometry(cfg):
|
197 |
+
feat_size = [
|
198 |
+
[
|
199 |
+
cfg.DATA.NUM_FRAMES // cfg.MVIT.PATCH_STRIDE[0]
|
200 |
+
if len(cfg.MVIT.PATCH_STRIDE) > 2
|
201 |
+
else 1,
|
202 |
+
cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-2],
|
203 |
+
cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-1],
|
204 |
+
]
|
205 |
+
for i in range(cfg.MVIT.DEPTH)
|
206 |
+
]
|
207 |
+
feat_stride = [
|
208 |
+
[
|
209 |
+
cfg.MVIT.PATCH_STRIDE[0] if len(cfg.MVIT.PATCH_STRIDE) > 2 else 1,
|
210 |
+
cfg.MVIT.PATCH_STRIDE[-2],
|
211 |
+
cfg.MVIT.PATCH_STRIDE[-1],
|
212 |
+
]
|
213 |
+
for i in range(cfg.MVIT.DEPTH)
|
214 |
+
]
|
215 |
+
for _, x in enumerate(cfg.MVIT.POOL_Q_STRIDE):
|
216 |
+
for i in range(cfg.MVIT.DEPTH):
|
217 |
+
if i >= x[0]:
|
218 |
+
for j in range(len(feat_size[i])):
|
219 |
+
feat_size[i][j] = feat_size[i][j] // x[j + 1]
|
220 |
+
feat_stride[i][j] = feat_stride[i][j] * x[j + 1]
|
221 |
+
return feat_size, feat_stride
|