ianpan commited on
Commit
231edce
·
1 Parent(s): a1b5998

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +81 -0
  2. configs/chunk000.yaml +89 -0
  3. configs/chunkseq003.yaml +67 -0
  4. configs/pseudoseg000.yaml +110 -0
  5. examples/1.2.826.0.1.3680043.15773.nii.gz +3 -0
  6. packages.txt +1 -0
  7. requirements.txt +7 -0
  8. seg.ckpt +3 -0
  9. seq.ckpt +3 -0
  10. skp/.DS_Store +0 -0
  11. skp/__init__.py +0 -0
  12. skp/__pycache__/__init__.cpython-39.pyc +0 -0
  13. skp/__pycache__/builder.cpython-39.pyc +0 -0
  14. skp/builder.py +187 -0
  15. skp/models/__init__.py +1 -0
  16. skp/models/__pycache__/__init__.cpython-39.pyc +0 -0
  17. skp/models/__pycache__/backbones.cpython-39.pyc +0 -0
  18. skp/models/__pycache__/engine.cpython-39.pyc +0 -0
  19. skp/models/__pycache__/sequence.cpython-39.pyc +0 -0
  20. skp/models/__pycache__/tools.cpython-39.pyc +0 -0
  21. skp/models/backbones.py +114 -0
  22. skp/models/engine.py +257 -0
  23. skp/models/pooling/__init__.py +3 -0
  24. skp/models/pooling/__pycache__/__init__.cpython-39.pyc +0 -0
  25. skp/models/pooling/__pycache__/gem.cpython-39.pyc +0 -0
  26. skp/models/pooling/__pycache__/pool1d.cpython-39.pyc +0 -0
  27. skp/models/pooling/__pycache__/pool2d.cpython-39.pyc +0 -0
  28. skp/models/pooling/__pycache__/pool3d.cpython-39.pyc +0 -0
  29. skp/models/pooling/gem.py +35 -0
  30. skp/models/pooling/pool1d.py +107 -0
  31. skp/models/pooling/pool2d.py +16 -0
  32. skp/models/pooling/pool3d.py +107 -0
  33. skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml +109 -0
  34. skp/models/rev_mvit/__init__.py +0 -0
  35. skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc +0 -0
  36. skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc +0 -0
  37. skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc +0 -0
  38. skp/models/rev_mvit/__pycache__/common.cpython-39.pyc +0 -0
  39. skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc +0 -0
  40. skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc +0 -0
  41. skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc +0 -0
  42. skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc +0 -0
  43. skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc +0 -0
  44. skp/models/rev_mvit/attention.py +568 -0
  45. skp/models/rev_mvit/batchnorm_helper.py +112 -0
  46. skp/models/rev_mvit/common.py +154 -0
  47. skp/models/rev_mvit/head_helper.py +140 -0
  48. skp/models/rev_mvit/reversible_mvit.py +696 -0
  49. skp/models/rev_mvit/stem_helper.py +325 -0
  50. skp/models/rev_mvit/utils.py +221 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import glob
3
+ import gradio as gr
4
+ import mediapy
5
+ import nibabel
6
+ import numpy as np
7
+ import shutil
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from omegaconf import OmegaConf
12
+ from skp import builder
13
+
14
+
15
+ def window(x, WL=400, WW=2500):
16
+ lower, upper = WL - WW // 2, WL + WW // 2
17
+ x = np.clip(x, lower, upper)
18
+ x = x - lower
19
+ x = x / (upper - lower)
20
+ return (x * 255).astype("uint8")
21
+
22
+
23
+ def rescale(x):
24
+ x = x / 255.
25
+ x = x - 0.5
26
+ x = x * 2.0
27
+ return x
28
+
29
+
30
+ def generate_segmentation_video(study):
31
+ img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0)
32
+ img = window(img)
33
+
34
+ X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
35
+ X = F.interpolate(X, size=(192, 192, 192), mode="nearest")
36
+ X = rescale(X)
37
+ with torch.no_grad():
38
+ seg_output = seg_model(X)
39
+
40
+ seg_output = torch.sigmoid(seg_output)
41
+ p_spine = seg_output[:, :7].sum(1)
42
+ seg_output = torch.argmax(seg_output, dim=1) + 1
43
+ seg_output[p_spine < 0.5] = 0
44
+ seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest")
45
+ seg_output = seg_output.squeeze(0).squeeze(0).numpy()
46
+ seg_output = (seg_output * 255 / 7).astype("uint8")
47
+ seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output])
48
+
49
+ frames = []
50
+ skip = 8
51
+ for idx in range(0, img.shape[2], skip):
52
+ i = img[:, :, idx]
53
+ o = seg_output[:, :, idx]
54
+ i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)
55
+ frame = np.concatenate((i, o), 1)
56
+ frames.append(frame)
57
+ mediapy.write_video("video.mp4", frames, fps=30)
58
+ return "video.mp4"
59
+
60
+
61
+ ffmpeg_path = shutil.which('ffmpeg')
62
+ mediapy.set_ffmpeg(ffmpeg_path)
63
+
64
+ config = OmegaConf.load("configs/pseudoseg000.yaml")
65
+ config.model.load_pretrained = "seg.ckpt"
66
+ seg_model = builder.build_model(config).eval()
67
+ examples = glob.glob("examples/*.nii.gz")
68
+
69
+ with gr.Blocks(theme="dark-peach") as demo:
70
+ select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study")
71
+ button_predict = gr.Button("Predict")
72
+ video_output = gr.Video()
73
+ button_predict.click(fn=generate_segmentation_video,
74
+ inputs=select_study,
75
+ outputs=video_output)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch(debug=True, share=True)
80
+
81
+
configs/chunk000.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ seed: 88
3
+ save_dir: ../experiments/
4
+
5
+
6
+ data:
7
+ annotations: ../data/train_vertebra_chunks_kfold.csv
8
+ data_dir: ../data/train-numpy-vertebra-chunks
9
+ input: filename
10
+ target: fracture
11
+ outer_fold: 0
12
+ dataset:
13
+ name: NumpyChunkDataset
14
+ params:
15
+ flip: true
16
+ invert: false
17
+ channels: grayscale
18
+ z_lt: resample_resample
19
+ z_gt: resample_resample
20
+ num_images: 64
21
+
22
+
23
+ transform:
24
+ resize:
25
+ name: resize_ignore_3d
26
+ params:
27
+ imsize: [64, 288, 288]
28
+ augment:
29
+ null
30
+ crop:
31
+ null
32
+ preprocess:
33
+ name: Preprocessor
34
+ params:
35
+ image_range: [0, 255]
36
+ input_range: [0, 1]
37
+ mean: [0.5]
38
+ sdev: [0.5]
39
+
40
+
41
+ task:
42
+ name: ClassificationTask
43
+ params:
44
+
45
+
46
+ model:
47
+ name: Net3D
48
+ params:
49
+ backbone: x3d_l
50
+ backbone_params:
51
+ z_strides: [1, 1, 1, 1, 1]
52
+ pretrained: true
53
+ num_classes: 1
54
+ dropout: 0.2
55
+ pool: avg
56
+ in_channels: 1
57
+ multisample_dropout: true
58
+
59
+
60
+ loss:
61
+ name: BCEWithLogitsLoss
62
+ params:
63
+
64
+
65
+ optimizer:
66
+ name: AdamW
67
+ params:
68
+ lr: 3.0e-4
69
+ weight_decay: 5.0e-4
70
+
71
+
72
+ scheduler:
73
+ name: CosineAnnealingLR
74
+ params:
75
+ final_lr: 0.0
76
+
77
+
78
+ train:
79
+ batch_size: 4
80
+ num_epochs: 10
81
+
82
+
83
+ evaluate:
84
+ metrics: [AUROC]
85
+ monitor: auc_mean
86
+ mode: max
87
+
88
+
89
+
configs/chunkseq003.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ seed: 88
3
+ save_dir: ../experiments/
4
+
5
+
6
+ data:
7
+ annotations: ../data/train_chunk_features_kfold.csv
8
+ data_dir: ../data/train-chunk000-features/foldx
9
+ input: filename
10
+ target: [C1, C2, C3, C4, C5, C6, C7, patient_overall]
11
+ outer_fold: 0
12
+ dataset:
13
+ name: FeatureDataset
14
+ params:
15
+ seq_len: 7
16
+ reverse: false
17
+ normalize: false
18
+ exam_level_label: true
19
+
20
+
21
+ task:
22
+ name: ClassificationTask
23
+ params:
24
+
25
+
26
+ model:
27
+ name: DualTransformer
28
+ params:
29
+ num_classes: 1
30
+ embedding_dim: 432
31
+ hidden_dim: 864
32
+ n_layers: 3
33
+ n_heads: 16
34
+
35
+
36
+ loss:
37
+ name: MultilabelWeightedBCE
38
+ params:
39
+ weights: [1, 1, 1, 1, 1, 1, 1, 7]
40
+ pos_weight: 2.0
41
+
42
+
43
+ optimizer:
44
+ name: AdamW
45
+ params:
46
+ lr: 1.0e-5
47
+ weight_decay: 5.0e-4
48
+
49
+
50
+ scheduler:
51
+ name: CosineAnnealingLR
52
+ params:
53
+ final_lr: 0
54
+
55
+
56
+ train:
57
+ batch_size: 32
58
+ num_epochs: 25
59
+
60
+
61
+ evaluate:
62
+ batch_size: 1
63
+ metrics: [CompetitionMetric, AUROC]
64
+ monitor: comp_metric
65
+ mode: min
66
+
67
+
configs/pseudoseg000.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ seed: 88
3
+ save_dir: ../experiments/
4
+
5
+
6
+ data:
7
+ annotations: ../data/train_seg_whole_192_kfold_with_pseudo.csv
8
+ data_dir: ../data/
9
+ input: filename
10
+ target: label
11
+ outer_fold: 0
12
+ dataset:
13
+ name: NumpyChunkSegmentDataset
14
+ params:
15
+ segmentation_format: numpy
16
+ channels: grayscale
17
+ flip: true
18
+ transpose: true
19
+ invert: false
20
+ verbose: true
21
+ num_images: 192
22
+ z_lt: resample_resample
23
+ z_gt: resample_resample
24
+ one_hot_encode: true
25
+ num_classes: 8
26
+ add_foreground_channel: false
27
+
28
+
29
+ transform:
30
+ resize:
31
+ name: resize_ignore_3d
32
+ params:
33
+ imsize: [192, 192, 192]
34
+ augment:
35
+ null
36
+ crop:
37
+ null
38
+ preprocess:
39
+ name: Preprocessor
40
+ params:
41
+ image_range: [0, 255]
42
+ input_range: [0, 1]
43
+ mean: [0.5]
44
+ sdev: [0.5]
45
+
46
+
47
+ task:
48
+ name: SegmentationTask3D
49
+ params:
50
+ chunk_validation: true
51
+
52
+
53
+ model:
54
+ name: NetSegment3D
55
+ params:
56
+ architecture: DeepLabV3Plus_3D
57
+ encoder_name: x3d_l
58
+ encoder_params:
59
+ pretrained: true
60
+ output_stride: 16
61
+ z_strides: [2, 2, 2, 2, 2]
62
+ decoder_params:
63
+ upsampling: 4
64
+ deep_supervision: true
65
+ num_classes: 8
66
+ in_channels: 1
67
+ dropout: 0.2
68
+
69
+
70
+ loss:
71
+ name: SupervisorLoss
72
+ params:
73
+ segmentation_loss: DiceBCELoss
74
+ scale_factors: [0.25, 0.25]
75
+ loss_weights: [1.0, 0.25, 0.25]
76
+ loss_params:
77
+ dice_loss_params:
78
+ mode: multilabel
79
+ exponent: 2
80
+ smooth: 1.0
81
+ bce_loss_params:
82
+ smooth_factor: 0.01
83
+ pos_weight: 1.0
84
+ dice_loss_weight: 1.0
85
+ bce_loss_weight: 0.2
86
+
87
+
88
+ optimizer:
89
+ name: AdamW
90
+ params:
91
+ lr: 3.0e-4
92
+ weight_decay: 5.0e-4
93
+
94
+
95
+ scheduler:
96
+ name: CosineAnnealingLR
97
+ params:
98
+ final_lr: 0.0
99
+
100
+
101
+ train:
102
+ batch_size: 4
103
+ num_epochs: 10
104
+
105
+
106
+ evaluate:
107
+ batch_size: 1
108
+ metrics: [DSC]
109
+ monitor: dsc_ignore_mean
110
+ mode: max
examples/1.2.826.0.1.3680043.15773.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a316d4cdb9534c662a209dea2b50fd57168398b1a658d14937ce285d3b792917
3
+ size 65868417
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ mediapy
3
+ nibabel
4
+ opencv-python
5
+ timm
6
+ torch
7
+ transformers
seg.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa6ee0036af98df68621b5cacbed9b4cd290eb1b59c6af7785a7e9c81ed74afa
3
+ size 21569386
seq.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:653637b500e3ae5ffab8b07f34d36662396ab3eacec8024e5ecea952d7c2c07e
3
+ size 18011334
skp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
skp/__init__.py ADDED
File without changes
skp/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (169 Bytes). View file
 
skp/__pycache__/builder.cpython-39.pyc ADDED
Binary file (5.29 kB). View file
 
skp/builder.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from . import models
5
+
6
+
7
+ def get_name_and_params(base):
8
+ name = getattr(base, 'name')
9
+ params = getattr(base, 'params') or {}
10
+ return name, params
11
+
12
+
13
+ def get_transform(base, transform, mode=None):
14
+ if not base: return None
15
+ transform = getattr(base, transform)
16
+ if not transform: return None
17
+ name, params = get_name_and_params(transform)
18
+ if mode:
19
+ params.update({'mode': mode})
20
+ return getattr(data.transforms, name)(**params)
21
+
22
+
23
+ def build_transforms(cfg, mode):
24
+ # 1-Resize
25
+ resizer = get_transform(cfg.transform, 'resize')
26
+ # 2-(Optional) Data augmentation
27
+ augmenter = None
28
+ if mode == "train":
29
+ augmenter = get_transform(cfg.transform, 'augment')
30
+ # 3-(Optional) Crop
31
+ cropper = get_transform(cfg.transform, 'crop', mode=mode)
32
+ # 4-Preprocess
33
+ preprocessor = get_transform(cfg.transform, 'preprocess')
34
+ return {
35
+ 'resize': resizer,
36
+ 'augment': augmenter,
37
+ 'crop': cropper,
38
+ 'preprocess': preprocessor
39
+ }
40
+
41
+
42
+ def build_dataset(cfg, data_info, mode):
43
+ dataset_class = getattr(data.datasets, cfg.data.dataset.name)
44
+ dataset_params = cfg.data.dataset.params
45
+ dataset_params.test_mode = mode != 'train'
46
+ dataset_params = dict(dataset_params)
47
+ if "FeatureDataset" not in cfg.data.dataset.name:
48
+ transforms = build_transforms(cfg, mode)
49
+ dataset_params.update(transforms)
50
+ dataset_params.update(data_info)
51
+ return dataset_class(**dataset_params)
52
+
53
+
54
+ def build_dataloader(cfg, dataset, mode):
55
+
56
+ def worker_init_fn(worker_id):
57
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
58
+
59
+ dataloader_params = {}
60
+ dataloader_params['num_workers'] = cfg.data.num_workers
61
+ dataloader_params['drop_last'] = mode == 'train'
62
+ dataloader_params['shuffle'] = mode == 'train'
63
+ dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True)
64
+ if mode in ('train', 'valid'):
65
+ if mode == "train":
66
+ dataloader_params['batch_size'] = cfg.train.batch_size
67
+ elif mode == "valid":
68
+ dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
69
+ sampler = None
70
+ if cfg.data.get("sampler") and mode == 'train':
71
+ name, params = get_name_and_params(cfg.data.sampler)
72
+ sampler = getattr(data.samplers, name)(dataset, **params)
73
+ if sampler:
74
+ dataloader_params['shuffle'] = False
75
+ if cfg.strategy == 'ddp':
76
+ sampler = data.samplers.DistributedSamplerWrapper(sampler)
77
+ dataloader_params['sampler'] = sampler
78
+ print(f'Using sampler {sampler} for training ...')
79
+ elif cfg.strategy == 'ddp':
80
+ dataloader_params["shuffle"] = False
81
+ dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train")
82
+ else:
83
+ assert cfg.strategy != "ddp", "DDP currently not supported for inference"
84
+ dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
85
+
86
+ loader = DataLoader(dataset,
87
+ **dataloader_params,
88
+ worker_init_fn=worker_init_fn)
89
+ return loader
90
+
91
+
92
+ def build_model(cfg):
93
+ name, params = get_name_and_params(cfg.model)
94
+ if cfg.model.params.get("cnn_params", None):
95
+ cnn_params = cfg.model.params.cnn_params
96
+ if cnn_params.get("load_pretrained_backbone", None):
97
+ if "foldx" in cnn_params.load_pretrained_backbone:
98
+ cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\
99
+ replace("foldx", f"fold{cfg.data.outer_fold}")
100
+ print(f'Creating model <{name}> ...')
101
+ model = getattr(models.engine, name)(**params)
102
+ if 'backbone' in cfg.model.params:
103
+ print(f' Using backbone <{cfg.model.params.backbone}> ...')
104
+ if 'pretrained' in cfg.model.params:
105
+ print(f' Pretrained : {cfg.model.params.pretrained}')
106
+ if "load_pretrained" in cfg.model:
107
+ import re
108
+ if "foldx" in cfg.model.load_pretrained:
109
+ cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}")
110
+ print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}")
111
+ weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict']
112
+ weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k}
113
+ model.load_state_dict(weights)
114
+ return model
115
+
116
+
117
+ def build_loss(cfg):
118
+ name, params = get_name_and_params(cfg.loss)
119
+ print(f'Using loss function <{name}> ...')
120
+ params = dict(params)
121
+ if "pos_weight" in params:
122
+ params["pos_weight"] = torch.tensor(params["pos_weight"])
123
+ criterion = getattr(losses, name)(**params)
124
+ return criterion
125
+
126
+
127
+ def build_scheduler(cfg, optimizer):
128
+ # Some schedulers will require manipulation of config params
129
+ # My specifications were to make it more intuitive for me
130
+ name, params = get_name_and_params(cfg.scheduler)
131
+ print(f'Using learning rate schedule <{name}> ...')
132
+
133
+ if name == 'CosineAnnealingLR':
134
+ # eta_min <-> final_lr
135
+ # Set T_max as 100000 ... this is changed in on_train_start() method
136
+ # of the LightningModule task
137
+
138
+ params = {
139
+ 'T_max': 100000,
140
+ 'eta_min': max(params.final_lr, 1.0e-8)
141
+ }
142
+
143
+ if name in ('OneCycleLR', 'CustomOneCycleLR'):
144
+ # Use learning rate from optimizer parameters as initial learning rate
145
+ lr_0 = cfg.optimizer.params.lr
146
+ lr_1 = params.max_lr
147
+ lr_2 = params.final_lr
148
+ # lr_0 -> lr_1 -> lr_2
149
+ pct_start = params.pct_start
150
+ params = {}
151
+ params['steps_per_epoch'] = 100000 # see above- will fix in task
152
+ params['epochs'] = cfg.train.num_epochs
153
+ params['max_lr'] = lr_1
154
+ params['pct_start'] = pct_start
155
+ params['div_factor'] = lr_1 / lr_0 # max/init
156
+ params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final
157
+
158
+ scheduler = getattr(optim, name)(optimizer=optimizer, **params)
159
+
160
+ # Some schedulers might need more manipulation after instantiation
161
+ if name in ('OneCycleLR', 'CustomOneCycleLR'):
162
+ scheduler.pct_start = params['pct_start']
163
+
164
+ # Set update frequency
165
+ if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'):
166
+ scheduler.update_frequency = 'on_batch'
167
+ elif name in ('ReduceLROnPlateau'):
168
+ scheduler.update_frequency = 'on_valid'
169
+ else:
170
+ scheduler.update_frequency = 'on_epoch'
171
+
172
+ return scheduler
173
+
174
+
175
+ def build_optimizer(cfg, parameters):
176
+ name, params = get_name_and_params(cfg.optimizer)
177
+ print(f'Using optimizer <{name}> ...')
178
+ optimizer = getattr(optim, name)(parameters, **params)
179
+ return optimizer
180
+
181
+
182
+ def build_task(cfg, model):
183
+ name, params = get_name_and_params(cfg.task)
184
+ print(f'Building task <{name}> ...')
185
+ return getattr(tasks, name)(cfg, model, **params)
186
+
187
+
skp/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import engine
skp/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (207 Bytes). View file
 
skp/models/__pycache__/backbones.cpython-39.pyc ADDED
Binary file (3.9 kB). View file
 
skp/models/__pycache__/engine.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
skp/models/__pycache__/sequence.cpython-39.pyc ADDED
Binary file (5.67 kB). View file
 
skp/models/__pycache__/tools.cpython-39.pyc ADDED
Binary file (932 Bytes). View file
 
skp/models/backbones.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import timm
3
+ import torch
4
+
5
+ from functools import partial
6
+ from timm.models.vision_transformer import VisionTransformer
7
+ from timm.models.swin_transformer_v2 import SwinTransformerV2
8
+
9
+ from .vmz.backbones import *
10
+
11
+
12
+ def check_name(name, s):
13
+ return bool(re.search(s, name))
14
+
15
+
16
+ def create_backbone(name, pretrained, features_only=False, **kwargs):
17
+ try:
18
+ model = timm.create_model(name, pretrained=pretrained,
19
+ features_only=features_only,
20
+ num_classes=0, global_pool="")
21
+ except Exception as e:
22
+ assert name in BACKBONES, f"{name} is not a valid backbone"
23
+ model = BACKBONES[name](pretrained=pretrained, features_only=features_only, **kwargs)
24
+ with torch.no_grad():
25
+ if check_name(name, r"x3d|csn|r2plus1d|i3d"):
26
+ dim_feats = model(torch.randn((2, 3, 64, 64, 64))).size(1)
27
+ elif isinstance(model, (VisionTransformer, SwinTransformerV2)):
28
+ dim_feats = model.norm.normalized_shape[0]
29
+ else:
30
+ dim_feats = model(torch.randn((2, 3, 128, 128))).size(1)
31
+ return model, dim_feats
32
+
33
+
34
+ def create_csn(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs):
35
+ if features_only:
36
+ raise Exception("features_only is currently not supported")
37
+ if not pretrained:
38
+ from pytorchvideo.models import hub
39
+ model = getattr(hub, name)(pretrained=False)
40
+ else:
41
+ model = torch.hub.load("facebookresearch/pytorchvideo:main", model=name, pretrained=pretrained)
42
+ model.blocks[5] = nn.Identity()
43
+ return model
44
+
45
+
46
+ def create_x3d(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs):
47
+ if not pretrained:
48
+ from pytorchvideo.models import hub
49
+ model = getattr(hub, name)(pretrained=False)
50
+ else:
51
+ model = torch.hub.load("facebookresearch/pytorchvideo", model=name, pretrained=pretrained)
52
+ for idx, z in enumerate(z_strides):
53
+ assert z in [1, 2], "Only z-strides of 1 or 2 are supported"
54
+ if z == 2:
55
+ if idx == 0:
56
+ stem_layer = model.blocks[0].conv.conv_t
57
+ w = stem_layer.weight
58
+ w = w.repeat(1, 1, 3, 1, 1)
59
+ in_channels, out_channels = stem_layer.in_channels, stem_layer.out_channels
60
+ model.blocks[0].conv.conv_t = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
61
+ else:
62
+ model.blocks[idx].res_blocks[0].branch1_conv.stride = (2, 2, 2)
63
+ model.blocks[idx].res_blocks[0].branch2.conv_b.stride = (2, 2, 2)
64
+
65
+ if features_only:
66
+ model.blocks[-1] = nn.Identity()
67
+ model = X3D_Features(model)
68
+ else:
69
+ model.blocks[-1] = nn.Sequential(
70
+ model.blocks[-1].pool.pre_conv,
71
+ model.blocks[-1].pool.pre_norm,
72
+ model.blocks[-1].pool.pre_act,
73
+ )
74
+
75
+ return model
76
+
77
+
78
+ def create_i3d(name, pretrained, features_only=False, **kwargs):
79
+ from pytorchvideo.models import hub
80
+ model = getattr(hub, name)(pretrained=pretrained)
81
+ model.blocks[-1] = nn.Identity()
82
+ return model
83
+
84
+
85
+ class X3D_Features(nn.Module):
86
+
87
+ def __init__(self, model):
88
+ super().__init__()
89
+ self.model = model
90
+ self.out_channels = [24, 24, 48, 96, 192]
91
+
92
+ def forward(self, x):
93
+ features = []
94
+ for idx in range(len(self.model.blocks) - 1):
95
+ x = self.model.blocks[idx](x)
96
+ features.append(x)
97
+ return features
98
+
99
+
100
+ BACKBONES = {
101
+ "x3d_xs": partial(create_x3d, name="x3d_xs"),
102
+ "x3d_s": partial(create_x3d, name="x3d_s"),
103
+ "x3d_m": partial(create_x3d, name="x3d_m"),
104
+ "x3d_l": partial(create_x3d, name="x3d_l"),
105
+ "i3d_r50": partial(create_i3d, name="i3d_r50"),
106
+ "csn_r101": partial(create_csn, name="csn_r101"),
107
+ "ir_csn_50": ir_csn_50,
108
+ "ir_csn_101": ir_csn_101,
109
+ "ir_csn_152": ir_csn_152,
110
+ "ip_csn_50": ip_csn_50,
111
+ "ip_csn_101": ip_csn_101,
112
+ "ip_csn_152": ip_csn_152,
113
+ "r2plus1d_34": r2plus1d_34
114
+ }
skp/models/engine.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import re
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from pytorchvideo.models.x3d import create_x3d_stem
9
+ from timm.models.vision_transformer import VisionTransformer
10
+ from timm.models.swin_transformer_v2 import SwinTransformerV2
11
+ from . import backbones
12
+ from . import segmentation
13
+ from .pooling import create_pool2d_layer, create_pool3d_layer
14
+ from .sequence import Transformer, DualTransformer, DualTransformerV2
15
+ from .tools import change_initial_stride, change_num_input_channels
16
+
17
+
18
+ class Net2D(nn.Module):
19
+
20
+ def __init__(self,
21
+ backbone,
22
+ pretrained,
23
+ num_classes,
24
+ dropout,
25
+ pool,
26
+ in_channels=3,
27
+ change_stride=None,
28
+ feature_reduction=None,
29
+ multisample_dropout=False,
30
+ load_pretrained_backbone=None,
31
+ freeze_backbone=False,
32
+ backbone_params={},
33
+ pool_layer_params={}):
34
+
35
+ super().__init__()
36
+ self.backbone, dim_feats = backbones.create_backbone(name=backbone, pretrained=pretrained, **backbone_params)
37
+ if isinstance(pool, str):
38
+ self.pool_layer = create_pool2d_layer(name=pool, **pool_layer_params)
39
+ else:
40
+ self.pool_layer = nn.Identity()
41
+ if pool == "catavgmax":
42
+ dim_feats *= 2
43
+ self.msdo = multisample_dropout
44
+ if in_channels != 3:
45
+ self.backbone = change_num_input_channels(self.backbone, in_channels)
46
+ if change_stride:
47
+ self.backbone = change_initial_stride(self.backbone, tuple(change_stride), in_channels)
48
+ self.dropout = nn.Dropout(p=dropout)
49
+ if isinstance(feature_reduction, int):
50
+ # Use 1D grouped convolution to reduce # of parameters
51
+ groups = math.gcd(dim_feats, feature_reduction)
52
+ self.feature_reduction = nn.Conv1d(dim_feats, feature_reduction, groups=groups, kernel_size=1,
53
+ stride=1, bias=False)
54
+ dim_feats = feature_reduction
55
+ self.classifier = nn.Linear(dim_feats, num_classes)
56
+
57
+ if load_pretrained_backbone:
58
+ # Assumes that model has a `backbone` attribute
59
+ # Note: if you want to load the entire pretrained model, this is done via the
60
+ # builder.build_model function
61
+ print(f"Loading pretrained backbone from {load_pretrained_backbone} ...")
62
+ weights = torch.load(load_pretrained_backbone, map_location=lambda storage, loc: storage)['state_dict']
63
+ weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items()}
64
+ # Get feature_reduction, if present
65
+ feat_reduce_weight = {re.sub(r"^feature_reduction.", "", k): v
66
+ for k, v in weights.items() if "feature_reduction" in k}
67
+ # Get backbone only
68
+ weights = {re.sub(r'^backbone.', '', k) : v for k,v in weights.items() if 'backbone' in k}
69
+ self.backbone.load_state_dict(weights)
70
+ if len(feat_reduce_weight) > 0:
71
+ print("Also loading feature reduction layer ...")
72
+ self.feature_reduction.load_state_dict(feat_reduce_weight)
73
+
74
+ if freeze_backbone:
75
+ print("Freezing backbone ...")
76
+ for param in self.backbone.parameters():
77
+ param.requires_grad = False
78
+
79
+ def extract_features(self, x):
80
+ features = self.backbone(x)
81
+ features = self.pool_layer(features)
82
+ if isinstance(self.backbone, VisionTransformer):
83
+ features = features[:, self.backbone.num_prefix_tokens:].mean(dim=1)
84
+ if isinstance(self.backbone, SwinTransformerV2):
85
+ features = features.mean(dim=1)
86
+ if hasattr(self, "feature_reduction"):
87
+ features = self.feature_reduction(features.unsqueeze(-1)).squeeze(-1)
88
+ return features
89
+
90
+ def forward(self, x):
91
+ features = self.extract_features(x)
92
+ if self.msdo:
93
+ x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
94
+ else:
95
+ x = self.classifier(self.dropout(features))
96
+ # Important nuance:
97
+ # For binary classification, the model returns a tensor of shape (N,)
98
+ # Otherwise, (N,C)
99
+ return x[:, 0] if self.classifier.out_features == 1 else x
100
+
101
+
102
+ class SeqNet2D(Net2D):
103
+
104
+ def forward(self, x):
105
+ # x.shape = (N, C, Z, H, W)
106
+ features = torch.stack([self.extract_features(x[:, :, _]) for _ in range(x.size(2))], dim=2)
107
+ features = features.max(2)[0]
108
+
109
+ if self.msdo:
110
+ x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
111
+ else:
112
+ x = self.classifier(self.dropout(features))
113
+ # Important nuance:
114
+ # For binary classification, the model returns a tensor of shape (N,)
115
+ # Otherwise, (N,C)
116
+ return x[:, 0] if self.classifier.out_features == 1 else x
117
+
118
+
119
+ class TDCNN(nn.Module):
120
+
121
+ def __init__(self, cnn_params, transformer_params, freeze_cnn=False, freeze_transformer=False):
122
+ super().__init__()
123
+ self.cnn = Net2D(**cnn_params)
124
+ del self.cnn.dropout
125
+ del self.cnn.classifier
126
+ self.transformer = Transformer(**transformer_params)
127
+
128
+ if freeze_cnn:
129
+ for param in self.cnn.parameters():
130
+ param.requires_grad = False
131
+
132
+ if freeze_transformer:
133
+ for param in self.transformer.parameters():
134
+ param.requires_grad = False
135
+
136
+ def extract_features(self, x):
137
+ N, C, Z, H, W = x.size()
138
+ assert N == 1, "For feature extraction, batch size must be 1"
139
+ features = self.cnn.extract_features(x.squeeze(0).transpose(0, 1)).unsqueeze(0)
140
+ # features.shape = (1, Z, dim_feats)
141
+ return self.transformer.extract_features((features, torch.ones((features.size(0), features.size(1))).to(features.device)))
142
+
143
+ def forward(self, x):
144
+ # BCZHW
145
+ features = torch.stack([self.cnn.extract_features(x[:, :, i]) for i in range(x.size(2))], dim=1)
146
+ # B, seq_len, dim_feat
147
+ return self.transformer((features, torch.ones((features.size(0), features.size(1))).to(features.device)))
148
+
149
+
150
+ class Net2DWith3DStem(Net2D):
151
+
152
+ def __init__(self, *args, **kwargs):
153
+ stem_out_channels = kwargs.pop("stem_out_channels", 24)
154
+ load_pretrained_stem = kwargs.pop("load_pretrained_stem", None)
155
+ conv_kernel_size = tuple(kwargs.pop("conv_kernel_size", (5, 3, 3)))
156
+ conv_stride = tuple(kwargs.pop("conv_stride", (1, 2, 2)))
157
+ in_channels = kwargs.pop("in_channels", 3)
158
+ kwargs["in_channels"] = stem_out_channels
159
+ super().__init__(*args, **kwargs)
160
+ self.stem_layer = create_x3d_stem(in_channels=in_channels,
161
+ out_channels=stem_out_channels,
162
+ conv_kernel_size=conv_kernel_size,
163
+ conv_stride=conv_stride)
164
+ if kwargs["pretrained"]:
165
+ from pytorchvideo.models.hub import x3d_l
166
+ self.stem_layer.load_state_dict(x3d_l(pretrained=True).blocks[0].state_dict())
167
+
168
+ if load_pretrained_stem:
169
+ import re
170
+ print(f" Loading pretrained stem from {load_pretrained_stem} ...")
171
+ weights = torch.load(load_pretrained_stem, map_location=lambda storage, loc: storage)['state_dict']
172
+ stem_weights = {k.replace("model.backbone.blocks.0.", ""): v for k, v in weights.items() if "backbone.blocks.0" in k}
173
+ self.stem_layer.load_state_dict(stem_weights)
174
+
175
+ def forward(self, x):
176
+ x = self.stem_layer(x)
177
+ x = x.mean(3)
178
+ features = self.extract_features(x)
179
+ if self.msdo:
180
+ x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
181
+ else:
182
+ x = self.classifier(self.dropout(features))
183
+ # Important nuance:
184
+ # For binary classification, the model returns a tensor of shape (N,)
185
+ # Otherwise, (N,C)
186
+ return x[:, 0] if self.classifier.out_features == 1 else x
187
+
188
+
189
+ class Net3D(Net2D):
190
+
191
+ def __init__(self, *args, **kwargs):
192
+ z_strides = kwargs.pop("z_strides", [1,1,1,1,1])
193
+ super().__init__(*args, **kwargs)
194
+ self.pool_layer = create_pool3d_layer(name=kwargs["pool"], **kwargs.pop("pool_layer_params", {}))
195
+
196
+
197
+ class NetSegment2D(nn.Module):
198
+ """ For now, this class essentially servers as a wrapper for the
199
+ segmentation model which is mostly defined in the segmentation submodule,
200
+ similar to the original segmentation_models.pytorch.
201
+
202
+ It may be worth refactoring it in the future, such that you define this as
203
+ a general class, then select your choice of encoder and decoder. The encoder
204
+ is pretty much the same across all the segmentation models currently
205
+ implemented (DeepLabV3+, FPN, Unet).
206
+ """
207
+ def __init__(self,
208
+ architecture,
209
+ encoder_name,
210
+ encoder_params,
211
+ decoder_params,
212
+ num_classes,
213
+ dropout,
214
+ in_channels,
215
+ load_pretrained_encoder=None,
216
+ freeze_encoder=False,
217
+ deep_supervision=False,
218
+ pool_layer_params={},
219
+ aux_head_params={}):
220
+
221
+ super().__init__()
222
+
223
+ self.segmentation_model = getattr(segmentation, architecture)(
224
+ encoder_name=encoder_name,
225
+ encoder_params=encoder_params,
226
+ dropout=dropout,
227
+ classes=num_classes,
228
+ deep_supervision=deep_supervision,
229
+ in_channels=in_channels,
230
+ **decoder_params
231
+ )
232
+
233
+
234
+ if load_pretrained_encoder:
235
+ # Assumes that model has a `encoder` attribute
236
+ # Note: if you want to load the entire pretrained model, this is done via the
237
+ # builder.build_model function
238
+ print(f"Loading pretrained encoder from {load_pretrained_encoder} ...")
239
+ weights = torch.load(load_pretrained_encoder, map_location=lambda storage, loc: storage)['state_dict']
240
+ weights = {re.sub(r'^model.segmentation_model', '', k) : v for k,v in weights.items()}
241
+ # Get encoder only
242
+ weights = {re.sub(r'^encoder.', '', k) : v for k,v in weights.items() if 'backbone' in k}
243
+ self.segmentation_model.encoder.load_state_dict(weights)
244
+
245
+ if freeze_encoder:
246
+ print("Freezing encoder ...")
247
+ for param in self.segmentation_model.encoder.parameters():
248
+ param.requires_grad = False
249
+
250
+
251
+ def forward(self, x):
252
+ return self.segmentation_model(x)
253
+
254
+
255
+ class NetSegment3D(NetSegment2D):
256
+
257
+ pass
skp/models/pooling/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pool3d import create_pool3d_layer
2
+ from .pool2d import create_pool2d_layer
3
+ from .pool1d import create_pool1d_layer
skp/models/pooling/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (334 Bytes). View file
 
skp/models/pooling/__pycache__/gem.cpython-39.pyc ADDED
Binary file (1.67 kB). View file
 
skp/models/pooling/__pycache__/pool1d.cpython-39.pyc ADDED
Binary file (4.28 kB). View file
 
skp/models/pooling/__pycache__/pool2d.cpython-39.pyc ADDED
Binary file (678 Bytes). View file
 
skp/models/pooling/__pycache__/pool3d.cpython-39.pyc ADDED
Binary file (4.29 kB). View file
 
skp/models/pooling/gem.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # From: https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/layers/pooling.py
7
+ def gem_1d(x, p=3, eps=1e-6):
8
+ return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1),)).pow(1./p)
9
+
10
+
11
+ def gem_2d(x, p=3, eps=1e-6):
12
+ return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
13
+
14
+
15
+ def gem_3d(x, p=3, eps=1e-6):
16
+ return F.avg_pool3d(x.clamp(min=eps).pow(p), (x.size(-3), x.size(-2), x.size(-1))).pow(1./p)
17
+
18
+
19
+ _GEM_FN = {
20
+ 1: gem_1d, 2: gem_2d, 3: gem_3d
21
+ }
22
+
23
+
24
+ class GeM(nn.Module):
25
+
26
+ def __init__(self, p=3, eps=1e-6, dim=2):
27
+ super().__init__()
28
+ self.p = nn.Parameter(torch.ones(1)*p)
29
+ self.eps = eps
30
+ self.dim = dim
31
+ self.flatten = nn.Flatten(1)
32
+
33
+ def forward(self, x):
34
+ pooled = _GEM_FN[self.dim](x, p=self.p, eps=self.eps)
35
+ return self.flatten(pooled)
skp/models/pooling/pool1d.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .gem import GeM
6
+
7
+
8
+ def adaptive_avgmax_pool1d(x, output_size=1):
9
+ x_avg = F.adaptive_avg_pool1d(x, output_size)
10
+ x_max = F.adaptive_max_pool1d(x, output_size)
11
+ return 0.5 * (x_avg + x_max)
12
+
13
+
14
+ def adaptive_catavgmax_pool1d(x, output_size=1):
15
+ x_avg = F.adaptive_avg_pool1d(x, output_size)
16
+ x_max = F.adaptive_max_pool1d(x, output_size)
17
+ return torch.cat((x_avg, x_max), 1)
18
+
19
+
20
+ def select_adaptive_pool1d(x, pool_type='avg', output_size=1):
21
+ """Selectable global pooling function with dynamic input kernel size
22
+ """
23
+ if pool_type == 'avg':
24
+ x = F.adaptive_avg_pool1d(x, output_size)
25
+ elif pool_type == 'avgmax':
26
+ x = adaptive_avgmax_pool1d(x, output_size)
27
+ elif pool_type == 'catavgmax':
28
+ x = adaptive_catavgmax_pool1d(x, output_size)
29
+ elif pool_type == 'max':
30
+ x = F.adaptive_max_pool1d(x, output_size)
31
+ else:
32
+ assert False, 'Invalid pool type: %s' % pool_type
33
+ return x
34
+
35
+
36
+ class FastAdaptiveAvgPool1d(nn.Module):
37
+ def __init__(self, flatten=False):
38
+ super(FastAdaptiveAvgPool1d, self).__init__()
39
+ self.flatten = flatten
40
+
41
+ def forward(self, x):
42
+ return x.mean(2, keepdim=not self.flatten)
43
+
44
+
45
+ class AdaptiveAvgMaxPool1d(nn.Module):
46
+ def __init__(self, output_size=1):
47
+ super(AdaptiveAvgMaxPool1d, self).__init__()
48
+ self.output_size = output_size
49
+
50
+ def forward(self, x):
51
+ return adaptive_avgmax_pool1d(x, self.output_size)
52
+
53
+
54
+ class AdaptiveCatAvgMaxPool1d(nn.Module):
55
+ def __init__(self, output_size=1):
56
+ super(AdaptiveCatAvgMaxPool1d, self).__init__()
57
+ self.output_size = output_size
58
+
59
+ def forward(self, x):
60
+ return adaptive_catavgmax_pool1d(x, self.output_size)
61
+
62
+
63
+ class SelectAdaptivePool1d(nn.Module):
64
+ """Selectable global pooling layer with dynamic input kernel size
65
+ """
66
+ def __init__(self, output_size=1, pool_type='fast', flatten=False):
67
+ super(SelectAdaptivePool1d, self).__init__()
68
+ self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
69
+ self.flatten = nn.Flatten(1) if flatten else nn.Identity()
70
+ if pool_type == '':
71
+ self.pool = nn.Identity() # pass through
72
+ elif pool_type == 'fast':
73
+ assert output_size == 1
74
+ self.pool = FastAdaptiveAvgPool1d(flatten)
75
+ self.flatten = nn.Identity()
76
+ elif pool_type == 'avg':
77
+ self.pool = nn.AdaptiveAvgPool1d(output_size)
78
+ elif pool_type == 'avgmax':
79
+ self.pool = AdaptiveAvgMaxPool1d(output_size)
80
+ elif pool_type == 'catavgmax':
81
+ self.pool = AdaptiveCatAvgMaxPool1d(output_size)
82
+ elif pool_type == 'max':
83
+ self.pool = nn.AdaptiveMaxPool1d(output_size)
84
+ else:
85
+ assert False, 'Invalid pool type: %s' % pool_type
86
+
87
+ def is_identity(self):
88
+ return not self.pool_type
89
+
90
+ def forward(self, x):
91
+ x = self.pool(x)
92
+ x = self.flatten(x)
93
+ return x
94
+
95
+ def __repr__(self):
96
+ return self.__class__.__name__ + ' (' \
97
+ + 'pool_type=' + self.pool_type \
98
+ + ', flatten=' + str(self.flatten) + ')'
99
+
100
+
101
+ def create_pool1d_layer(name, **kwargs):
102
+ assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
103
+ if name != "gem":
104
+ pool1d_layer = SelectAdaptivePool1d(pool_type=name, flatten=True)
105
+ elif name == "gem":
106
+ pool1d_layer = GeM(dim=1, **kwargs)
107
+ return pool1d_layer
skp/models/pooling/pool2d.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from timm.models.layers import SelectAdaptivePool2d
6
+
7
+ from .gem import GeM
8
+
9
+
10
+ def create_pool2d_layer(name, **kwargs):
11
+ assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
12
+ if name != "gem":
13
+ pool2d_layer = SelectAdaptivePool2d(pool_type=name, flatten=True)
14
+ elif name == "gem":
15
+ pool2d_layer = GeM(dim=2, **kwargs)
16
+ return pool2d_layer
skp/models/pooling/pool3d.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .gem import GeM
6
+
7
+
8
+ def adaptive_avgmax_pool3d(x, output_size=1):
9
+ x_avg = F.adaptive_avg_pool3d(x, output_size)
10
+ x_max = F.adaptive_max_pool3d(x, output_size)
11
+ return 0.5 * (x_avg + x_max)
12
+
13
+
14
+ def adaptive_catavgmax_pool3d(x, output_size=1):
15
+ x_avg = F.adaptive_avg_pool3d(x, output_size)
16
+ x_max = F.adaptive_max_pool3d(x, output_size)
17
+ return torch.cat((x_avg, x_max), 1)
18
+
19
+
20
+ def select_adaptive_pool3d(x, pool_type='avg', output_size=1):
21
+ """Selectable global pooling function with dynamic input kernel size
22
+ """
23
+ if pool_type == 'avg':
24
+ x = F.adaptive_avg_pool3d(x, output_size)
25
+ elif pool_type == 'avgmax':
26
+ x = adaptive_avgmax_pool3d(x, output_size)
27
+ elif pool_type == 'catavgmax':
28
+ x = adaptive_catavgmax_pool3d(x, output_size)
29
+ elif pool_type == 'max':
30
+ x = F.adaptive_max_pool3d(x, output_size)
31
+ else:
32
+ assert False, 'Invalid pool type: %s' % pool_type
33
+ return x
34
+
35
+
36
+ class FastAdaptiveAvgPool3d(nn.Module):
37
+ def __init__(self, flatten=False):
38
+ super(FastAdaptiveAvgPool3d, self).__init__()
39
+ self.flatten = flatten
40
+
41
+ def forward(self, x):
42
+ return x.mean((2,3,4), keepdim=not self.flatten)
43
+
44
+
45
+ class AdaptiveAvgMaxPool3d(nn.Module):
46
+ def __init__(self, output_size=1):
47
+ super(AdaptiveAvgMaxPool3d, self).__init__()
48
+ self.output_size = output_size
49
+
50
+ def forward(self, x):
51
+ return adaptive_avgmax_pool3d(x, self.output_size)
52
+
53
+
54
+ class AdaptiveCatAvgMaxPool3d(nn.Module):
55
+ def __init__(self, output_size=1):
56
+ super(AdaptiveCatAvgMaxPool3d, self).__init__()
57
+ self.output_size = output_size
58
+
59
+ def forward(self, x):
60
+ return adaptive_catavgmax_pool3d(x, self.output_size)
61
+
62
+
63
+ class SelectAdaptivePool3d(nn.Module):
64
+ """Selectable global pooling layer with dynamic input kernel size
65
+ """
66
+ def __init__(self, output_size=1, pool_type='fast', flatten=False):
67
+ super(SelectAdaptivePool3d, self).__init__()
68
+ self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
69
+ self.flatten = nn.Flatten(1) if flatten else nn.Identity()
70
+ if pool_type == '':
71
+ self.pool = nn.Identity() # pass through
72
+ elif pool_type == 'fast':
73
+ assert output_size == 1
74
+ self.pool = FastAdaptiveAvgPool3d(flatten)
75
+ self.flatten = nn.Identity()
76
+ elif pool_type == 'avg':
77
+ self.pool = nn.AdaptiveAvgPool3d(output_size)
78
+ elif pool_type == 'avgmax':
79
+ self.pool = AdaptiveAvgMaxPool3d(output_size)
80
+ elif pool_type == 'catavgmax':
81
+ self.pool = AdaptiveCatAvgMaxPool3d(output_size)
82
+ elif pool_type == 'max':
83
+ self.pool = nn.AdaptiveMaxPool3d(output_size)
84
+ else:
85
+ assert False, 'Invalid pool type: %s' % pool_type
86
+
87
+ def is_identity(self):
88
+ return not self.pool_type
89
+
90
+ def forward(self, x):
91
+ x = self.pool(x)
92
+ x = self.flatten(x)
93
+ return x
94
+
95
+ def __repr__(self):
96
+ return self.__class__.__name__ + ' (' \
97
+ + 'pool_type=' + self.pool_type \
98
+ + ', flatten=' + str(self.flatten) + ')'
99
+
100
+
101
+ def create_pool3d_layer(name, **kwargs):
102
+ assert name in ["avg", "max", "fast", "avgmax", "catavgmax", "gem"]
103
+ if name != "gem":
104
+ pool1d_layer = SelectAdaptivePool3d(pool_type=name, flatten=True)
105
+ elif name == "gem":
106
+ pool1d_layer = GeM(dim=3, **kwargs)
107
+ return pool1d_layer
skp/models/rev_mvit/REV_MVIT_B_16_CONV.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: imagenet
4
+ BATCH_SIZE: 256
5
+ EVAL_PERIOD: 10
6
+ CHECKPOINT_PERIOD: 1
7
+ AUTO_RESUME: True
8
+
9
+ DATA:
10
+ # PATH_TO_DATA_DIR: path-to-imagenet-dir
11
+ MEAN: [0.485, 0.456, 0.406]
12
+ STD: [0.229, 0.224, 0.225]
13
+ NUM_FRAMES: 64
14
+ TRAIN_CROP_SIZE: 224
15
+ TEST_CROP_SIZE: 224
16
+ INPUT_CHANNEL_NUM: [3]
17
+ MVIT:
18
+ PATCH_2D: False
19
+ ZERO_DECAY_POS_CLS: False
20
+ MODE: "conv"
21
+ CLS_EMBED_ON: False
22
+ PATCH_KERNEL: [3, 7, 7]
23
+ PATCH_STRIDE: [2, 4, 4]
24
+ PATCH_PADDING: [1, 3, 3]
25
+ EMBED_DIM: 96
26
+ NUM_HEADS: 1
27
+ MLP_RATIO: 4.0
28
+ QKV_BIAS: True
29
+ DROPPATH_RATE: 0.1
30
+ DROPOUT_RATE: 0.0
31
+ DEPTH: 16
32
+ LAYER_SCALE_INIT_VALUE: 0.0
33
+ HEAD_INIT_SCALE: 1.0
34
+ USE_MEAN_POOLING: False
35
+ USE_ABS_POS: True
36
+ USE_FIXED_SINCOS_POS: False
37
+ SEP_POS_EMBED: False
38
+ REL_POS_SPATIAL: False
39
+ REL_POS_TEMPORAL: False
40
+ REL_POS_ZERO_INIT: False
41
+ RESIDUAL_POOLING: False
42
+ NORM: "layernorm"
43
+ NORM_STEM: False
44
+ DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]]
45
+ HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]]
46
+ POOL_FIRST: null
47
+ POOL_KVQ_KERNEL: [1, 3, 3]
48
+ POOL_KV_STRIDE_ADAPTIVE: [1, 4, 4]
49
+ POOL_Q_STRIDE: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]]
50
+ SEPARATE_QKV : True
51
+ REV:
52
+ ENABLE: True
53
+ RESPATH_FUSE: "concat"
54
+ BUFFER_LAYERS : [1,3, 14]
55
+ RES_PATH : "conv"
56
+ PRE_Q_FUSION: "concat_linear_2"
57
+ DETECTION:
58
+ ENABLE: False
59
+ AUG:
60
+ ENABLE: True
61
+ COLOR_JITTER: 0.4
62
+ AA_TYPE: rand-m9-n6-mstd0.5-inc1
63
+ INTERPOLATION: bicubic
64
+ RE_PROB: 0.25
65
+ RE_MODE: pixel
66
+ RE_COUNT: 1
67
+ RE_SPLIT: False
68
+ MIXUP:
69
+ ENABLE: True
70
+ ALPHA: 0.8
71
+ CUTMIX_ALPHA: 1.0
72
+ PROB: 1.0
73
+ SWITCH_PROB: 0.5
74
+ LABEL_SMOOTH_VALUE: 0.1
75
+ SOLVER:
76
+ BASE_LR_SCALE_NUM_SHARDS: True
77
+ BASE_LR: 0.00025
78
+ LR_POLICY: cosine
79
+ MAX_EPOCH: 300
80
+ MOMENTUM: 0.9
81
+ WEIGHT_DECAY: 0.05
82
+ WARMUP_EPOCHS: 70.0
83
+ WARMUP_START_LR: 1e-8
84
+ OPTIMIZING_METHOD: adamw
85
+ COSINE_AFTER_WARMUP: True
86
+ COSINE_END_LR: 1e-6
87
+ ZERO_WD_1D_PARAM: True
88
+ CLIP_GRAD_L2NORM: 1.0
89
+ MODEL:
90
+ NUM_CLASSES: 1000
91
+ ARCH: mvit
92
+ MODEL_NAME: MViT
93
+ LOSS_FUNC: soft_cross_entropy
94
+ DROPOUT_RATE: 0.0
95
+ HEAD_ACT: "softmax"
96
+ DETACH_FINAL_FC: False
97
+ CONTRASTIVE:
98
+ NUM_MLP_LAYERS: 1
99
+ TEST:
100
+ ENABLE: False
101
+ DATASET: imagenet
102
+ BATCH_SIZE: 256
103
+ DATA_LOADER:
104
+ NUM_WORKERS: 8
105
+ PIN_MEMORY: True
106
+ NUM_GPUS: 2
107
+ NUM_SHARDS: 1
108
+ RNG_SEED: 0
109
+ OUTPUT_DIR: .
skp/models/rev_mvit/__init__.py ADDED
File without changes
skp/models/rev_mvit/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (161 Bytes). View file
 
skp/models/rev_mvit/__pycache__/attention.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
skp/models/rev_mvit/__pycache__/batchnorm_helper.cpython-39.pyc ADDED
Binary file (3.65 kB). View file
 
skp/models/rev_mvit/__pycache__/common.cpython-39.pyc ADDED
Binary file (4.94 kB). View file
 
skp/models/rev_mvit/__pycache__/head_helper.cpython-39.pyc ADDED
Binary file (3.46 kB). View file
 
skp/models/rev_mvit/__pycache__/reversible_mvit.cpython-39.pyc ADDED
Binary file (13.8 kB). View file
 
skp/models/rev_mvit/__pycache__/stem_helper.cpython-39.pyc ADDED
Binary file (8.37 kB). View file
 
skp/models/rev_mvit/__pycache__/utils.cpython-39.pyc ADDED
Binary file (5.42 kB). View file
 
skp/models/rev_mvit/__pycache__/video_model_builder.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
skp/models/rev_mvit/attention.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+
5
+ import numpy
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.init import trunc_normal_
10
+
11
+ from .common import DropPath, Mlp
12
+
13
+
14
+ def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
15
+ if pool is None:
16
+ return tensor, thw_shape
17
+ tensor_dim = tensor.ndim
18
+ if tensor_dim == 4:
19
+ pass
20
+ elif tensor_dim == 3:
21
+ tensor = tensor.unsqueeze(1)
22
+ else:
23
+ raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
24
+
25
+ if has_cls_embed:
26
+ cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
27
+
28
+ B, N, L, C = tensor.shape
29
+ T, H, W = thw_shape
30
+ tensor = (
31
+ tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
32
+ )
33
+
34
+ tensor = pool(tensor)
35
+
36
+ thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
37
+ L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
38
+ tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
39
+ if has_cls_embed:
40
+ tensor = torch.cat((cls_tok, tensor), dim=2)
41
+ if norm is not None:
42
+ tensor = norm(tensor)
43
+ # Assert tensor_dim in [3, 4]
44
+ if tensor_dim == 4:
45
+ pass
46
+ else: # tensor_dim == 3:
47
+ tensor = tensor.squeeze(1)
48
+ return tensor, thw_shape
49
+
50
+
51
+ def get_rel_pos(rel_pos, d):
52
+ if isinstance(d, int):
53
+ ori_d = rel_pos.shape[0]
54
+ if ori_d == d:
55
+ return rel_pos
56
+ else:
57
+ # Interpolate rel pos.
58
+ new_pos_embed = F.interpolate(
59
+ rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1),
60
+ size=d,
61
+ mode="linear",
62
+ )
63
+
64
+ return new_pos_embed.reshape(-1, d).permute(1, 0)
65
+
66
+
67
+ def cal_rel_pos_spatial(
68
+ attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w
69
+ ):
70
+ """
71
+ Decomposed Spatial Relative Positional Embeddings.
72
+ """
73
+ sp_idx = 1 if has_cls_embed else 0
74
+ q_t, q_h, q_w = q_shape
75
+ k_t, k_h, k_w = k_shape
76
+ dh = int(2 * max(q_h, k_h) - 1)
77
+ dw = int(2 * max(q_w, k_w) - 1)
78
+
79
+ # Scale up rel pos if shapes for q and k are different.
80
+ q_h_ratio = max(k_h / q_h, 1.0)
81
+ k_h_ratio = max(q_h / k_h, 1.0)
82
+ dist_h = (
83
+ torch.arange(q_h)[:, None] * q_h_ratio
84
+ - torch.arange(k_h)[None, :] * k_h_ratio
85
+ )
86
+ dist_h += (k_h - 1) * k_h_ratio
87
+ q_w_ratio = max(k_w / q_w, 1.0)
88
+ k_w_ratio = max(q_w / k_w, 1.0)
89
+ dist_w = (
90
+ torch.arange(q_w)[:, None] * q_w_ratio
91
+ - torch.arange(k_w)[None, :] * k_w_ratio
92
+ )
93
+ dist_w += (k_w - 1) * k_w_ratio
94
+
95
+ # Intepolate rel pos if needed.
96
+ rel_pos_h = get_rel_pos(rel_pos_h, dh)
97
+ rel_pos_w = get_rel_pos(rel_pos_w, dw)
98
+ Rh = rel_pos_h[dist_h.long()]
99
+ Rw = rel_pos_w[dist_w.long()]
100
+
101
+ B, n_head, q_N, dim = q.shape
102
+
103
+ r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
104
+ rel_h_q = torch.einsum(
105
+ "bythwc,hkc->bythwk", r_q, Rh
106
+ ) # [B, H, q_t, qh, qw, k_h]
107
+ rel_w_q = torch.einsum(
108
+ "bythwc,wkc->bythwk", r_q, Rw
109
+ ) # [B, H, q_t, qh, qw, k_w]
110
+
111
+ attn[:, :, sp_idx:, sp_idx:] = (
112
+ attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
113
+ + rel_h_q[:, :, :, :, :, None, :, None]
114
+ + rel_w_q[:, :, :, :, :, None, None, :]
115
+ ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w)
116
+
117
+ return attn
118
+
119
+
120
+ def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t):
121
+ """
122
+ Temporal Relative Positional Embeddings.
123
+ """
124
+ sp_idx = 1 if has_cls_embed else 0
125
+ q_t, q_h, q_w = q_shape
126
+ k_t, k_h, k_w = k_shape
127
+ dt = int(2 * max(q_t, k_t) - 1)
128
+ # Intepolate rel pos if needed.
129
+ rel_pos_t = get_rel_pos(rel_pos_t, dt)
130
+
131
+ # Scale up rel pos if shapes for q and k are different.
132
+ q_t_ratio = max(k_t / q_t, 1.0)
133
+ k_t_ratio = max(q_t / k_t, 1.0)
134
+ dist_t = (
135
+ torch.arange(q_t)[:, None] * q_t_ratio
136
+ - torch.arange(k_t)[None, :] * k_t_ratio
137
+ )
138
+ dist_t += (k_t - 1) * k_t_ratio
139
+ Rt = rel_pos_t[dist_t.long()]
140
+
141
+ B, n_head, q_N, dim = q.shape
142
+
143
+ r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
144
+ # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
145
+ r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(
146
+ q_t, B * n_head * q_h * q_w, dim
147
+ )
148
+
149
+ # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
150
+ rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
151
+ # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
152
+ rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
153
+
154
+ attn[:, :, sp_idx:, sp_idx:] = (
155
+ attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
156
+ + rel[:, :, :, :, :, :, None, None]
157
+ ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w)
158
+
159
+ return attn
160
+
161
+
162
+ class MultiScaleAttention(nn.Module):
163
+ def __init__(
164
+ self,
165
+ dim,
166
+ dim_out,
167
+ input_size,
168
+ num_heads=8,
169
+ qkv_bias=False,
170
+ drop_rate=0.0,
171
+ kernel_q=(1, 1, 1),
172
+ kernel_kv=(1, 1, 1),
173
+ stride_q=(1, 1, 1),
174
+ stride_kv=(1, 1, 1),
175
+ norm_layer=nn.LayerNorm,
176
+ has_cls_embed=True,
177
+ # Options include `conv`, `avg`, and `max`.
178
+ mode="conv",
179
+ # If True, perform pool before projection.
180
+ pool_first=False,
181
+ rel_pos_spatial=False,
182
+ rel_pos_temporal=False,
183
+ rel_pos_zero_init=False,
184
+ residual_pooling=False,
185
+ separate_qkv=False,
186
+ ):
187
+ super().__init__()
188
+ self.pool_first = pool_first
189
+ self.separate_qkv = separate_qkv
190
+ self.drop_rate = drop_rate
191
+ self.num_heads = num_heads
192
+ self.dim_out = dim_out
193
+ head_dim = dim_out // num_heads
194
+ self.scale = head_dim**-0.5
195
+ self.has_cls_embed = has_cls_embed
196
+ self.mode = mode
197
+ padding_q = [int(q // 2) for q in kernel_q]
198
+ padding_kv = [int(kv // 2) for kv in kernel_kv]
199
+
200
+ if pool_first or separate_qkv:
201
+ self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
202
+ self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
203
+ self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
204
+ else:
205
+ self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
206
+
207
+ self.proj = nn.Linear(dim_out, dim_out)
208
+ if drop_rate > 0.0:
209
+ self.proj_drop = nn.Dropout(drop_rate)
210
+
211
+ # Skip pooling with kernel and stride size of (1, 1, 1).
212
+ if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1:
213
+ kernel_q = ()
214
+ if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1:
215
+ kernel_kv = ()
216
+
217
+ if mode in ("avg", "max"):
218
+ pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d
219
+ self.pool_q = (
220
+ pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
221
+ if len(kernel_q) > 0
222
+ else None
223
+ )
224
+ self.pool_k = (
225
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
226
+ if len(kernel_kv) > 0
227
+ else None
228
+ )
229
+ self.pool_v = (
230
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
231
+ if len(kernel_kv) > 0
232
+ else None
233
+ )
234
+ elif mode == "conv" or mode == "conv_unshared":
235
+ if pool_first:
236
+ dim_conv = dim // num_heads if mode == "conv" else dim
237
+ else:
238
+ dim_conv = dim_out // num_heads if mode == "conv" else dim_out
239
+ self.pool_q = (
240
+ nn.Conv3d(
241
+ dim_conv,
242
+ dim_conv,
243
+ kernel_q,
244
+ stride=stride_q,
245
+ padding=padding_q,
246
+ groups=dim_conv,
247
+ bias=False,
248
+ )
249
+ if len(kernel_q) > 0
250
+ else None
251
+ )
252
+ self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None
253
+ self.pool_k = (
254
+ nn.Conv3d(
255
+ dim_conv,
256
+ dim_conv,
257
+ kernel_kv,
258
+ stride=stride_kv,
259
+ padding=padding_kv,
260
+ groups=dim_conv,
261
+ bias=False,
262
+ )
263
+ if len(kernel_kv) > 0
264
+ else None
265
+ )
266
+ self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
267
+ self.pool_v = (
268
+ nn.Conv3d(
269
+ dim_conv,
270
+ dim_conv,
271
+ kernel_kv,
272
+ stride=stride_kv,
273
+ padding=padding_kv,
274
+ groups=dim_conv,
275
+ bias=False,
276
+ )
277
+ if len(kernel_kv) > 0
278
+ else None
279
+ )
280
+ self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
281
+ else:
282
+ raise NotImplementedError(f"Unsupported model {mode}")
283
+
284
+ self.rel_pos_spatial = rel_pos_spatial
285
+ self.rel_pos_temporal = rel_pos_temporal
286
+ if self.rel_pos_spatial:
287
+ assert input_size[1] == input_size[2]
288
+ size = input_size[1]
289
+ q_size = size // stride_q[1] if len(stride_q) > 0 else size
290
+ kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
291
+ rel_sp_dim = 2 * max(q_size, kv_size) - 1
292
+
293
+ self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
294
+ self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
295
+ if not rel_pos_zero_init:
296
+ trunc_normal_(self.rel_pos_h, std=0.02)
297
+ trunc_normal_(self.rel_pos_w, std=0.02)
298
+ if self.rel_pos_temporal:
299
+ self.rel_pos_t = nn.Parameter(
300
+ torch.zeros(2 * input_size[0] - 1, head_dim)
301
+ )
302
+ if not rel_pos_zero_init:
303
+ trunc_normal_(self.rel_pos_t, std=0.02)
304
+
305
+ self.residual_pooling = residual_pooling
306
+
307
+ def forward(self, x, thw_shape):
308
+ B, N, _ = x.shape
309
+
310
+ if self.pool_first:
311
+ if self.mode == "conv_unshared":
312
+ fold_dim = 1
313
+ else:
314
+ fold_dim = self.num_heads
315
+ x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
316
+ q = k = v = x
317
+ else:
318
+ assert self.mode != "conv_unshared"
319
+ if not self.separate_qkv:
320
+ qkv = (
321
+ self.qkv(x)
322
+ .reshape(B, N, 3, self.num_heads, -1)
323
+ .permute(2, 0, 3, 1, 4)
324
+ )
325
+ q, k, v = qkv[0], qkv[1], qkv[2]
326
+ else:
327
+ q = k = v = x
328
+ q = (
329
+ self.q(q)
330
+ .reshape(B, N, self.num_heads, -1)
331
+ .permute(0, 2, 1, 3)
332
+ )
333
+ k = (
334
+ self.k(k)
335
+ .reshape(B, N, self.num_heads, -1)
336
+ .permute(0, 2, 1, 3)
337
+ )
338
+ v = (
339
+ self.v(v)
340
+ .reshape(B, N, self.num_heads, -1)
341
+ .permute(0, 2, 1, 3)
342
+ )
343
+
344
+ q, q_shape = attention_pool(
345
+ q,
346
+ self.pool_q,
347
+ thw_shape,
348
+ has_cls_embed=self.has_cls_embed,
349
+ norm=self.norm_q if hasattr(self, "norm_q") else None,
350
+ )
351
+ k, k_shape = attention_pool(
352
+ k,
353
+ self.pool_k,
354
+ thw_shape,
355
+ has_cls_embed=self.has_cls_embed,
356
+ norm=self.norm_k if hasattr(self, "norm_k") else None,
357
+ )
358
+ v, v_shape = attention_pool(
359
+ v,
360
+ self.pool_v,
361
+ thw_shape,
362
+ has_cls_embed=self.has_cls_embed,
363
+ norm=self.norm_v if hasattr(self, "norm_v") else None,
364
+ )
365
+
366
+ if self.pool_first:
367
+ q_N = (
368
+ numpy.prod(q_shape) + 1
369
+ if self.has_cls_embed
370
+ else numpy.prod(q_shape)
371
+ )
372
+ k_N = (
373
+ numpy.prod(k_shape) + 1
374
+ if self.has_cls_embed
375
+ else numpy.prod(k_shape)
376
+ )
377
+ v_N = (
378
+ numpy.prod(v_shape) + 1
379
+ if self.has_cls_embed
380
+ else numpy.prod(v_shape)
381
+ )
382
+
383
+ q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
384
+ q = (
385
+ self.q(q)
386
+ .reshape(B, q_N, self.num_heads, -1)
387
+ .permute(0, 2, 1, 3)
388
+ )
389
+
390
+ v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
391
+ v = (
392
+ self.v(v)
393
+ .reshape(B, v_N, self.num_heads, -1)
394
+ .permute(0, 2, 1, 3)
395
+ )
396
+
397
+ k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
398
+ k = (
399
+ self.k(k)
400
+ .reshape(B, k_N, self.num_heads, -1)
401
+ .permute(0, 2, 1, 3)
402
+ )
403
+
404
+ N = q.shape[2]
405
+ attn = (q * self.scale) @ k.transpose(-2, -1)
406
+ if self.rel_pos_spatial:
407
+ attn = cal_rel_pos_spatial(
408
+ attn,
409
+ q,
410
+ k,
411
+ self.has_cls_embed,
412
+ q_shape,
413
+ k_shape,
414
+ self.rel_pos_h,
415
+ self.rel_pos_w,
416
+ )
417
+
418
+ if self.rel_pos_temporal:
419
+ attn = cal_rel_pos_temporal(
420
+ attn,
421
+ q,
422
+ self.has_cls_embed,
423
+ q_shape,
424
+ k_shape,
425
+ self.rel_pos_t,
426
+ )
427
+ attn = attn.softmax(dim=-1)
428
+
429
+ x = attn @ v
430
+
431
+ if self.residual_pooling:
432
+ if self.has_cls_embed:
433
+ x[:, :, 1:, :] += q[:, :, 1:, :]
434
+ else:
435
+ x = x + q
436
+
437
+ x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
438
+ x = self.proj(x)
439
+
440
+ if self.drop_rate > 0.0:
441
+ x = self.proj_drop(x)
442
+ return x, q_shape
443
+
444
+
445
+ class MultiScaleBlock(nn.Module):
446
+ def __init__(
447
+ self,
448
+ dim,
449
+ dim_out,
450
+ num_heads,
451
+ input_size,
452
+ mlp_ratio=4.0,
453
+ qkv_bias=False,
454
+ qk_scale=None,
455
+ drop_rate=0.0,
456
+ drop_path=0.0,
457
+ layer_scale_init_value=0.0,
458
+ act_layer=nn.GELU,
459
+ norm_layer=nn.LayerNorm,
460
+ up_rate=None,
461
+ kernel_q=(1, 1, 1),
462
+ kernel_kv=(1, 1, 1),
463
+ stride_q=(1, 1, 1),
464
+ stride_kv=(1, 1, 1),
465
+ mode="conv",
466
+ has_cls_embed=True,
467
+ pool_first=False,
468
+ rel_pos_spatial=False,
469
+ rel_pos_temporal=False,
470
+ rel_pos_zero_init=False,
471
+ residual_pooling=False,
472
+ dim_mul_in_att=False,
473
+ separate_qkv=False,
474
+ ):
475
+ super().__init__()
476
+ self.dim = dim
477
+ self.dim_out = dim_out
478
+ self.norm1 = norm_layer(dim)
479
+ self.dim_mul_in_att = dim_mul_in_att
480
+ kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
481
+ stride_skip = stride_q
482
+ padding_skip = [int(skip // 2) for skip in kernel_skip]
483
+ att_dim = dim_out if dim_mul_in_att else dim
484
+ self.attn = MultiScaleAttention(
485
+ dim,
486
+ att_dim,
487
+ num_heads=num_heads,
488
+ input_size=input_size,
489
+ qkv_bias=qkv_bias,
490
+ drop_rate=drop_rate,
491
+ kernel_q=kernel_q,
492
+ kernel_kv=kernel_kv,
493
+ stride_q=stride_q,
494
+ stride_kv=stride_kv,
495
+ norm_layer=norm_layer,
496
+ has_cls_embed=has_cls_embed,
497
+ mode=mode,
498
+ pool_first=pool_first,
499
+ rel_pos_spatial=rel_pos_spatial,
500
+ rel_pos_temporal=rel_pos_temporal,
501
+ rel_pos_zero_init=rel_pos_zero_init,
502
+ residual_pooling=residual_pooling,
503
+ separate_qkv=separate_qkv,
504
+ )
505
+ self.drop_path = (
506
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
507
+ )
508
+ self.norm2 = norm_layer(att_dim)
509
+ mlp_hidden_dim = int(att_dim * mlp_ratio)
510
+ self.has_cls_embed = has_cls_embed
511
+ # TODO: check the use case for up_rate, and merge the following lines
512
+ if up_rate is not None and up_rate > 1:
513
+ mlp_dim_out = dim * up_rate
514
+ else:
515
+ mlp_dim_out = dim_out
516
+ self.mlp = Mlp(
517
+ in_features=att_dim,
518
+ hidden_features=mlp_hidden_dim,
519
+ out_features=mlp_dim_out,
520
+ act_layer=act_layer,
521
+ drop_rate=drop_rate,
522
+ )
523
+ if layer_scale_init_value > 0:
524
+ self.gamma_1 = nn.Parameter(
525
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
526
+ )
527
+ self.gamma_2 = nn.Parameter(
528
+ layer_scale_init_value * torch.ones((dim_out)),
529
+ requires_grad=True,
530
+ )
531
+ else:
532
+ self.gamma_1, self.gamma_2 = None, None
533
+
534
+ if dim != dim_out:
535
+ self.proj = nn.Linear(dim, dim_out)
536
+
537
+ self.pool_skip = (
538
+ nn.MaxPool3d(
539
+ kernel_skip, stride_skip, padding_skip, ceil_mode=False
540
+ )
541
+ if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1
542
+ else None
543
+ )
544
+
545
+ def forward(self, x, thw_shape=None):
546
+ x_norm = self.norm1(x)
547
+ x_block, thw_shape_new = self.attn(x_norm, thw_shape)
548
+ if self.dim_mul_in_att and self.dim != self.dim_out:
549
+ x = self.proj(x_norm)
550
+ x_res, _ = attention_pool(
551
+ x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed
552
+ )
553
+ if self.gamma_1 is not None:
554
+ x = x_res + self.drop_path(self.gamma_1 * x_block)
555
+ else:
556
+ x = x_res + self.drop_path(x_block)
557
+ x_norm = self.norm2(x)
558
+ x_mlp = self.mlp(x_norm)
559
+ if not self.dim_mul_in_att and self.dim != self.dim_out:
560
+ x = self.proj(x_norm)
561
+ if self.gamma_2 is not None:
562
+ x = x + self.drop_path(self.gamma_2 * x_mlp)
563
+ else:
564
+ x = x + self.drop_path(x_mlp)
565
+ if thw_shape:
566
+ return x, thw_shape_new
567
+ else:
568
+ return x
skp/models/rev_mvit/batchnorm_helper.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+ """BatchNorm (BN) utility functions and custom batch-size BN implementations"""
5
+
6
+ from functools import partial
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from pytorchvideo.layers.batch_norm import (
11
+ NaiveSyncBatchNorm1d,
12
+ NaiveSyncBatchNorm3d,
13
+ ) # noqa
14
+
15
+
16
+ def get_norm(cfg):
17
+ """
18
+ Args:
19
+ cfg (CfgNode): model building configs, details are in the comments of
20
+ the config file.
21
+ Returns:
22
+ nn.Module: the normalization layer.
23
+ """
24
+ if cfg.BN.NORM_TYPE in {"batchnorm", "sync_batchnorm_apex"}:
25
+ return nn.BatchNorm3d
26
+ elif cfg.BN.NORM_TYPE == "sub_batchnorm":
27
+ return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS)
28
+ elif cfg.BN.NORM_TYPE == "sync_batchnorm":
29
+ return partial(
30
+ NaiveSyncBatchNorm3d,
31
+ num_sync_devices=cfg.BN.NUM_SYNC_DEVICES,
32
+ global_sync=cfg.BN.GLOBAL_SYNC,
33
+ )
34
+ else:
35
+ raise NotImplementedError(
36
+ "Norm type {} is not supported".format(cfg.BN.NORM_TYPE)
37
+ )
38
+
39
+
40
+ class SubBatchNorm3d(nn.Module):
41
+ """
42
+ The standard BN layer computes stats across all examples in a GPU. In some
43
+ cases it is desirable to compute stats across only a subset of examples
44
+ (e.g., in multigrid training https://arxiv.org/abs/1912.00998).
45
+ SubBatchNorm3d splits the batch dimension into N splits, and run BN on
46
+ each of them separately (so that the stats are computed on each subset of
47
+ examples (1/N of batch) independently. During evaluation, it aggregates
48
+ the stats from all splits into one BN.
49
+ """
50
+
51
+ def __init__(self, num_splits, **args):
52
+ """
53
+ Args:
54
+ num_splits (int): number of splits.
55
+ args (list): other arguments.
56
+ """
57
+ super(SubBatchNorm3d, self).__init__()
58
+ self.num_splits = num_splits
59
+ num_features = args["num_features"]
60
+ # Keep only one set of weight and bias.
61
+ if args.get("affine", True):
62
+ self.affine = True
63
+ args["affine"] = False
64
+ self.weight = torch.nn.Parameter(torch.ones(num_features))
65
+ self.bias = torch.nn.Parameter(torch.zeros(num_features))
66
+ else:
67
+ self.affine = False
68
+ self.bn = nn.BatchNorm3d(**args)
69
+ args["num_features"] = num_features * num_splits
70
+ self.split_bn = nn.BatchNorm3d(**args)
71
+
72
+ def _get_aggregated_mean_std(self, means, stds, n):
73
+ """
74
+ Calculate the aggregated mean and stds.
75
+ Args:
76
+ means (tensor): mean values.
77
+ stds (tensor): standard deviations.
78
+ n (int): number of sets of means and stds.
79
+ """
80
+ mean = means.view(n, -1).sum(0) / n
81
+ std = (
82
+ stds.view(n, -1).sum(0) / n
83
+ + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n
84
+ )
85
+ return mean.detach(), std.detach()
86
+
87
+ def aggregate_stats(self):
88
+ """
89
+ Synchronize running_mean, and running_var. Call this before eval.
90
+ """
91
+ if self.split_bn.track_running_stats:
92
+ (
93
+ self.bn.running_mean.data,
94
+ self.bn.running_var.data,
95
+ ) = self._get_aggregated_mean_std(
96
+ self.split_bn.running_mean,
97
+ self.split_bn.running_var,
98
+ self.num_splits,
99
+ )
100
+
101
+ def forward(self, x):
102
+ if self.training:
103
+ n, c, t, h, w = x.shape
104
+ x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
105
+ x = self.split_bn(x)
106
+ x = x.view(n, c, t, h, w)
107
+ else:
108
+ x = self.bn(x)
109
+ if self.affine:
110
+ x = x * self.weight.view((-1, 1, 1, 1))
111
+ x = x + self.bias.view((-1, 1, 1, 1))
112
+ return x
skp/models/rev_mvit/common.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Mlp(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_features,
11
+ hidden_features=None,
12
+ out_features=None,
13
+ act_layer=nn.GELU,
14
+ drop_rate=0.0,
15
+ ):
16
+ super().__init__()
17
+ self.drop_rate = drop_rate
18
+ out_features = out_features or in_features
19
+ hidden_features = hidden_features or in_features
20
+ self.fc1 = nn.Linear(in_features, hidden_features)
21
+ self.act = act_layer()
22
+ self.fc2 = nn.Linear(hidden_features, out_features)
23
+ if self.drop_rate > 0.0:
24
+ self.drop = nn.Dropout(drop_rate)
25
+
26
+ def forward(self, x):
27
+ x = self.fc1(x)
28
+ x = self.act(x)
29
+ if self.drop_rate > 0.0:
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ if self.drop_rate > 0.0:
33
+ x = self.drop(x)
34
+ return x
35
+
36
+
37
+ class Permute(nn.Module):
38
+ def __init__(self, dims):
39
+ super().__init__()
40
+ self.dims = dims
41
+
42
+ def forward(self, x):
43
+ return x.permute(*self.dims)
44
+
45
+
46
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
+ """
48
+ Stochastic Depth per sample.
49
+ """
50
+ if drop_prob == 0.0 or not training:
51
+ return x
52
+ keep_prob = 1 - drop_prob
53
+ shape = (x.shape[0],) + (1,) * (
54
+ x.ndim - 1
55
+ ) # work with diff dim tensors, not just 2D ConvNets
56
+ mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
57
+ mask.floor_() # binarize
58
+ output = x.div(keep_prob) * mask
59
+ return output
60
+
61
+
62
+ class DropPath(nn.Module):
63
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
64
+
65
+ def __init__(self, drop_prob=None):
66
+ super(DropPath, self).__init__()
67
+ self.drop_prob = drop_prob
68
+
69
+ def forward(self, x):
70
+ return drop_path(x, self.drop_prob, self.training)
71
+
72
+
73
+ class TwoStreamFusion(nn.Module):
74
+ def __init__(self, mode, dim=None, kernel=3, padding=1):
75
+ """
76
+ A general constructor for neural modules fusing two equal sized tensors
77
+ in forward. Following options are supported:
78
+
79
+ "add" / "max" / "min" / "avg" : respective operations on the two halves.
80
+ "concat" : NOOP.
81
+ "concat_linear_{dim_mult}_{drop_rate}" : MLP to fuse with hidden dim "dim_mult"
82
+ (optional, def 1.) higher than input dim
83
+ with optional dropout "drop_rate" (def: 0.)
84
+ "ln+concat_linear_{dim_mult}_{drop_rate}" : perform MLP after layernorm on the input.
85
+
86
+ """
87
+ super().__init__()
88
+ self.mode = mode
89
+ if mode == "add":
90
+ self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).sum(
91
+ dim=0
92
+ )
93
+ elif mode == "max":
94
+ self.fuse_fn = (
95
+ lambda x: torch.stack(torch.chunk(x, 2, dim=2))
96
+ .max(dim=0)
97
+ .values
98
+ )
99
+ elif mode == "min":
100
+ self.fuse_fn = (
101
+ lambda x: torch.stack(torch.chunk(x, 2, dim=2))
102
+ .min(dim=0)
103
+ .values
104
+ )
105
+ elif mode == "avg":
106
+ self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).mean(
107
+ dim=0
108
+ )
109
+ elif mode == "concat":
110
+ # x itself is the channel concat version
111
+ self.fuse_fn = lambda x: x
112
+ elif "concat_linear" in mode:
113
+ if len(mode.split("_")) == 2:
114
+ dim_mult = 1.0
115
+ drop_rate = 0.0
116
+ elif len(mode.split("_")) == 3:
117
+ dim_mult = float(mode.split("_")[-1])
118
+ drop_rate = 0.0
119
+
120
+ elif len(mode.split("_")) == 4:
121
+ dim_mult = float(mode.split("_")[-2])
122
+ drop_rate = float(mode.split("_")[-1])
123
+ else:
124
+ raise NotImplementedError
125
+
126
+ if mode.split("+")[0] == "ln":
127
+ self.fuse_fn = nn.Sequential(
128
+ nn.LayerNorm(dim),
129
+ Mlp(
130
+ in_features=dim,
131
+ hidden_features=int(dim * dim_mult),
132
+ act_layer=nn.GELU,
133
+ out_features=dim,
134
+ drop_rate=drop_rate,
135
+ ),
136
+ )
137
+ else:
138
+ self.fuse_fn = Mlp(
139
+ in_features=dim,
140
+ hidden_features=int(dim * dim_mult),
141
+ act_layer=nn.GELU,
142
+ out_features=dim,
143
+ drop_rate=drop_rate,
144
+ )
145
+
146
+ else:
147
+ raise NotImplementedError
148
+
149
+ def forward(self, x):
150
+ if "concat_linear" in self.mode:
151
+ return self.fuse_fn(x) + x
152
+
153
+ else:
154
+ return self.fuse_fn(x)
skp/models/rev_mvit/head_helper.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+ """ResNe(X)t Head helper."""
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .batchnorm_helper import (
10
+ NaiveSyncBatchNorm1d as NaiveSyncBatchNorm1d,
11
+ )
12
+
13
+
14
+ class MLPHead(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim_in,
18
+ dim_out,
19
+ mlp_dim,
20
+ num_layers,
21
+ bn_on=False,
22
+ bias=True,
23
+ flatten=False,
24
+ xavier_init=True,
25
+ bn_sync_num=1,
26
+ global_sync=False,
27
+ ):
28
+ super(MLPHead, self).__init__()
29
+ self.flatten = flatten
30
+ b = False if bn_on else bias
31
+ # assert bn_on or bn_sync_num=1
32
+ mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)]
33
+ mlp_layers[-1].xavier_init = xavier_init
34
+ for i in range(1, num_layers):
35
+ if bn_on:
36
+ if global_sync or bn_sync_num > 1:
37
+ mlp_layers.append(
38
+ NaiveSyncBatchNorm1d(
39
+ num_sync_devices=bn_sync_num,
40
+ global_sync=global_sync,
41
+ num_features=mlp_dim,
42
+ )
43
+ )
44
+ else:
45
+ mlp_layers.append(nn.BatchNorm1d(num_features=mlp_dim))
46
+ mlp_layers.append(nn.ReLU(inplace=True))
47
+ if i == num_layers - 1:
48
+ d = dim_out
49
+ b = bias
50
+ else:
51
+ d = mlp_dim
52
+ mlp_layers.append(nn.Linear(mlp_dim, d, bias=b))
53
+ mlp_layers[-1].xavier_init = xavier_init
54
+ self.projection = nn.Sequential(*mlp_layers)
55
+
56
+ def forward(self, x):
57
+ if x.ndim == 5:
58
+ x = x.permute((0, 2, 3, 4, 1))
59
+ if self.flatten:
60
+ x = x.reshape(-1, x.shape[-1])
61
+
62
+ return self.projection(x)
63
+
64
+
65
+ class TransformerBasicHead(nn.Module):
66
+ """
67
+ BasicHead. No pool.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ dim_in,
73
+ num_classes,
74
+ dropout_rate=0.0,
75
+ act_func="softmax",
76
+ cfg=None,
77
+ ):
78
+ """
79
+ Perform linear projection and activation as head for tranformers.
80
+ Args:
81
+ dim_in (int): the channel dimension of the input to the head.
82
+ num_classes (int): the channel dimensions of the output to the head.
83
+ dropout_rate (float): dropout rate. If equal to 0.0, perform no
84
+ dropout.
85
+ act_func (string): activation function to use. 'softmax': applies
86
+ softmax on the output. 'sigmoid': applies sigmoid on the output.
87
+ """
88
+ super(TransformerBasicHead, self).__init__()
89
+ if dropout_rate > 0.0:
90
+ self.dropout = nn.Dropout(dropout_rate)
91
+ self.projection = nn.Linear(dim_in, num_classes, bias=True)
92
+
93
+ if cfg.CONTRASTIVE.NUM_MLP_LAYERS == 1:
94
+ self.projection = nn.Linear(dim_in, num_classes, bias=True)
95
+ else:
96
+ self.projection = MLPHead(
97
+ dim_in,
98
+ num_classes,
99
+ cfg.CONTRASTIVE.MLP_DIM,
100
+ cfg.CONTRASTIVE.NUM_MLP_LAYERS,
101
+ bn_on=cfg.CONTRASTIVE.BN_MLP,
102
+ bn_sync_num=cfg.BN.NUM_SYNC_DEVICES
103
+ if cfg.CONTRASTIVE.BN_SYNC_MLP
104
+ else 1,
105
+ global_sync=(
106
+ cfg.CONTRASTIVE.BN_SYNC_MLP and cfg.BN.GLOBAL_SYNC
107
+ ),
108
+ )
109
+ self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC
110
+
111
+ # Softmax for evaluation and testing.
112
+ if act_func == "softmax":
113
+ self.act = nn.Softmax(dim=1)
114
+ elif act_func == "sigmoid":
115
+ self.act = nn.Sigmoid()
116
+ elif act_func == "none":
117
+ self.act = None
118
+ else:
119
+ raise NotImplementedError(
120
+ "{} is not supported as an activation"
121
+ "function.".format(act_func)
122
+ )
123
+
124
+ def forward(self, x):
125
+ if hasattr(self, "dropout"):
126
+ x = self.dropout(x)
127
+ if self.detach_final_fc:
128
+ x = x.detach()
129
+ x = self.projection(x)
130
+
131
+ if not self.training:
132
+ if self.act is not None:
133
+ x = self.act(x)
134
+ # Performs fully convolutional inference.
135
+ if x.ndim == 5 and x.shape[1:4] > torch.Size([1, 1, 1]):
136
+ x = x.mean([1, 2, 3])
137
+
138
+ x = x.view(x.shape[0], -1)
139
+
140
+ return x
skp/models/rev_mvit/reversible_mvit.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from functools import partial
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function as Function
6
+
7
+ from .attention import MultiScaleAttention, attention_pool
8
+ from .common import Mlp, TwoStreamFusion, drop_path
9
+ from .utils import round_width
10
+
11
+
12
+ class ReversibleMViT(nn.Module):
13
+ """
14
+ Reversible model builder. This builds the reversible transformer encoder
15
+ and allows reversible training.
16
+
17
+ Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong,
18
+ Christoph Feichtenhofer, Jitendra Malik
19
+ "Reversible Vision Transformers"
20
+
21
+ https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf
22
+ """
23
+
24
+ def __init__(self, config, model):
25
+ """
26
+ The `__init__` method of any subclass should also contain these
27
+ arguments.
28
+ Args:
29
+ cfg (CfgNode): model building configs, details are in the
30
+ comments of the config file.
31
+ model (nn.Module): parent MViT module this module forms
32
+ a reversible encoder in.
33
+ """
34
+
35
+ super().__init__()
36
+ self.cfg = config
37
+
38
+ embed_dim = self.cfg.MVIT.EMBED_DIM
39
+ depth = self.cfg.MVIT.DEPTH
40
+ num_heads = self.cfg.MVIT.NUM_HEADS
41
+ mlp_ratio = self.cfg.MVIT.MLP_RATIO
42
+ qkv_bias = self.cfg.MVIT.QKV_BIAS
43
+
44
+ drop_path_rate = self.cfg.MVIT.DROPPATH_RATE
45
+ self.dropout = config.MVIT.DROPOUT_RATE
46
+ self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION
47
+ dpr = [
48
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
49
+ ] # stochastic depth decay rule
50
+
51
+ input_size = model.patch_dims
52
+
53
+ self.layers = nn.ModuleList([])
54
+ self.no_custom_backward = False
55
+
56
+ if self.cfg.MVIT.NORM == "layernorm":
57
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
58
+ else:
59
+ raise NotImplementedError("Only supports layernorm.")
60
+
61
+ dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
62
+ for i in range(len(self.cfg.MVIT.DIM_MUL)):
63
+ dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1]
64
+ for i in range(len(self.cfg.MVIT.HEAD_MUL)):
65
+ head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][
66
+ 1
67
+ ]
68
+
69
+ pool_q = model.pool_q
70
+ pool_kv = model.pool_kv
71
+ stride_q = model.stride_q
72
+ stride_kv = model.stride_kv
73
+
74
+ for i in range(depth):
75
+
76
+ num_heads = round_width(num_heads, head_mul[i])
77
+
78
+ # Upsampling inside the MHPA, input to the Q-pooling block is lower C dimension
79
+ # This localizes the feature changes in a single block, making more computation reversible.
80
+ embed_dim = round_width(
81
+ embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads
82
+ )
83
+ dim_out = round_width(
84
+ embed_dim,
85
+ dim_mul[i],
86
+ divisor=round_width(num_heads, head_mul[i + 1]),
87
+ )
88
+
89
+ if i in self.cfg.MVIT.REV.BUFFER_LAYERS:
90
+ layer_type = StageTransitionBlock
91
+ input_mult = 2 if "concat" in self.pre_q_fusion else 1
92
+ else:
93
+ layer_type = ReversibleBlock
94
+ input_mult = 1
95
+
96
+ dimout_correction = (
97
+ 2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1
98
+ )
99
+
100
+ self.layers.append(
101
+ layer_type(
102
+ dim=embed_dim
103
+ * input_mult, # added only for concat fusion before Qpooling layers
104
+ input_size=input_size,
105
+ dim_out=dim_out * input_mult // dimout_correction,
106
+ num_heads=num_heads,
107
+ cfg=self.cfg,
108
+ mlp_ratio=mlp_ratio,
109
+ qkv_bias=qkv_bias,
110
+ drop_path=dpr[i],
111
+ norm_layer=norm_layer,
112
+ kernel_q=pool_q[i] if len(pool_q) > i else [],
113
+ kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
114
+ stride_q=stride_q[i] if len(stride_q) > i else [],
115
+ stride_kv=stride_kv[i] if len(stride_kv) > i else [],
116
+ layer_id=i,
117
+ pre_q_fusion=self.pre_q_fusion,
118
+ )
119
+ )
120
+ # F is the attention block
121
+ self.layers[-1].F.thw = input_size
122
+
123
+ if len(stride_q[i]) > 0:
124
+ input_size = [
125
+ size // stride
126
+ for size, stride in zip(input_size, stride_q[i])
127
+ ]
128
+
129
+ embed_dim = dim_out
130
+
131
+ @staticmethod
132
+ def vanilla_backward(h, layers, buffer):
133
+ """
134
+ Using rev layers without rev backpropagation. Debugging purposes only.
135
+ Activated with self.no_custom_backward.
136
+ """
137
+
138
+ # split into hidden states (h) and attention_output (a)
139
+ h, a = torch.chunk(h, 2, dim=-1)
140
+ for _, layer in enumerate(layers):
141
+ a, h = layer(a, h)
142
+
143
+ return torch.cat([a, h], dim=-1)
144
+
145
+ def forward(self, x):
146
+
147
+ # process the layers in a reversible stack and an irreversible stack.
148
+ stack = []
149
+ for l_i in range(len(self.layers)):
150
+ if isinstance(self.layers[l_i], StageTransitionBlock):
151
+ stack.append(("StageTransition", l_i))
152
+ else:
153
+ if len(stack) == 0 or stack[-1][0] == "StageTransition":
154
+ stack.append(("Reversible", []))
155
+ stack[-1][1].append(l_i)
156
+
157
+ for layer_seq in stack:
158
+
159
+ if layer_seq[0] == "StageTransition":
160
+ x = self.layers[layer_seq[1]](x)
161
+
162
+ else:
163
+ x = torch.cat([x, x], dim=-1)
164
+
165
+ # no need for custom backprop in eval/model stat log
166
+ if not self.training or self.no_custom_backward:
167
+ executing_fn = ReversibleMViT.vanilla_backward
168
+ else:
169
+ executing_fn = RevBackProp.apply
170
+
171
+ x = executing_fn(
172
+ x,
173
+ self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1],
174
+ [], # buffer activations
175
+ )
176
+
177
+ # Apply dropout
178
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
179
+
180
+ return x
181
+
182
+
183
+ class RevBackProp(Function):
184
+ """
185
+ Custom Backpropagation function to allow (A) flusing memory in foward
186
+ and (B) activation recomputation reversibly in backward for gradient calculation.
187
+
188
+ Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
189
+ """
190
+
191
+ @staticmethod
192
+ def forward(
193
+ ctx,
194
+ x,
195
+ layers,
196
+ buffer_layers, # List of layer ids for int activation to buffer
197
+ ):
198
+ """
199
+ Reversible Forward pass. Any intermediate activations from `buffer_layers` are
200
+ cached in ctx for forward pass. This is not necessary for standard usecases.
201
+ Each reversible layer implements its own forward pass logic.
202
+ """
203
+ buffer_layers.sort()
204
+
205
+ X_1, X_2 = torch.chunk(x, 2, dim=-1)
206
+
207
+ intermediate = []
208
+
209
+ for layer in layers:
210
+
211
+ X_1, X_2 = layer(X_1, X_2)
212
+
213
+ if layer.layer_id in buffer_layers:
214
+ intermediate.extend([X_1.detach(), X_2.detach()])
215
+
216
+ if len(buffer_layers) == 0:
217
+ all_tensors = [X_1.detach(), X_2.detach()]
218
+ else:
219
+ intermediate = [torch.LongTensor(buffer_layers), *intermediate]
220
+ all_tensors = [X_1.detach(), X_2.detach(), *intermediate]
221
+
222
+ ctx.save_for_backward(*all_tensors)
223
+ ctx.layers = layers
224
+
225
+ return torch.cat([X_1, X_2], dim=-1)
226
+
227
+ @staticmethod
228
+ def backward(ctx, dx):
229
+ """
230
+ Reversible Backward pass. Any intermediate activations from `buffer_layers` are
231
+ recovered from ctx. Each layer implements its own loic for backward pass (both
232
+ activation recomputation and grad calculation).
233
+ """
234
+ dX_1, dX_2 = torch.chunk(dx, 2, dim=-1)
235
+
236
+ # retrieve params from ctx for backward
237
+ X_1, X_2, *int_tensors = ctx.saved_tensors
238
+
239
+ # no buffering
240
+ if len(int_tensors) != 0:
241
+ buffer_layers = int_tensors[0].tolist()
242
+
243
+ else:
244
+ buffer_layers = []
245
+
246
+ layers = ctx.layers
247
+
248
+ for _, layer in enumerate(layers[::-1]):
249
+
250
+ if layer.layer_id in buffer_layers:
251
+
252
+ X_1, X_2, dX_1, dX_2 = layer.backward_pass(
253
+ Y_1=int_tensors[
254
+ buffer_layers.index(layer.layer_id) * 2 + 1
255
+ ],
256
+ Y_2=int_tensors[
257
+ buffer_layers.index(layer.layer_id) * 2 + 2
258
+ ],
259
+ dY_1=dX_1,
260
+ dY_2=dX_2,
261
+ )
262
+
263
+ else:
264
+
265
+ X_1, X_2, dX_1, dX_2 = layer.backward_pass(
266
+ Y_1=X_1,
267
+ Y_2=X_2,
268
+ dY_1=dX_1,
269
+ dY_2=dX_2,
270
+ )
271
+
272
+ dx = torch.cat([dX_1, dX_2], dim=-1)
273
+
274
+ del int_tensors
275
+ del dX_1, dX_2, X_1, X_2
276
+
277
+ return dx, None, None
278
+
279
+
280
+ class StageTransitionBlock(nn.Module):
281
+ """
282
+ Blocks for changing the feature dimensions in MViT (using Q-pooling).
283
+ See Section 3.3.1 in paper for details.
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ dim,
289
+ input_size,
290
+ dim_out,
291
+ num_heads,
292
+ mlp_ratio,
293
+ qkv_bias,
294
+ drop_path,
295
+ kernel_q,
296
+ kernel_kv,
297
+ stride_q,
298
+ stride_kv,
299
+ cfg,
300
+ norm_layer=nn.LayerNorm,
301
+ pre_q_fusion=None,
302
+ layer_id=0,
303
+ ):
304
+ """
305
+ Uses the same structure of F and G functions as Reversible Block except
306
+ without using reversible forward (and backward) pass.
307
+ """
308
+ super().__init__()
309
+
310
+ self.drop_path_rate = drop_path
311
+
312
+ embed_dim = dim
313
+
314
+ self.F = AttentionSubBlock(
315
+ dim=embed_dim,
316
+ input_size=input_size,
317
+ num_heads=num_heads,
318
+ cfg=cfg,
319
+ dim_out=dim_out,
320
+ kernel_q=kernel_q,
321
+ kernel_kv=kernel_kv,
322
+ stride_q=stride_q,
323
+ stride_kv=stride_kv,
324
+ norm_layer=norm_layer,
325
+ )
326
+
327
+ self.G = MLPSubblock(
328
+ dim=dim_out,
329
+ mlp_ratio=mlp_ratio,
330
+ norm_layer=norm_layer,
331
+ )
332
+
333
+ self.layer_id = layer_id
334
+
335
+ self.is_proj = False
336
+ self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON
337
+
338
+ self.is_conv = False
339
+ self.pool_first = cfg.MVIT.POOL_FIRST
340
+ self.mode = cfg.MVIT.MODE
341
+ self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim)
342
+
343
+ if cfg.MVIT.REV.RES_PATH == "max":
344
+ self.res_conv = False
345
+ self.pool_skip = nn.MaxPool3d(
346
+ # self.attention.attn.pool_q.kernel_size,
347
+ [s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride],
348
+ self.F.attn.pool_q.stride,
349
+ [int(k // 2) for k in self.F.attn.pool_q.stride],
350
+ # self.attention.attn.pool_q.padding,
351
+ ceil_mode=False,
352
+ )
353
+
354
+ elif cfg.MVIT.REV.RES_PATH == "conv":
355
+ self.res_conv = True
356
+ else:
357
+ raise NotImplementedError
358
+
359
+ # Add a linear projection in residual branch
360
+ if embed_dim != dim_out:
361
+ self.is_proj = True
362
+ self.res_proj = nn.Linear(embed_dim, dim_out, bias=True)
363
+
364
+ def forward(
365
+ self,
366
+ x,
367
+ ):
368
+ """
369
+ Forward logic is similar to MultiScaleBlock with Q-pooling.
370
+ """
371
+ x = self.pre_q_fuse(x)
372
+
373
+ # fork tensor for residual connections
374
+ x_res = x
375
+
376
+ # This uses conv to pool the residual hidden features
377
+ # but done before pooling only if not pool_first
378
+ if self.is_proj and not self.pool_first:
379
+ x_res = self.res_proj(x_res)
380
+
381
+ if self.res_conv:
382
+
383
+ # Pooling the hidden features with the same conv as Q
384
+ N, L, C = x_res.shape
385
+
386
+ # This handling is the same as that of q in MultiScaleAttention
387
+ if self.mode == "conv_unshared":
388
+ fold_dim = 1
389
+ else:
390
+ fold_dim = self.F.attn.num_heads
391
+
392
+ # Output is (B, N, L, C)
393
+ x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute(
394
+ 0, 2, 1, 3
395
+ )
396
+
397
+ x_res, _ = attention_pool(
398
+ x_res,
399
+ self.F.attn.pool_q,
400
+ # thw_shape = self.attention.attn.thw,
401
+ thw_shape=self.F.thw,
402
+ has_cls_embed=self.has_cls_embed,
403
+ norm=self.F.attn.norm_q
404
+ if hasattr(self.F.attn, "norm_q")
405
+ else None,
406
+ )
407
+ x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C)
408
+
409
+ else:
410
+ # Pooling the hidden features with max op
411
+ x_res, _ = attention_pool(
412
+ x_res,
413
+ self.pool_skip,
414
+ thw_shape=self.F.attn.thw,
415
+ has_cls_embed=self.has_cls_embed,
416
+ )
417
+
418
+ # If pool_first then project to higher dim now
419
+ if self.is_proj and self.pool_first:
420
+ x_res = self.res_proj(x_res)
421
+
422
+ x = self.F(x)
423
+ x = x_res + x
424
+ x = x + self.G(x)
425
+
426
+ x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training)
427
+
428
+ return x
429
+
430
+
431
+ class ReversibleBlock(nn.Module):
432
+ """
433
+ Reversible Blocks for Reversible Vision Transformer and also
434
+ for state-preserving blocks in Reversible MViT. See Section
435
+ 3.3.2 in paper for details.
436
+ """
437
+
438
+ def __init__(
439
+ self,
440
+ dim,
441
+ input_size,
442
+ dim_out,
443
+ num_heads,
444
+ mlp_ratio,
445
+ qkv_bias,
446
+ drop_path,
447
+ kernel_q,
448
+ kernel_kv,
449
+ stride_q,
450
+ stride_kv,
451
+ cfg,
452
+ norm_layer=nn.LayerNorm,
453
+ layer_id=0,
454
+ **kwargs
455
+ ):
456
+ """
457
+ Block is composed entirely of function F (Attention
458
+ sub-block) and G (MLP sub-block) including layernorm.
459
+ """
460
+ super().__init__()
461
+
462
+ self.drop_path_rate = drop_path
463
+
464
+ self.F = AttentionSubBlock(
465
+ dim=dim,
466
+ input_size=input_size,
467
+ num_heads=num_heads,
468
+ cfg=cfg,
469
+ dim_out=dim_out,
470
+ kernel_q=kernel_q,
471
+ kernel_kv=kernel_kv,
472
+ stride_q=stride_q,
473
+ stride_kv=stride_kv,
474
+ norm_layer=norm_layer,
475
+ )
476
+
477
+ self.G = MLPSubblock(
478
+ dim=dim,
479
+ mlp_ratio=mlp_ratio,
480
+ norm_layer=norm_layer,
481
+ )
482
+
483
+ self.layer_id = layer_id
484
+
485
+ self.seeds = {}
486
+
487
+ def seed_cuda(self, key):
488
+ """
489
+ Fix seeds to allow for stochastic elements such as
490
+ dropout to be reproduced exactly in activation
491
+ recomputation in the backward pass.
492
+ """
493
+
494
+ # randomize seeds
495
+ # use cuda generator if available
496
+ if (
497
+ hasattr(torch.cuda, "default_generators")
498
+ and len(torch.cuda.default_generators) > 0
499
+ ):
500
+ # GPU
501
+ device_idx = torch.cuda.current_device()
502
+ seed = torch.cuda.default_generators[device_idx].seed()
503
+ else:
504
+ # CPU
505
+ seed = int(torch.seed() % sys.maxsize)
506
+
507
+ self.seeds[key] = seed
508
+ torch.manual_seed(self.seeds[key])
509
+
510
+ def forward(self, X_1, X_2):
511
+ """
512
+ forward pass equations:
513
+ Y_1 = X_1 + Attention(X_2), F = Attention
514
+ Y_2 = X_2 + MLP(Y_1), G = MLP
515
+ """
516
+
517
+ self.seed_cuda("attn")
518
+ # Y_1 : attn_output
519
+ f_X_2 = self.F(X_2)
520
+
521
+ self.seed_cuda("droppath")
522
+ f_X_2_dropped = drop_path(
523
+ f_X_2, drop_prob=self.drop_path_rate, training=self.training
524
+ )
525
+
526
+ # Y_1 = X_1 + f(X_2)
527
+ Y_1 = X_1 + f_X_2_dropped
528
+
529
+ # free memory
530
+ del X_1
531
+
532
+ self.seed_cuda("FFN")
533
+ g_Y_1 = self.G(Y_1)
534
+
535
+ torch.manual_seed(self.seeds["droppath"])
536
+ g_Y_1_dropped = drop_path(
537
+ g_Y_1, drop_prob=self.drop_path_rate, training=self.training
538
+ )
539
+
540
+ # Y_2 = X_2 + g(Y_1)
541
+ Y_2 = X_2 + g_Y_1_dropped
542
+
543
+ del X_2
544
+
545
+ return Y_1, Y_2
546
+
547
+ def backward_pass(
548
+ self,
549
+ Y_1,
550
+ Y_2,
551
+ dY_1,
552
+ dY_2,
553
+ ):
554
+ """
555
+ equation for activation recomputation:
556
+ X_2 = Y_2 - G(Y_1), G = MLP
557
+ X_1 = Y_1 - F(X_2), F = Attention
558
+ """
559
+
560
+ # temporarily record intermediate activation for G
561
+ # and use them for gradient calculcation of G
562
+ with torch.enable_grad():
563
+
564
+ Y_1.requires_grad = True
565
+
566
+ torch.manual_seed(self.seeds["FFN"])
567
+ g_Y_1 = self.G(Y_1)
568
+
569
+ torch.manual_seed(self.seeds["droppath"])
570
+ g_Y_1 = drop_path(
571
+ g_Y_1, drop_prob=self.drop_path_rate, training=self.training
572
+ )
573
+
574
+ g_Y_1.backward(dY_2, retain_graph=True)
575
+
576
+ # activation recomputation is by design and not part of
577
+ # the computation graph in forward pass.
578
+ with torch.no_grad():
579
+
580
+ X_2 = Y_2 - g_Y_1
581
+ del g_Y_1
582
+
583
+ dY_1 = dY_1 + Y_1.grad
584
+ Y_1.grad = None
585
+
586
+ # record F activations and calc gradients on F
587
+ with torch.enable_grad():
588
+ X_2.requires_grad = True
589
+
590
+ torch.manual_seed(self.seeds["attn"])
591
+ f_X_2 = self.F(X_2)
592
+
593
+ torch.manual_seed(self.seeds["droppath"])
594
+ f_X_2 = drop_path(
595
+ f_X_2, drop_prob=self.drop_path_rate, training=self.training
596
+ )
597
+
598
+ f_X_2.backward(dY_1, retain_graph=True)
599
+
600
+ # propagate reverse computed acitvations at the start of
601
+ # the previou block for backprop.s
602
+ with torch.no_grad():
603
+
604
+ X_1 = Y_1 - f_X_2
605
+
606
+ del f_X_2, Y_1
607
+ dY_2 = dY_2 + X_2.grad
608
+
609
+ X_2.grad = None
610
+ X_2 = X_2.detach()
611
+
612
+ return X_1, X_2, dY_1, dY_2
613
+
614
+
615
+ class MLPSubblock(nn.Module):
616
+ """
617
+ This creates the function G such that the entire block can be
618
+ expressed as F(G(X)). Includes pre-LayerNorm.
619
+ """
620
+
621
+ def __init__(
622
+ self,
623
+ dim,
624
+ mlp_ratio,
625
+ norm_layer=nn.LayerNorm,
626
+ ):
627
+
628
+ super().__init__()
629
+ self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)
630
+
631
+ mlp_hidden_dim = int(dim * mlp_ratio)
632
+
633
+ self.mlp = Mlp(
634
+ in_features=dim,
635
+ hidden_features=mlp_hidden_dim,
636
+ act_layer=nn.GELU,
637
+ )
638
+
639
+ def forward(self, x):
640
+ return self.mlp(self.norm(x))
641
+
642
+
643
+ class AttentionSubBlock(nn.Module):
644
+ """
645
+ This creates the function F such that the entire block can be
646
+ expressed as F(G(X)). Includes pre-LayerNorm.
647
+ """
648
+
649
+ def __init__(
650
+ self,
651
+ dim,
652
+ input_size,
653
+ num_heads,
654
+ cfg,
655
+ dim_out=None,
656
+ kernel_q=(1, 1, 1),
657
+ kernel_kv=(1, 1, 1),
658
+ stride_q=(1, 1, 1),
659
+ stride_kv=(1, 1, 1),
660
+ norm_layer=nn.LayerNorm,
661
+ ):
662
+
663
+ super().__init__()
664
+ self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)
665
+
666
+ # This will be set externally during init
667
+ self.thw = None
668
+
669
+ # the actual attention details are the same as Multiscale
670
+ # attention for MViTv2 (with channel up=projection inside block)
671
+ # can also implement no upprojection attention for vanilla ViT
672
+ self.attn = MultiScaleAttention(
673
+ dim,
674
+ dim_out,
675
+ input_size=input_size,
676
+ num_heads=num_heads,
677
+ kernel_q=kernel_q,
678
+ kernel_kv=kernel_kv,
679
+ stride_q=stride_q,
680
+ stride_kv=stride_kv,
681
+ norm_layer=norm_layer,
682
+ drop_rate=cfg.MVIT.DROPOUT_RATE,
683
+ qkv_bias=cfg.MVIT.QKV_BIAS,
684
+ has_cls_embed=cfg.MVIT.CLS_EMBED_ON,
685
+ mode=cfg.MVIT.MODE,
686
+ pool_first=cfg.MVIT.POOL_FIRST,
687
+ rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL,
688
+ rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL,
689
+ rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT,
690
+ residual_pooling=cfg.MVIT.RESIDUAL_POOLING,
691
+ separate_qkv=cfg.MVIT.SEPARATE_QKV,
692
+ )
693
+
694
+ def forward(self, x):
695
+ out, _ = self.attn(self.norm(x), self.thw)
696
+ return out
skp/models/rev_mvit/stem_helper.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+ """ResNe(X)t 3D stem helper."""
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def get_stem_func(name):
11
+ """
12
+ Retrieves the stem module by name.
13
+ """
14
+ trans_funcs = {"x3d_stem": X3DStem, "basic_stem": ResNetBasicStem}
15
+ assert (
16
+ name in trans_funcs.keys()
17
+ ), "Transformation function '{}' not supported".format(name)
18
+ return trans_funcs[name]
19
+
20
+
21
+ class VideoModelStem(nn.Module):
22
+ """
23
+ Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool
24
+ on input data tensor for one or multiple pathways.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ dim_in,
30
+ dim_out,
31
+ kernel,
32
+ stride,
33
+ padding,
34
+ inplace_relu=True,
35
+ eps=1e-5,
36
+ bn_mmt=0.1,
37
+ norm_module=nn.BatchNorm3d,
38
+ stem_func_name="basic_stem",
39
+ ):
40
+ """
41
+ The `__init__` method of any subclass should also contain these
42
+ arguments. List size of 1 for single pathway models (C2D, I3D, Slow
43
+ and etc), list size of 2 for two pathway models (SlowFast).
44
+
45
+ Args:
46
+ dim_in (list): the list of channel dimensions of the inputs.
47
+ dim_out (list): the output dimension of the convolution in the stem
48
+ layer.
49
+ kernel (list): the kernels' size of the convolutions in the stem
50
+ layers. Temporal kernel size, height kernel size, width kernel
51
+ size in order.
52
+ stride (list): the stride sizes of the convolutions in the stem
53
+ layer. Temporal kernel stride, height kernel size, width kernel
54
+ size in order.
55
+ padding (list): the paddings' sizes of the convolutions in the stem
56
+ layer. Temporal padding size, height padding size, width padding
57
+ size in order.
58
+ inplace_relu (bool): calculate the relu on the original input
59
+ without allocating new memory.
60
+ eps (float): epsilon for batch norm.
61
+ bn_mmt (float): momentum for batch norm. Noted that BN momentum in
62
+ PyTorch = 1 - BN momentum in Caffe2.
63
+ norm_module (nn.Module): nn.Module for the normalization layer. The
64
+ default is nn.BatchNorm3d.
65
+ stem_func_name (string): name of the the stem function applied on
66
+ input to the network.
67
+ """
68
+ super(VideoModelStem, self).__init__()
69
+
70
+ assert (
71
+ len(
72
+ {
73
+ len(dim_in),
74
+ len(dim_out),
75
+ len(kernel),
76
+ len(stride),
77
+ len(padding),
78
+ }
79
+ )
80
+ == 1
81
+ ), "Input pathway dimensions are not consistent. {} {} {} {} {}".format(
82
+ len(dim_in),
83
+ len(dim_out),
84
+ len(kernel),
85
+ len(stride),
86
+ len(padding),
87
+ )
88
+
89
+ self.num_pathways = len(dim_in)
90
+ self.kernel = kernel
91
+ self.stride = stride
92
+ self.padding = padding
93
+ self.inplace_relu = inplace_relu
94
+ self.eps = eps
95
+ self.bn_mmt = bn_mmt
96
+ # Construct the stem layer.
97
+ self._construct_stem(dim_in, dim_out, norm_module, stem_func_name)
98
+
99
+ def _construct_stem(self, dim_in, dim_out, norm_module, stem_func_name):
100
+ trans_func = get_stem_func(stem_func_name)
101
+
102
+ for pathway in range(len(dim_in)):
103
+ stem = trans_func(
104
+ dim_in[pathway],
105
+ dim_out[pathway],
106
+ self.kernel[pathway],
107
+ self.stride[pathway],
108
+ self.padding[pathway],
109
+ self.inplace_relu,
110
+ self.eps,
111
+ self.bn_mmt,
112
+ norm_module,
113
+ )
114
+ self.add_module("pathway{}_stem".format(pathway), stem)
115
+
116
+ def forward(self, x):
117
+ assert (
118
+ len(x) == self.num_pathways
119
+ ), "Input tensor does not contain {} pathway".format(self.num_pathways)
120
+ # use a new list, don't modify in-place the x list, which is bad for activation checkpointing.
121
+ y = []
122
+ for pathway in range(len(x)):
123
+ m = getattr(self, "pathway{}_stem".format(pathway))
124
+ y.append(m(x[pathway]))
125
+ return y
126
+
127
+
128
+ class ResNetBasicStem(nn.Module):
129
+ """
130
+ ResNe(X)t 3D stem module.
131
+ Performs spatiotemporal Convolution, BN, and Relu following by a
132
+ spatiotemporal pooling.
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ dim_in,
138
+ dim_out,
139
+ kernel,
140
+ stride,
141
+ padding,
142
+ inplace_relu=True,
143
+ eps=1e-5,
144
+ bn_mmt=0.1,
145
+ norm_module=nn.BatchNorm3d,
146
+ ):
147
+ """
148
+ The `__init__` method of any subclass should also contain these arguments.
149
+
150
+ Args:
151
+ dim_in (int): the channel dimension of the input. Normally 3 is used
152
+ for rgb input, and 2 or 3 is used for optical flow input.
153
+ dim_out (int): the output dimension of the convolution in the stem
154
+ layer.
155
+ kernel (list): the kernel size of the convolution in the stem layer.
156
+ temporal kernel size, height kernel size, width kernel size in
157
+ order.
158
+ stride (list): the stride size of the convolution in the stem layer.
159
+ temporal kernel stride, height kernel size, width kernel size in
160
+ order.
161
+ padding (int): the padding size of the convolution in the stem
162
+ layer, temporal padding size, height padding size, width
163
+ padding size in order.
164
+ inplace_relu (bool): calculate the relu on the original input
165
+ without allocating new memory.
166
+ eps (float): epsilon for batch norm.
167
+ bn_mmt (float): momentum for batch norm. Noted that BN momentum in
168
+ PyTorch = 1 - BN momentum in Caffe2.
169
+ norm_module (nn.Module): nn.Module for the normalization layer. The
170
+ default is nn.BatchNorm3d.
171
+ """
172
+ super(ResNetBasicStem, self).__init__()
173
+ self.kernel = kernel
174
+ self.stride = stride
175
+ self.padding = padding
176
+ self.inplace_relu = inplace_relu
177
+ self.eps = eps
178
+ self.bn_mmt = bn_mmt
179
+ # Construct the stem layer.
180
+ self._construct_stem(dim_in, dim_out, norm_module)
181
+
182
+ def _construct_stem(self, dim_in, dim_out, norm_module):
183
+ self.conv = nn.Conv3d(
184
+ dim_in,
185
+ dim_out,
186
+ self.kernel,
187
+ stride=self.stride,
188
+ padding=self.padding,
189
+ bias=False,
190
+ )
191
+ self.bn = norm_module(
192
+ num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
193
+ )
194
+ self.relu = nn.ReLU(self.inplace_relu)
195
+ self.pool_layer = nn.MaxPool3d(
196
+ kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]
197
+ )
198
+
199
+ def forward(self, x):
200
+ x = self.conv(x)
201
+ x = self.bn(x)
202
+ x = self.relu(x)
203
+ x = self.pool_layer(x)
204
+ return x
205
+
206
+
207
+ class X3DStem(nn.Module):
208
+ """
209
+ X3D's 3D stem module.
210
+ Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a
211
+ spatiotemporal pooling.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ dim_in,
217
+ dim_out,
218
+ kernel,
219
+ stride,
220
+ padding,
221
+ inplace_relu=True,
222
+ eps=1e-5,
223
+ bn_mmt=0.1,
224
+ norm_module=nn.BatchNorm3d,
225
+ ):
226
+ """
227
+ The `__init__` method of any subclass should also contain these arguments.
228
+
229
+ Args:
230
+ dim_in (int): the channel dimension of the input. Normally 3 is used
231
+ for rgb input, and 2 or 3 is used for optical flow input.
232
+ dim_out (int): the output dimension of the convolution in the stem
233
+ layer.
234
+ kernel (list): the kernel size of the convolution in the stem layer.
235
+ temporal kernel size, height kernel size, width kernel size in
236
+ order.
237
+ stride (list): the stride size of the convolution in the stem layer.
238
+ temporal kernel stride, height kernel size, width kernel size in
239
+ order.
240
+ padding (int): the padding size of the convolution in the stem
241
+ layer, temporal padding size, height padding size, width
242
+ padding size in order.
243
+ inplace_relu (bool): calculate the relu on the original input
244
+ without allocating new memory.
245
+ eps (float): epsilon for batch norm.
246
+ bn_mmt (float): momentum for batch norm. Noted that BN momentum in
247
+ PyTorch = 1 - BN momentum in Caffe2.
248
+ norm_module (nn.Module): nn.Module for the normalization layer. The
249
+ default is nn.BatchNorm3d.
250
+ """
251
+ super(X3DStem, self).__init__()
252
+ self.kernel = kernel
253
+ self.stride = stride
254
+ self.padding = padding
255
+ self.inplace_relu = inplace_relu
256
+ self.eps = eps
257
+ self.bn_mmt = bn_mmt
258
+ # Construct the stem layer.
259
+ self._construct_stem(dim_in, dim_out, norm_module)
260
+
261
+ def _construct_stem(self, dim_in, dim_out, norm_module):
262
+ self.conv_xy = nn.Conv3d(
263
+ dim_in,
264
+ dim_out,
265
+ kernel_size=(1, self.kernel[1], self.kernel[2]),
266
+ stride=(1, self.stride[1], self.stride[2]),
267
+ padding=(0, self.padding[1], self.padding[2]),
268
+ bias=False,
269
+ )
270
+ self.conv = nn.Conv3d(
271
+ dim_out,
272
+ dim_out,
273
+ kernel_size=(self.kernel[0], 1, 1),
274
+ stride=(self.stride[0], 1, 1),
275
+ padding=(self.padding[0], 0, 0),
276
+ bias=False,
277
+ groups=dim_out,
278
+ )
279
+
280
+ self.bn = norm_module(
281
+ num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
282
+ )
283
+ self.relu = nn.ReLU(self.inplace_relu)
284
+
285
+ def forward(self, x):
286
+ x = self.conv_xy(x)
287
+ x = self.conv(x)
288
+ x = self.bn(x)
289
+ x = self.relu(x)
290
+ return x
291
+
292
+
293
+ class PatchEmbed(nn.Module):
294
+ """
295
+ PatchEmbed.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ dim_in=3,
301
+ dim_out=768,
302
+ kernel=(1, 16, 16),
303
+ stride=(1, 4, 4),
304
+ padding=(1, 7, 7),
305
+ conv_2d=False,
306
+ ):
307
+ super().__init__()
308
+ if conv_2d:
309
+ conv = nn.Conv2d
310
+ else:
311
+ conv = nn.Conv3d
312
+ self.proj = conv(
313
+ dim_in,
314
+ dim_out,
315
+ kernel_size=kernel,
316
+ stride=stride,
317
+ padding=padding,
318
+ )
319
+
320
+ def forward(self, x, keep_spatial=False):
321
+ x = self.proj(x)
322
+ if keep_spatial:
323
+ return x, x.shape
324
+ # B C (T) H W -> B (T)HW C
325
+ return x.flatten(2).transpose(1, 2), x.shape
skp/models/rev_mvit/utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def round_width(width, multiplier, min_width=1, divisor=1, verbose=False):
8
+ if not multiplier:
9
+ return width
10
+ width *= multiplier
11
+ min_width = min_width or divisor
12
+ if verbose:
13
+ print(f"min width {min_width}")
14
+ print(f"width {width} divisor {divisor}")
15
+ print(f"other {int(width + divisor / 2) // divisor * divisor}")
16
+
17
+ width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
18
+ if width_out < 0.9 * width:
19
+ width_out += divisor
20
+ return int(width_out)
21
+
22
+
23
+ def validate_checkpoint_wrapper_import(checkpoint_wrapper):
24
+ """
25
+ Check if checkpoint_wrapper is imported.
26
+ """
27
+ if checkpoint_wrapper is None:
28
+ raise ImportError("Please install fairscale.")
29
+
30
+
31
+ def get_gkern(kernlen, std):
32
+ """Returns a 2D Gaussian kernel array."""
33
+
34
+ def _gaussian_fn(kernlen, std):
35
+ n = torch.arange(0, kernlen).float()
36
+ n -= n.mean()
37
+ n /= std
38
+ w = torch.exp(-0.5 * n**2)
39
+ return w
40
+
41
+ gkern1d = _gaussian_fn(kernlen, std)
42
+ gkern2d = torch.outer(gkern1d, gkern1d)
43
+ return gkern2d / gkern2d.sum()
44
+
45
+
46
+ # --------------------------------------------------------
47
+ # 2D sine-cosine position embedding
48
+ # References:
49
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
50
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
51
+ # --------------------------------------------------------
52
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
53
+ """
54
+ grid_size: int of the grid height and width
55
+ t_size: int of the temporal size
56
+ return:
57
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
58
+ """
59
+ assert embed_dim % 4 == 0
60
+ embed_dim_spatial = embed_dim // 4 * 3
61
+ embed_dim_temporal = embed_dim // 4
62
+
63
+ # spatial
64
+ grid_h = np.arange(grid_size, dtype=np.float32)
65
+ grid_w = np.arange(grid_size, dtype=np.float32)
66
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
67
+ grid = np.stack(grid, axis=0)
68
+
69
+ grid = grid.reshape([2, 1, grid_size, grid_size])
70
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
71
+ embed_dim_spatial, grid
72
+ )
73
+
74
+ # temporal
75
+ grid_t = np.arange(t_size, dtype=np.float32)
76
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
77
+ embed_dim_temporal, grid_t
78
+ )
79
+
80
+ # concate: [T, H, W] order
81
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
82
+ pos_embed_temporal = np.repeat(
83
+ pos_embed_temporal, grid_size**2, axis=1
84
+ ) # [T, H*W, D // 4]
85
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
86
+ pos_embed_spatial = np.repeat(
87
+ pos_embed_spatial, t_size, axis=0
88
+ ) # [T, H*W, D // 4 * 3]
89
+
90
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
91
+ pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
92
+
93
+ if cls_token:
94
+ pos_embed = np.concatenate(
95
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
96
+ )
97
+ return pos_embed
98
+
99
+
100
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
101
+ """
102
+ grid_size: int of the grid height and width
103
+ return:
104
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
105
+ """
106
+ grid_h = np.arange(grid_size, dtype=np.float32)
107
+ grid_w = np.arange(grid_size, dtype=np.float32)
108
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
109
+ grid = np.stack(grid, axis=0)
110
+
111
+ grid = grid.reshape([2, 1, grid_size, grid_size])
112
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
113
+ if cls_token:
114
+ pos_embed = np.concatenate(
115
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
116
+ )
117
+ return pos_embed
118
+
119
+
120
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
121
+ assert embed_dim % 2 == 0
122
+
123
+ # use half of dimensions to encode grid_h
124
+ emb_h = get_1d_sincos_pos_embed_from_grid(
125
+ embed_dim // 2, grid[0]
126
+ ) # (H*W, D/2)
127
+ emb_w = get_1d_sincos_pos_embed_from_grid(
128
+ embed_dim // 2, grid[1]
129
+ ) # (H*W, D/2)
130
+
131
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
132
+ return emb
133
+
134
+
135
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
136
+ """
137
+ embed_dim: output dimension for each position
138
+ pos: a list of positions to be encoded: size (M,)
139
+ out: (M, D)
140
+ """
141
+ assert embed_dim % 2 == 0
142
+ omega = np.arange(embed_dim // 2, dtype=np.float)
143
+ omega /= embed_dim / 2.0
144
+ omega = 1.0 / 10000**omega # (D/2,)
145
+
146
+ pos = pos.reshape(-1) # (M,)
147
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
148
+
149
+ emb_sin = np.sin(out) # (M, D/2)
150
+ emb_cos = np.cos(out) # (M, D/2)
151
+
152
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
153
+ return emb
154
+
155
+
156
+ # --------------------------------------------------------
157
+ # Interpolate position embeddings for high-resolution
158
+ # References:
159
+ # DeiT: https://github.com/facebookresearch/deit
160
+ # --------------------------------------------------------
161
+ def interpolate_pos_embed(model, checkpoint_model):
162
+ if "pos_embed" in checkpoint_model:
163
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
164
+ embedding_size = pos_embed_checkpoint.shape[-1]
165
+ num_patches = model.patch_embed.num_patches
166
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
167
+ # height (== width) for the checkpoint position embedding
168
+ orig_size = int(
169
+ (pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
170
+ )
171
+ # height (== width) for the new position embedding
172
+ new_size = int(num_patches**0.5)
173
+ # class_token and dist_token are kept unchanged
174
+ if orig_size != new_size:
175
+ print(
176
+ "Position interpolate from %dx%d to %dx%d"
177
+ % (orig_size, orig_size, new_size, new_size)
178
+ )
179
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
180
+ # only the position tokens are interpolated
181
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
182
+ pos_tokens = pos_tokens.reshape(
183
+ -1, orig_size, orig_size, embedding_size
184
+ ).permute(0, 3, 1, 2)
185
+ pos_tokens = torch.nn.functional.interpolate(
186
+ pos_tokens,
187
+ size=(new_size, new_size),
188
+ mode="bicubic",
189
+ align_corners=False,
190
+ )
191
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
192
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
193
+ checkpoint_model["pos_embed"] = new_pos_embed
194
+
195
+
196
+ def calc_mvit_feature_geometry(cfg):
197
+ feat_size = [
198
+ [
199
+ cfg.DATA.NUM_FRAMES // cfg.MVIT.PATCH_STRIDE[0]
200
+ if len(cfg.MVIT.PATCH_STRIDE) > 2
201
+ else 1,
202
+ cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-2],
203
+ cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-1],
204
+ ]
205
+ for i in range(cfg.MVIT.DEPTH)
206
+ ]
207
+ feat_stride = [
208
+ [
209
+ cfg.MVIT.PATCH_STRIDE[0] if len(cfg.MVIT.PATCH_STRIDE) > 2 else 1,
210
+ cfg.MVIT.PATCH_STRIDE[-2],
211
+ cfg.MVIT.PATCH_STRIDE[-1],
212
+ ]
213
+ for i in range(cfg.MVIT.DEPTH)
214
+ ]
215
+ for _, x in enumerate(cfg.MVIT.POOL_Q_STRIDE):
216
+ for i in range(cfg.MVIT.DEPTH):
217
+ if i >= x[0]:
218
+ for j in range(len(feat_size[i])):
219
+ feat_size[i][j] = feat_size[i][j] // x[j + 1]
220
+ feat_stride[i][j] = feat_stride[i][j] * x[j + 1]
221
+ return feat_size, feat_stride