rahulvenkk
commited on
Commit
•
6dfcb0f
1
Parent(s):
0ba04fa
app.py updated
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +3 -0
- README.md +47 -14
- assets/color_wheel.png +0 -0
- assets/cwm_teaser.gif +3 -0
- assets/desk_1.jpg +3 -0
- assets/flow_test_videos/libby.mp4 +3 -0
- assets/flow_test_videos/weight_lifter.mp4 +3 -0
- cwm/__init__.py +0 -0
- cwm/data/__init__.py +0 -0
- cwm/data/dataset.py +453 -0
- cwm/data/dataset_utils.py +73 -0
- cwm/data/masking_generator.py +86 -0
- cwm/data/transforms.py +206 -0
- cwm/data/video_file_lists/kinetics_400_train_list.txt +3 -0
- cwm/data/video_file_lists/kinetics_400_train_list_sing.txt +3 -0
- cwm/engine_for_pretraining.py +92 -0
- cwm/eval/Action_recognition/__init__.py +0 -0
- cwm/eval/Flow/__init__.py +0 -0
- cwm/eval/Flow/create_spring_submission_parallel.sh +36 -0
- cwm/eval/Flow/create_spring_submission_unified.py +111 -0
- cwm/eval/Flow/flow_extraction_classes.py +122 -0
- cwm/eval/Flow/flow_utils.py +569 -0
- cwm/eval/Flow/flow_utils_legacy.py +152 -0
- cwm/eval/Flow/generator.py +579 -0
- cwm/eval/Flow/losses.py +60 -0
- cwm/eval/Flow/masking_flow.py +375 -0
- cwm/eval/Flow/vis_utils.py +150 -0
- cwm/eval/IntPhys/__init__.py +0 -0
- cwm/eval/Physion/__init__.py +0 -0
- cwm/eval/Physion/feature_extractor.py +317 -0
- cwm/eval/Physion/flow_utils.py +279 -0
- cwm/eval/Physion/run_eval.sh +17 -0
- cwm/eval/Physion/run_eval_kfflow.sh +18 -0
- cwm/eval/Physion/run_eval_mp4s.sh +19 -0
- cwm/eval/Physion/run_eval_mp4s_keyp.sh +17 -0
- cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh +17 -0
- cwm/eval/Segmentation/__init__.py +0 -0
- cwm/eval/Segmentation/archive/__init__.py +0 -0
- cwm/eval/Segmentation/archive/common/__init__.py +0 -0
- cwm/eval/Segmentation/archive/common/coco_loader_lsj.py +222 -0
- cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py +19 -0
- cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py +54 -0
- cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py +26 -0
- cwm/eval/Segmentation/archive/competition.py +673 -0
- cwm/eval/Segmentation/archive/configs/__init__.py +0 -0
- cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py +56 -0
- cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py +59 -0
- cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py +56 -0
- cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py +59 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.txt filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
|
3 |
+
.pth
|
README.md
CHANGED
@@ -1,14 +1,47 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h2>Understanding Physical Dynamics with Counterfactual World Modeling</h2>
|
3 |
+
|
4 |
+
[**Rahul Venkatesh***](https://rahulvenkk.github.io/)<sup>1</sup> · [**Honglin Chen***](https://web.stanford.edu/~honglinc/)<sup>1*</sup> · [**Kevin Feigelis***](https://neuroscience.stanford.edu/people/kevin-t-feigelis)<sup>1</sup> · [**Daniel M. Bear**](https://twitter.com/recursus?lang=en)<sup>1</sup> · [**Khaled Jedoui**](https://web.stanford.edu/~thekej/)<sup>1</sup> · [**Klemen Kotar**](https://klemenkotar.github.io/)<sup>1</sup> · [**Felix Binder**](https://ac.felixbinder.net/)<sup>2</sup> · [**Wanhee Lee**](https://www.linkedin.com/in/wanhee-lee-31102820b/)<sup>1</sup> · [**Sherry Liu**](https://neuroailab.github.io/cwm-physics/)<sup>1</sup> · [**Kevin A. Smith**](https://www.mit.edu/~k2smith/)<sup>3</sup> · [**Judith E. Fan**](https://cogtoolslab.github.io/)<sup>1</sup> · [**Daniel L. K. Yamins**](https://stanford.edu/~yamins/)<sup>1</sup>
|
5 |
+
|
6 |
+
(* equal contribution)
|
7 |
+
|
8 |
+
<sup>1</sup>Stanford    <sup>2</sup>UCSD    <sup>3</sup>MIT
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
<a href="https://arxiv.org/abs/2312.06721"><img src='https://img.shields.io/badge/arXiv-CWM-red' alt='Paper PDF'></a>
|
14 |
+
<a href='https://neuroailab.github.io/cwm-physics/'><img src='https://img.shields.io/badge/Project_Page-CWM-green' alt='Project Page'></a>
|
15 |
+
<a href='https://neuroailab.github.io/cwm-physics/'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
16 |
+
<a href='https://neuroailab.github.io/cwm-physics/'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Colab-yellow'></a>
|
17 |
+
</div>
|
18 |
+
|
19 |
+
This work presents the Counterfactual World Modeling (CWM) framework. CWM is capable of counterfactual prediction and extraction of vision structures useful for understanding physical dynamics.
|
20 |
+
|
21 |
+
![](assets/cwm_teaser.gif)
|
22 |
+
|
23 |
+
## 📣 News
|
24 |
+
|
25 |
+
- 2024-06-01: Release [project page](https://neuroailab.github.io) and [codes](https://github.com/rahulvenkk/cwm_release.git)
|
26 |
+
|
27 |
+
## 🔨 Installation
|
28 |
+
|
29 |
+
```
|
30 |
+
git clone https://github.com/rahulvenkk/cwm_release.git
|
31 |
+
pip install -e .
|
32 |
+
```
|
33 |
+
|
34 |
+
## ✨ Usage
|
35 |
+
To download and use a pre-trianed model run the following
|
36 |
+
```
|
37 |
+
from cwm.model.model_factory import model_factory
|
38 |
+
model = model_factory.load_model('vitbase_8x8patch_3frames_1tube')
|
39 |
+
```
|
40 |
+
This will automatically initialize the appropriate model class and download the specified weights to your `$CACHE` directory.
|
41 |
+
|
42 |
+
## 🔄 Pre-training
|
43 |
+
To train the model run the following script
|
44 |
+
|
45 |
+
```
|
46 |
+
./scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh
|
47 |
+
```
|
assets/color_wheel.png
ADDED
assets/cwm_teaser.gif
ADDED
Git LFS Details
|
assets/desk_1.jpg
ADDED
Git LFS Details
|
assets/flow_test_videos/libby.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d1887bc796d1883e8a63405398125476257572c0c8d7f1862bf309e422b4828
|
3 |
+
size 671950
|
assets/flow_test_videos/weight_lifter.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:525988b314936079f904236c614e2f987ba64fd5c69f4f12dd5b6b9076311854
|
3 |
+
size 1176790
|
cwm/__init__.py
ADDED
File without changes
|
cwm/data/__init__.py
ADDED
File without changes
|
cwm/data/dataset.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import decord
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
|
10 |
+
class VideoMAE(torch.utils.data.Dataset):
|
11 |
+
"""Load your own video classification dataset.
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
root : str, required.
|
15 |
+
Path to the root folder storing the dataset.
|
16 |
+
setting : str, required.
|
17 |
+
A text file describing the dataset, each line per video sample.
|
18 |
+
There are three items in each line: (1) video path; (2) video length and (3) video label.
|
19 |
+
train : bool, default True.
|
20 |
+
Whether to load the training or validation set.
|
21 |
+
test_mode : bool, default False.
|
22 |
+
Whether to perform evaluation on the test set.
|
23 |
+
Usually there is three-crop or ten-crop evaluation strategy involved.
|
24 |
+
name_pattern : str, default None.
|
25 |
+
The naming pattern of the decoded video frames.
|
26 |
+
For example, img_00012.jpg.
|
27 |
+
video_ext : str, default 'mp4'.
|
28 |
+
If video_loader is set to True, please specify the video format accordinly.
|
29 |
+
is_color : bool, default True.
|
30 |
+
Whether the loaded image is color or grayscale.
|
31 |
+
modality : str, default 'rgb'.
|
32 |
+
Input modalities, we support only rgb video frames for now.
|
33 |
+
Will add support for rgb difference image and optical flow image later.
|
34 |
+
num_segments : int, default 1.
|
35 |
+
Number of segments to evenly divide the video into clips.
|
36 |
+
A useful technique to obtain global video-level information.
|
37 |
+
Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
|
38 |
+
num_crop : int, default 1.
|
39 |
+
Number of crops for each image. default is 1.
|
40 |
+
Common choices are three crops and ten crops during evaluation.
|
41 |
+
new_length : int, default 1.
|
42 |
+
The length of input video clip. Default is a single image, but it can be multiple video frames.
|
43 |
+
For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
|
44 |
+
new_step : int, default 1.
|
45 |
+
Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
|
46 |
+
new_step=2 means we will extract a video clip of every other frame.
|
47 |
+
temporal_jitter : bool, default False.
|
48 |
+
Whether to temporally jitter if new_step > 1.
|
49 |
+
video_loader : bool, default False.
|
50 |
+
Whether to use video loader to load data.
|
51 |
+
use_decord : bool, default True.
|
52 |
+
Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
|
53 |
+
transform : function, default None.
|
54 |
+
A function that takes data and label and transforms them.
|
55 |
+
data_aug : str, default 'v1'.
|
56 |
+
Different types of data augmentation auto. Supports v1, v2, v3 and v4.
|
57 |
+
lazy_init : bool, default False.
|
58 |
+
If set to True, build a dataset instance without loading any dataset.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self,
|
62 |
+
root,
|
63 |
+
setting,
|
64 |
+
train=True,
|
65 |
+
test_mode=False,
|
66 |
+
name_pattern='img_%05d.jpg',
|
67 |
+
video_ext='mp4',
|
68 |
+
is_color=True,
|
69 |
+
modality='rgb',
|
70 |
+
num_segments=1,
|
71 |
+
num_crop=1,
|
72 |
+
new_length=1,
|
73 |
+
new_step=1,
|
74 |
+
randomize_interframes=False,
|
75 |
+
transform=None,
|
76 |
+
temporal_jitter=False,
|
77 |
+
video_loader=False,
|
78 |
+
use_decord=False,
|
79 |
+
lazy_init=False,
|
80 |
+
is_video_dataset=True):
|
81 |
+
|
82 |
+
super(VideoMAE, self).__init__()
|
83 |
+
self.root = root
|
84 |
+
self.setting = setting
|
85 |
+
self.train = train
|
86 |
+
self.test_mode = test_mode
|
87 |
+
self.is_color = is_color
|
88 |
+
self.modality = modality
|
89 |
+
self.num_segments = num_segments
|
90 |
+
self.num_crop = num_crop
|
91 |
+
self.new_length = new_length
|
92 |
+
|
93 |
+
self.randomize_interframes = randomize_interframes
|
94 |
+
self._new_step = new_step # If randomize_interframes is True, then this is the max, otherwise it's just the skip
|
95 |
+
# self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated
|
96 |
+
self.temporal_jitter = temporal_jitter
|
97 |
+
self.name_pattern = name_pattern
|
98 |
+
self.video_loader = video_loader
|
99 |
+
self.video_ext = video_ext
|
100 |
+
self.use_decord = use_decord
|
101 |
+
self.transform = transform
|
102 |
+
self.lazy_init = lazy_init
|
103 |
+
|
104 |
+
if (not self.lazy_init) and is_video_dataset:
|
105 |
+
self.clips = self._make_dataset(root, setting)
|
106 |
+
if len(self.clips) == 0:
|
107 |
+
raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
|
108 |
+
"Check your data directory (opt.data-dir)."))
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
|
112 |
+
directory, target = self.clips[index]
|
113 |
+
|
114 |
+
if self.video_loader:
|
115 |
+
if '.' in directory.split('/')[-1]:
|
116 |
+
# data in the "setting" file already have extension, e.g., demo.mp4
|
117 |
+
video_name = directory
|
118 |
+
else:
|
119 |
+
# data in the "setting" file do not have extension, e.g., demo
|
120 |
+
# So we need to provide extension (i.e., .mp4) to complete the file name.
|
121 |
+
video_name = '{}.{}'.format(directory, self.video_ext)
|
122 |
+
|
123 |
+
try:
|
124 |
+
decord_vr = decord.VideoReader(video_name, num_threads=1)
|
125 |
+
except:
|
126 |
+
# return video_name
|
127 |
+
return (self.__getitem__(index + 1))
|
128 |
+
duration = len(decord_vr)
|
129 |
+
|
130 |
+
segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration)
|
131 |
+
|
132 |
+
images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets,
|
133 |
+
new_step, skip_length)
|
134 |
+
|
135 |
+
process_data, mask = self.transform((images, None)) # T*C,H,W
|
136 |
+
process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,
|
137 |
+
1) # T*C,H,W -> T,C,H,W -> C,T,H,W
|
138 |
+
|
139 |
+
return (process_data, mask)
|
140 |
+
|
141 |
+
def __len__(self):
|
142 |
+
return len(self.clips)
|
143 |
+
|
144 |
+
def _make_dataset(self, directory, setting):
|
145 |
+
if not os.path.exists(setting):
|
146 |
+
raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
|
147 |
+
clips = []
|
148 |
+
with open(setting) as split_f:
|
149 |
+
data = split_f.readlines()
|
150 |
+
for line in data:
|
151 |
+
line_info = line.split(' ')
|
152 |
+
# line format: video_path, video_duration, video_label
|
153 |
+
if len(line_info) < 2:
|
154 |
+
raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
|
155 |
+
elif len(line_info) > 2:
|
156 |
+
line_info = (' '.join(line_info[:-1]), line_info[-1]) # filename has spaces
|
157 |
+
clip_path = os.path.join(line_info[0])
|
158 |
+
target = int(line_info[1])
|
159 |
+
item = (clip_path, target)
|
160 |
+
clips.append(item)
|
161 |
+
# import torch_xla.core.xla_model as xm
|
162 |
+
# print = xm.master_print
|
163 |
+
# print("Dataset created. Number of clips: ", len(clips))
|
164 |
+
return clips
|
165 |
+
|
166 |
+
def _sample_train_indices(self, num_frames):
|
167 |
+
if self.randomize_interframes is False:
|
168 |
+
new_step = self._new_step
|
169 |
+
else:
|
170 |
+
new_step = np.random.randint(1, self._new_step + 1)
|
171 |
+
|
172 |
+
skip_length = self.new_length * new_step
|
173 |
+
|
174 |
+
average_duration = (num_frames - skip_length + 1) // self.num_segments
|
175 |
+
if average_duration > 0:
|
176 |
+
offsets = np.multiply(list(range(self.num_segments)),
|
177 |
+
average_duration)
|
178 |
+
offsets = offsets + np.random.randint(average_duration,
|
179 |
+
size=self.num_segments)
|
180 |
+
elif num_frames > max(self.num_segments, skip_length):
|
181 |
+
offsets = np.sort(np.random.randint(
|
182 |
+
num_frames - skip_length + 1,
|
183 |
+
size=self.num_segments))
|
184 |
+
else:
|
185 |
+
offsets = np.zeros((self.num_segments,))
|
186 |
+
|
187 |
+
if self.temporal_jitter:
|
188 |
+
skip_offsets = np.random.randint(
|
189 |
+
new_step, size=skip_length // new_step)
|
190 |
+
else:
|
191 |
+
skip_offsets = np.zeros(
|
192 |
+
skip_length // new_step, dtype=int)
|
193 |
+
return offsets + 1, skip_offsets, new_step, skip_length
|
194 |
+
|
195 |
+
def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step,
|
196 |
+
skip_length):
|
197 |
+
sampled_list = []
|
198 |
+
frame_id_list = []
|
199 |
+
for seg_ind in indices:
|
200 |
+
offset = int(seg_ind)
|
201 |
+
for i, _ in enumerate(range(0, skip_length, new_step)):
|
202 |
+
if offset + skip_offsets[i] <= duration:
|
203 |
+
frame_id = offset + skip_offsets[i] - 1
|
204 |
+
else:
|
205 |
+
frame_id = offset - 1
|
206 |
+
frame_id_list.append(frame_id)
|
207 |
+
if offset + new_step < duration:
|
208 |
+
offset += new_step
|
209 |
+
try:
|
210 |
+
video_data = video_reader.get_batch(frame_id_list).asnumpy()
|
211 |
+
sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in
|
212 |
+
enumerate(frame_id_list)]
|
213 |
+
except:
|
214 |
+
raise RuntimeError(
|
215 |
+
'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory,
|
216 |
+
duration))
|
217 |
+
return sampled_list
|
218 |
+
|
219 |
+
|
220 |
+
class ContextAndTargetVideoDataset(VideoMAE):
|
221 |
+
"""
|
222 |
+
A video dataset whose provided videos consist of (1) a "context" sequence of length Tc
|
223 |
+
and (2) a "target" sequence Tt.
|
224 |
+
|
225 |
+
These two sequences have the same frame rate (specificiable in real units) but are
|
226 |
+
separated by a specified gap (which may vary for different examples.)
|
227 |
+
|
228 |
+
The main use case is for training models to predict ahead by some variable amount,
|
229 |
+
given the context.
|
230 |
+
"""
|
231 |
+
|
232 |
+
standard_fps = [12, 24, 30, 48, 60, 100]
|
233 |
+
|
234 |
+
def __init__(self,
|
235 |
+
root,
|
236 |
+
setting,
|
237 |
+
train=True,
|
238 |
+
test_mode=False,
|
239 |
+
transform=None,
|
240 |
+
step_units='ms',
|
241 |
+
new_step=150,
|
242 |
+
start_frame=0,
|
243 |
+
context_length=2,
|
244 |
+
target_length=1,
|
245 |
+
channels_first=True,
|
246 |
+
generate_masks=True,
|
247 |
+
mask_generator=None,
|
248 |
+
context_target_gap=[400, 600],
|
249 |
+
normalize_timestamps=True,
|
250 |
+
default_fps=30,
|
251 |
+
min_fps=0.1,
|
252 |
+
seed=0,
|
253 |
+
*args,
|
254 |
+
**kwargs):
|
255 |
+
super(ContextAndTargetVideoDataset, self).__init__(
|
256 |
+
root=root,
|
257 |
+
setting=setting,
|
258 |
+
train=train,
|
259 |
+
test_mode=test_mode,
|
260 |
+
transform=transform,
|
261 |
+
new_length=context_length,
|
262 |
+
use_decord=True,
|
263 |
+
lazy_init=False,
|
264 |
+
video_loader=True,
|
265 |
+
*args, **kwargs)
|
266 |
+
|
267 |
+
# breakpoint()
|
268 |
+
|
269 |
+
self.context_length = self.new_length
|
270 |
+
self.target_length = target_length
|
271 |
+
|
272 |
+
## convert from fps and step size to frames
|
273 |
+
self._fps = None
|
274 |
+
self._min_fps = min_fps
|
275 |
+
self._default_fps = default_fps
|
276 |
+
self._step_units = step_units
|
277 |
+
self.new_step = new_step
|
278 |
+
|
279 |
+
## sampling for train and test
|
280 |
+
self._start_frame = start_frame
|
281 |
+
self.gap = context_target_gap
|
282 |
+
self.seed = seed
|
283 |
+
self.rng = np.random.RandomState(seed=seed)
|
284 |
+
|
285 |
+
# breakpoint()
|
286 |
+
|
287 |
+
## output formatting
|
288 |
+
self._channels_first = channels_first
|
289 |
+
self._normalize_timestamps = normalize_timestamps
|
290 |
+
self._generate_masks = generate_masks
|
291 |
+
self.mask_generator = mask_generator
|
292 |
+
|
293 |
+
|
294 |
+
def _get_frames_per_t(self, t):
|
295 |
+
if self._step_units == 'frames' or (self._step_units is None):
|
296 |
+
return int(t)
|
297 |
+
|
298 |
+
assert self._fps is not None
|
299 |
+
t_per_frame = 1 / self._fps
|
300 |
+
if self._step_units in ['ms', 'milliseconds']:
|
301 |
+
t_per_frame *= 1000.0
|
302 |
+
|
303 |
+
return max(int(np.round(t / t_per_frame)), 1)
|
304 |
+
|
305 |
+
@property
|
306 |
+
def new_step(self):
|
307 |
+
if self._fps is None:
|
308 |
+
return None
|
309 |
+
else:
|
310 |
+
return self._get_frames_per_t(self._new_step)
|
311 |
+
|
312 |
+
@new_step.setter
|
313 |
+
def new_step(self, v):
|
314 |
+
self._new_step = v
|
315 |
+
|
316 |
+
@property
|
317 |
+
def gap(self):
|
318 |
+
if self._fps is None:
|
319 |
+
return [1, 2]
|
320 |
+
else:
|
321 |
+
gap = [self._get_frames_per_t(self._gap[0]),
|
322 |
+
self._get_frames_per_t(self._gap[1])]
|
323 |
+
gap[1] = max(gap[1], gap[0] + 1)
|
324 |
+
return gap
|
325 |
+
|
326 |
+
@gap.setter
|
327 |
+
def gap(self, v):
|
328 |
+
if v is None:
|
329 |
+
v = self._new_step
|
330 |
+
if not isinstance(v, (list, tuple)):
|
331 |
+
v = [v, v]
|
332 |
+
self._gap = v
|
333 |
+
|
334 |
+
def _get_video_name(self, directory):
|
335 |
+
if ''.join(['.', self.video_ext]) in directory.split('/')[-1]:
|
336 |
+
# data in the "setting" file has extension, e.g. demo.mpr
|
337 |
+
video_name = directory
|
338 |
+
else:
|
339 |
+
# data doesn't have an extension
|
340 |
+
video_name = '{}.{}'.format(directory, self.video_ext)
|
341 |
+
return video_name
|
342 |
+
|
343 |
+
def _set_fps(self, reader):
|
344 |
+
"""click fps to a standard"""
|
345 |
+
if self._step_units == 'frames' or self._step_units is None:
|
346 |
+
self._fps = None
|
347 |
+
else:
|
348 |
+
self._fps = None
|
349 |
+
fps = reader.get_avg_fps()
|
350 |
+
for st in self.standard_fps:
|
351 |
+
if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st):
|
352 |
+
self._fps = st
|
353 |
+
if self._fps is None:
|
354 |
+
self._fps = int(np.round(fps))
|
355 |
+
|
356 |
+
if self._fps < self._min_fps:
|
357 |
+
self._fps = self._default_fps
|
358 |
+
|
359 |
+
def _get_step_and_gap(self):
|
360 |
+
step = self.new_step
|
361 |
+
if self.randomize_interframes and self.train:
|
362 |
+
step = self.rng.randint(1, step + 1)
|
363 |
+
|
364 |
+
if self.train:
|
365 |
+
gap = self.rng.randint(*self.gap)
|
366 |
+
else:
|
367 |
+
gap = sum(self.gap) // 2
|
368 |
+
return (step, gap)
|
369 |
+
|
370 |
+
def _sample_frames(self):
|
371 |
+
step, gap = self._get_step_and_gap()
|
372 |
+
|
373 |
+
## compute total length of sample
|
374 |
+
## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2:
|
375 |
+
## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13
|
376 |
+
## so len(video) must be >= 13
|
377 |
+
self._total_length = self.context_length * step + gap + (self.target_length - 1) * step
|
378 |
+
if self._total_length > (self._num_frames - self._start_frame):
|
379 |
+
if self.train:
|
380 |
+
return None
|
381 |
+
else:
|
382 |
+
raise ValueError(
|
383 |
+
"movie of length %d starting at fr=%d is too long for video of %d frames" % \
|
384 |
+
(self._total_length, self._start_frame, self._num_frames))
|
385 |
+
|
386 |
+
## sample the frames randomly (if training) or from the start frame (if test)
|
387 |
+
if self.train:
|
388 |
+
self.start_frame_now = self.rng.randint(
|
389 |
+
min(self._start_frame, self._num_frames - self._total_length),
|
390 |
+
self._num_frames - self._total_length + 1)
|
391 |
+
else:
|
392 |
+
self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length)
|
393 |
+
|
394 |
+
frames = [self.start_frame_now + i * step for i in range(self.context_length)]
|
395 |
+
frames += [frames[-1] + gap + i * step for i in range(self.target_length)]
|
396 |
+
|
397 |
+
# breakpoint()
|
398 |
+
|
399 |
+
return frames
|
400 |
+
|
401 |
+
def _decode_frame_images(self, reader, frames):
|
402 |
+
try:
|
403 |
+
video_data = reader.get_batch(frames).asnumpy()
|
404 |
+
video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB')
|
405 |
+
for t, _ in enumerate(frames)]
|
406 |
+
except:
|
407 |
+
raise RuntimeError(
|
408 |
+
"Error occurred in reading frames {} from video {} of duration {}".format(
|
409 |
+
frames, self.index, self._num_frames))
|
410 |
+
return video_data
|
411 |
+
|
412 |
+
def __getitem__(self, index):
|
413 |
+
|
414 |
+
self.index = index
|
415 |
+
self.directory, target = self.clips[index]
|
416 |
+
|
417 |
+
self.video_name = self._get_video_name(self.directory)
|
418 |
+
|
419 |
+
## build decord loader
|
420 |
+
try:
|
421 |
+
decord_vr = decord.VideoReader(self.video_name, num_threads=1)
|
422 |
+
self._set_fps(decord_vr)
|
423 |
+
except:
|
424 |
+
# return self.video_name
|
425 |
+
return (self.__getitem__(index + 1))
|
426 |
+
|
427 |
+
## sample the video
|
428 |
+
self._num_frames = len(decord_vr)
|
429 |
+
self.frames = self._sample_frames()
|
430 |
+
if self.frames is None:
|
431 |
+
print("no movie of length %d for video idx=%d" % (self._total_length, self.index))
|
432 |
+
return self.__getitem__(index + 1)
|
433 |
+
|
434 |
+
## decode to PIL.Image
|
435 |
+
image_list = self._decode_frame_images(decord_vr, self.frames)
|
436 |
+
|
437 |
+
## postproc to torch.Tensor and mask generation
|
438 |
+
if self.transform is None:
|
439 |
+
image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0)
|
440 |
+
else:
|
441 |
+
image_tensor = self.transform((image_list, None))
|
442 |
+
|
443 |
+
image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:])
|
444 |
+
|
445 |
+
## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W]
|
446 |
+
if self._channels_first:
|
447 |
+
image_tensor = image_tensor.transpose(0, 1)
|
448 |
+
|
449 |
+
if self._generate_masks and self.mask_generator is not None:
|
450 |
+
mask = self.mask_generator()
|
451 |
+
return image_tensor, mask.bool()
|
452 |
+
else:
|
453 |
+
return image_tensor
|
cwm/data/dataset_utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from cwm.data.transforms import *
|
3 |
+
from cwm.data.dataset import ContextAndTargetVideoDataset
|
4 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
5 |
+
from cwm.data.masking_generator import RotatedTableMaskingGenerator
|
6 |
+
|
7 |
+
class DataAugmentationForVideoMAE(object):
|
8 |
+
def __init__(self, augmentation_type, input_size, augmentation_scales):
|
9 |
+
|
10 |
+
transform_list = []
|
11 |
+
|
12 |
+
self.scale = GroupScale(input_size)
|
13 |
+
transform_list.append(self.scale)
|
14 |
+
|
15 |
+
if augmentation_type == 'multiscale':
|
16 |
+
self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales))
|
17 |
+
elif augmentation_type == 'center':
|
18 |
+
self.train_augmentation = GroupCenterCrop(input_size)
|
19 |
+
|
20 |
+
transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)])
|
21 |
+
|
22 |
+
# Normalize input images
|
23 |
+
normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
24 |
+
transform_list.append(normalize)
|
25 |
+
|
26 |
+
self.transform = transforms.Compose(transform_list)
|
27 |
+
|
28 |
+
def __call__(self, images):
|
29 |
+
process_data, _ = self.transform(images)
|
30 |
+
return process_data
|
31 |
+
|
32 |
+
def __repr__(self):
|
33 |
+
repr = "(DataAugmentationForVideoMAE,\n"
|
34 |
+
repr += " transform = %s,\n" % str(self.transform)
|
35 |
+
repr += ")"
|
36 |
+
return repr
|
37 |
+
|
38 |
+
|
39 |
+
def build_pretraining_dataset(args):
|
40 |
+
|
41 |
+
dataset_list = []
|
42 |
+
data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales)
|
43 |
+
|
44 |
+
mask_generator = RotatedTableMaskingGenerator(
|
45 |
+
input_size=args.mask_input_size,
|
46 |
+
mask_ratio=args.mask_ratio,
|
47 |
+
tube_length=args.tubelet_size,
|
48 |
+
batch_size=args.batch_size,
|
49 |
+
mask_type=args.mask_type
|
50 |
+
)
|
51 |
+
|
52 |
+
for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list:
|
53 |
+
dataset = ContextAndTargetVideoDataset(
|
54 |
+
root=None,
|
55 |
+
setting=data_path,
|
56 |
+
video_ext='mp4',
|
57 |
+
is_color=True,
|
58 |
+
modality='rgb',
|
59 |
+
context_length=args.context_frames,
|
60 |
+
target_length=args.target_frames,
|
61 |
+
step_units=args.temporal_units,
|
62 |
+
new_step=args.sampling_rate,
|
63 |
+
context_target_gap=args.context_target_gap,
|
64 |
+
transform=data_transform,
|
65 |
+
randomize_interframes=False,
|
66 |
+
channels_first=True,
|
67 |
+
temporal_jitter=False,
|
68 |
+
train=True,
|
69 |
+
mask_generator=mask_generator,
|
70 |
+
)
|
71 |
+
dataset_list.append(dataset)
|
72 |
+
dataset = torch.utils.data.ConcatDataset(dataset_list)
|
73 |
+
return dataset
|
cwm/data/masking_generator.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def get_tubes(masks_per_frame, tube_length):
|
5 |
+
rp = torch.randperm(len(masks_per_frame))
|
6 |
+
masks_per_frame = masks_per_frame[rp]
|
7 |
+
|
8 |
+
tubes = [masks_per_frame]
|
9 |
+
for x in range(tube_length - 1):
|
10 |
+
masks_per_frame = masks_per_frame.clone()
|
11 |
+
rp = torch.randperm(len(masks_per_frame))
|
12 |
+
masks_per_frame = masks_per_frame[rp]
|
13 |
+
tubes.append(masks_per_frame)
|
14 |
+
|
15 |
+
tubes = torch.vstack(tubes)
|
16 |
+
|
17 |
+
return tubes
|
18 |
+
|
19 |
+
class RotatedTableMaskingGenerator:
|
20 |
+
def __init__(self,
|
21 |
+
input_size,
|
22 |
+
mask_ratio,
|
23 |
+
tube_length,
|
24 |
+
batch_size,
|
25 |
+
mask_type='rotated_table',
|
26 |
+
seed=None,
|
27 |
+
randomize_num_visible=False):
|
28 |
+
|
29 |
+
self.batch_size = batch_size
|
30 |
+
|
31 |
+
self.mask_ratio = mask_ratio
|
32 |
+
self.tube_length = tube_length
|
33 |
+
|
34 |
+
self.frames, self.height, self.width = input_size
|
35 |
+
self.num_patches_per_frame = self.height * self.width
|
36 |
+
self.total_patches = self.frames * self.num_patches_per_frame
|
37 |
+
|
38 |
+
self.seed = seed
|
39 |
+
self.randomize_num_visible = randomize_num_visible
|
40 |
+
|
41 |
+
self.mask_type = mask_type
|
42 |
+
|
43 |
+
def __repr__(self):
|
44 |
+
repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format(
|
45 |
+
self.total_patches, self.tube_length, self.randomize_num_visible, self.seed
|
46 |
+
)
|
47 |
+
return repr_str
|
48 |
+
|
49 |
+
def __call__(self, m=None):
|
50 |
+
|
51 |
+
if self.mask_type == 'rotated_table_magvit':
|
52 |
+
self.mask_ratio = np.random.uniform(low=0.0, high=1)
|
53 |
+
self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2)
|
54 |
+
elif self.mask_type == 'rotated_table_maskvit':
|
55 |
+
self.mask_ratio = np.random.uniform(low=0.5, high=1)
|
56 |
+
|
57 |
+
all_masks = []
|
58 |
+
for b in range(self.batch_size):
|
59 |
+
|
60 |
+
self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame))
|
61 |
+
self.total_masks = self.tube_length * self.num_masks_per_frame
|
62 |
+
|
63 |
+
num_masks = self.num_masks_per_frame
|
64 |
+
|
65 |
+
if self.randomize_num_visible:
|
66 |
+
assert "Randomize num visible Not implemented"
|
67 |
+
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
|
68 |
+
|
69 |
+
if self.mask_ratio == 0:
|
70 |
+
mask_per_frame = torch.hstack([
|
71 |
+
torch.zeros(self.num_patches_per_frame - num_masks),
|
72 |
+
])
|
73 |
+
else:
|
74 |
+
mask_per_frame = torch.hstack([
|
75 |
+
torch.zeros(self.num_patches_per_frame - num_masks),
|
76 |
+
torch.ones(num_masks),
|
77 |
+
])
|
78 |
+
|
79 |
+
tubes = get_tubes(mask_per_frame, self.tube_length)
|
80 |
+
top = torch.zeros(self.height * self.width).to(tubes.dtype)
|
81 |
+
|
82 |
+
top = torch.tile(top, (self.frames - self.tube_length, 1))
|
83 |
+
mask = torch.cat([top, tubes])
|
84 |
+
mask = mask.flatten()
|
85 |
+
all_masks.append(mask)
|
86 |
+
return torch.stack(all_masks)
|
cwm/data/transforms.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms.functional as F
|
3 |
+
import warnings
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image, ImageOps
|
8 |
+
import numbers
|
9 |
+
|
10 |
+
|
11 |
+
class GroupRandomCrop(object):
|
12 |
+
def __init__(self, size):
|
13 |
+
if isinstance(size, numbers.Number):
|
14 |
+
self.size = (int(size), int(size))
|
15 |
+
else:
|
16 |
+
self.size = size
|
17 |
+
|
18 |
+
def __call__(self, img_tuple):
|
19 |
+
img_group, label = img_tuple
|
20 |
+
|
21 |
+
w, h = img_group[0].size
|
22 |
+
th, tw = self.size
|
23 |
+
|
24 |
+
out_images = list()
|
25 |
+
|
26 |
+
x1 = random.randint(0, w - tw)
|
27 |
+
y1 = random.randint(0, h - th)
|
28 |
+
|
29 |
+
for img in img_group:
|
30 |
+
assert(img.size[0] == w and img.size[1] == h)
|
31 |
+
if w == tw and h == th:
|
32 |
+
out_images.append(img)
|
33 |
+
else:
|
34 |
+
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
35 |
+
|
36 |
+
return (out_images, label)
|
37 |
+
|
38 |
+
|
39 |
+
class GroupCenterCrop(object):
|
40 |
+
def __init__(self, size):
|
41 |
+
self.worker = torchvision.transforms.CenterCrop(size)
|
42 |
+
|
43 |
+
def __call__(self, img_tuple):
|
44 |
+
img_group, label = img_tuple
|
45 |
+
return ([self.worker(img) for img in img_group], label)
|
46 |
+
|
47 |
+
|
48 |
+
class GroupNormalize(object):
|
49 |
+
def __init__(self, mean, std):
|
50 |
+
self.mean = mean
|
51 |
+
self.std = std
|
52 |
+
|
53 |
+
def __call__(self, tensor_tuple):
|
54 |
+
tensor, label = tensor_tuple
|
55 |
+
rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
|
56 |
+
rep_std = self.std * (tensor.size()[0]//len(self.std))
|
57 |
+
|
58 |
+
# TODO: make efficient
|
59 |
+
for t, m, s in zip(tensor, rep_mean, rep_std):
|
60 |
+
t.sub_(m).div_(s)
|
61 |
+
|
62 |
+
return (tensor,label)
|
63 |
+
|
64 |
+
|
65 |
+
class GroupGrayScale(object):
|
66 |
+
def __init__(self, size):
|
67 |
+
self.worker = torchvision.transforms.Grayscale(size)
|
68 |
+
|
69 |
+
def __call__(self, img_tuple):
|
70 |
+
img_group, label = img_tuple
|
71 |
+
return ([self.worker(img) for img in img_group], label)
|
72 |
+
|
73 |
+
|
74 |
+
class GroupScale(object):
|
75 |
+
""" Rescales the input PIL.Image to the given 'size'.
|
76 |
+
'size' will be the size of the smaller edge.
|
77 |
+
For example, if height > width, then image will be
|
78 |
+
rescaled to (size * height / width, size)
|
79 |
+
size: size of the smaller edge
|
80 |
+
interpolation: Default: PIL.Image.BILINEAR
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
84 |
+
self.worker = torchvision.transforms.Resize(size, interpolation)
|
85 |
+
|
86 |
+
def __call__(self, img_tuple):
|
87 |
+
img_group, label = img_tuple
|
88 |
+
return ([self.worker(img) for img in img_group], label)
|
89 |
+
|
90 |
+
|
91 |
+
class GroupMultiScaleCrop(object):
|
92 |
+
|
93 |
+
def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
|
94 |
+
self.scales = scales if scales is not None else [1, 875, .75, .66]
|
95 |
+
self.max_distort = max_distort
|
96 |
+
self.fix_crop = fix_crop
|
97 |
+
self.more_fix_crop = more_fix_crop
|
98 |
+
self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
|
99 |
+
self.interpolation = Image.BILINEAR
|
100 |
+
|
101 |
+
def __call__(self, img_tuple):
|
102 |
+
img_group, label = img_tuple
|
103 |
+
|
104 |
+
im_size = img_group[0].size
|
105 |
+
|
106 |
+
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
|
107 |
+
crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
|
108 |
+
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
|
109 |
+
return (ret_img_group, label)
|
110 |
+
|
111 |
+
def _sample_crop_size(self, im_size):
|
112 |
+
image_w, image_h = im_size[0], im_size[1]
|
113 |
+
|
114 |
+
# find a crop size
|
115 |
+
base_size = min(image_w, image_h)
|
116 |
+
crop_sizes = [int(base_size * x) for x in self.scales]
|
117 |
+
crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
|
118 |
+
crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
|
119 |
+
|
120 |
+
pairs = []
|
121 |
+
for i, h in enumerate(crop_h):
|
122 |
+
for j, w in enumerate(crop_w):
|
123 |
+
if abs(i - j) <= self.max_distort:
|
124 |
+
pairs.append((w, h))
|
125 |
+
|
126 |
+
crop_pair = random.choice(pairs)
|
127 |
+
if not self.fix_crop:
|
128 |
+
w_offset = random.randint(0, image_w - crop_pair[0])
|
129 |
+
h_offset = random.randint(0, image_h - crop_pair[1])
|
130 |
+
else:
|
131 |
+
w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
|
132 |
+
|
133 |
+
return crop_pair[0], crop_pair[1], w_offset, h_offset
|
134 |
+
|
135 |
+
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
|
136 |
+
offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
|
137 |
+
return random.choice(offsets)
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
|
141 |
+
w_step = (image_w - crop_w) // 4
|
142 |
+
h_step = (image_h - crop_h) // 4
|
143 |
+
|
144 |
+
ret = list()
|
145 |
+
ret.append((0, 0)) # upper left
|
146 |
+
ret.append((4 * w_step, 0)) # upper right
|
147 |
+
ret.append((0, 4 * h_step)) # lower left
|
148 |
+
ret.append((4 * w_step, 4 * h_step)) # lower right
|
149 |
+
ret.append((2 * w_step, 2 * h_step)) # center
|
150 |
+
|
151 |
+
if more_fix_crop:
|
152 |
+
ret.append((0, 2 * h_step)) # center left
|
153 |
+
ret.append((4 * w_step, 2 * h_step)) # center right
|
154 |
+
ret.append((2 * w_step, 4 * h_step)) # lower center
|
155 |
+
ret.append((2 * w_step, 0 * h_step)) # upper center
|
156 |
+
|
157 |
+
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
|
158 |
+
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
|
159 |
+
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
|
160 |
+
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
|
161 |
+
return ret
|
162 |
+
|
163 |
+
|
164 |
+
class Stack(object):
|
165 |
+
|
166 |
+
def __init__(self, roll=False):
|
167 |
+
self.roll = roll
|
168 |
+
|
169 |
+
def __call__(self, img_tuple):
|
170 |
+
img_group, label = img_tuple
|
171 |
+
|
172 |
+
if img_group[0].mode == 'L':
|
173 |
+
return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
|
174 |
+
elif img_group[0].mode == 'RGB':
|
175 |
+
if self.roll:
|
176 |
+
return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
|
177 |
+
else:
|
178 |
+
return (np.concatenate(img_group, axis=2), label)
|
179 |
+
|
180 |
+
|
181 |
+
class ToTorchFormatTensor(object):
|
182 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
183 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
184 |
+
def __init__(self, div=True):
|
185 |
+
self.div = div
|
186 |
+
|
187 |
+
def __call__(self, pic_tuple):
|
188 |
+
pic, label = pic_tuple
|
189 |
+
|
190 |
+
if isinstance(pic, np.ndarray):
|
191 |
+
# handle numpy array
|
192 |
+
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
|
193 |
+
else:
|
194 |
+
# handle PIL Image
|
195 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
196 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
197 |
+
# put it from HWC to CHW format
|
198 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
199 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
200 |
+
return (img.float().div(255.) if self.div else img.float(), label)
|
201 |
+
|
202 |
+
|
203 |
+
class IdentityTransform(object):
|
204 |
+
|
205 |
+
def __call__(self, data):
|
206 |
+
return data
|
cwm/data/video_file_lists/kinetics_400_train_list.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65e14c0735b4c90c57022add2407a8524d246cc09b3d5a7e83b963ac3b231032
|
3 |
+
size 19539143
|
cwm/data/video_file_lists/kinetics_400_train_list_sing.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b18ccdce4616fb32a0aababc2342a640a11f7d73439f49358a16cc99e7eaed3
|
3 |
+
size 1943
|
cwm/engine_for_pretraining.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Iterable
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
9 |
+
|
10 |
+
import utils
|
11 |
+
|
12 |
+
from datetime import datetime
|
13 |
+
|
14 |
+
|
15 |
+
def train_one_epoch(model: torch.nn.Module,
|
16 |
+
data_loader: Iterable,
|
17 |
+
optimizer: torch.optim.Optimizer,
|
18 |
+
device: torch.device,
|
19 |
+
epoch: int,
|
20 |
+
loss_scaler,
|
21 |
+
start_steps=None,
|
22 |
+
lr_schedule_values=None,
|
23 |
+
wd_schedule_values=None,
|
24 |
+
global_rank=None,
|
25 |
+
args=None,
|
26 |
+
loss_func = nn.MSELoss(),
|
27 |
+
):
|
28 |
+
|
29 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
30 |
+
|
31 |
+
if args.eval:
|
32 |
+
model.eval()
|
33 |
+
else:
|
34 |
+
model.train()
|
35 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
36 |
+
|
37 |
+
header = f'Epoch [{epoch}]'
|
38 |
+
patch_size = model.module.encoder.patch_size[-2:]
|
39 |
+
tubelet_size = model.module.encoder.patch_size[0]
|
40 |
+
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
|
41 |
+
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
|
42 |
+
|
43 |
+
for step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
44 |
+
|
45 |
+
# assign learning rate & weight decay for each iteration
|
46 |
+
it = start_steps + step # global training iteration
|
47 |
+
if (lr_schedule_values is not None or wd_schedule_values is not None) and (step % args.accum_iter == 0):
|
48 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
49 |
+
if lr_schedule_values is not None:
|
50 |
+
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
|
51 |
+
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
|
52 |
+
param_group["weight_decay"] = wd_schedule_values[it]
|
53 |
+
|
54 |
+
# prepare input
|
55 |
+
videos, bool_masked_pos = batch
|
56 |
+
videos = videos.to(device, non_blocking=True)
|
57 |
+
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1)
|
58 |
+
|
59 |
+
# prepare target
|
60 |
+
with torch.no_grad():
|
61 |
+
unnorm_videos = videos * std + mean # in [0, 1]
|
62 |
+
videos_patch = utils.patchify(unnorm_videos, tubelet_size, patch_size)
|
63 |
+
B, _, C = videos_patch.shape
|
64 |
+
labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
|
65 |
+
|
66 |
+
# feedforward
|
67 |
+
with torch.cuda.amp.autocast(enabled=True):
|
68 |
+
outputs = model(videos, bool_masked_pos)
|
69 |
+
loss = loss_func(input=outputs, target=labels)
|
70 |
+
|
71 |
+
loss_value = loss.item()
|
72 |
+
|
73 |
+
# backward
|
74 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
75 |
+
loss /= args.accum_iter
|
76 |
+
loss_scaler(loss, optimizer, clip_grad=None,
|
77 |
+
parameters=model.parameters(), create_graph=is_second_order,
|
78 |
+
update_grad=(step + 1) % args.accum_iter == 0)
|
79 |
+
|
80 |
+
torch.cuda.synchronize()
|
81 |
+
metric_logger.update(loss=loss_value)
|
82 |
+
|
83 |
+
if (step + 1) % args.accum_iter == 0:
|
84 |
+
optimizer.zero_grad()
|
85 |
+
|
86 |
+
lr = optimizer.param_groups[0]["lr"]
|
87 |
+
metric_logger.update(lr=lr)
|
88 |
+
|
89 |
+
# gather the stats from all processes
|
90 |
+
metric_logger.synchronize_between_processes()
|
91 |
+
print("Averaged stats:", metric_logger)
|
92 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
cwm/eval/Action_recognition/__init__.py
ADDED
File without changes
|
cwm/eval/Flow/__init__.py
ADDED
File without changes
|
cwm/eval/Flow/create_spring_submission_parallel.sh
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Define the path to the dataset and the Python script
|
4 |
+
DATASET_PATH="/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/"
|
5 |
+
SCRIPT_PATH="./create_spring_submission_unified.py"
|
6 |
+
SAVE_DATA_PATH=${1}
|
7 |
+
MODEL=${2}
|
8 |
+
# Counter for GPUs
|
9 |
+
GPU_COUNTER=0
|
10 |
+
|
11 |
+
# Number of GPUs available
|
12 |
+
NUM_GPUS=8
|
13 |
+
|
14 |
+
#kill session
|
15 |
+
tmux kill-session -t extraction
|
16 |
+
|
17 |
+
tmux new-session -d -s "extraction"
|
18 |
+
|
19 |
+
# Iterate through each folder in the dataset
|
20 |
+
for FOLDER in $(find $DATASET_PATH -mindepth 1 -maxdepth 1 -type d | sort); do
|
21 |
+
# Extract the folder name for the tmux session name
|
22 |
+
FOLDER_NAME=$(basename $FOLDER)
|
23 |
+
|
24 |
+
# Create a new detached tmux session for each folder
|
25 |
+
tmux new-window -t extraction -n "$FOLDER" "ulimit -n 65535; CUDA_VISIBLE_DEVICES=$GPU_COUNTER python $SCRIPT_PATH --folder $FOLDER --gpu $GPU_COUNTER --save_data_path $SAVE_DATA_PATH --model $MODEL; echo 'Press Enter to continue...'; read -p ''"
|
26 |
+
# Increment the GPU counter and reset if it exceeds the number of GPUs
|
27 |
+
GPU_COUNTER=$((GPU_COUNTER + 1))
|
28 |
+
if [ $GPU_COUNTER -ge $NUM_GPUS ]; then
|
29 |
+
GPU_COUNTER=0
|
30 |
+
fi
|
31 |
+
|
32 |
+
sleep 1
|
33 |
+
done
|
34 |
+
|
35 |
+
tmux attach-session -t extraction
|
36 |
+
|
cwm/eval/Flow/create_spring_submission_unified.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
# Parse command-line arguments
|
5 |
+
import importlib
|
6 |
+
import time
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser(description='Process a folder with RAFT')
|
9 |
+
parser.add_argument('--folder', type=str, required=True, help='Folder to process')
|
10 |
+
parser.add_argument('--model', type=str, required=True, help='Model used to extract flow')
|
11 |
+
parser.add_argument('--save_data_path', type=str, required=True, help='where to save the data')
|
12 |
+
parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
|
13 |
+
args = parser.parse_args()
|
14 |
+
import os
|
15 |
+
|
16 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
17 |
+
import torch
|
18 |
+
torch.cuda.set_device(0)
|
19 |
+
|
20 |
+
import h5py
|
21 |
+
|
22 |
+
def writeFlo5File(flow, filename):
|
23 |
+
with h5py.File(filename, "w") as f:
|
24 |
+
f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
module_name, class_name = args.model.rsplit(".", 1)
|
28 |
+
module = importlib.import_module(module_name)
|
29 |
+
|
30 |
+
model = getattr(module, class_name)
|
31 |
+
model = model().cuda().eval()
|
32 |
+
|
33 |
+
folder = args.folder.split('/')[-1]
|
34 |
+
|
35 |
+
import os
|
36 |
+
import matplotlib.pyplot as plt
|
37 |
+
|
38 |
+
import torch
|
39 |
+
import torchvision.transforms as transforms
|
40 |
+
|
41 |
+
# import smurf # Assuming this is your custom inference module
|
42 |
+
|
43 |
+
# Path for the dataset
|
44 |
+
dataset_path = '/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/'
|
45 |
+
|
46 |
+
save_data_path = args.save_data_path
|
47 |
+
|
48 |
+
if not os.path.exists(save_data_path):
|
49 |
+
os.makedirs(save_data_path)
|
50 |
+
|
51 |
+
resize_crop = transforms.Compose([
|
52 |
+
transforms.ToTensor(),
|
53 |
+
])
|
54 |
+
|
55 |
+
import numpy as np
|
56 |
+
|
57 |
+
def l2norm(x):
|
58 |
+
return np.sqrt((x ** 2).sum(-1))
|
59 |
+
|
60 |
+
all_epe = []
|
61 |
+
# Create a new HDF5 file
|
62 |
+
|
63 |
+
TAG_FLOAT = 202021.25
|
64 |
+
|
65 |
+
# Iterate over each folder in the dataset directory
|
66 |
+
for dir in ['FW', 'BW']:
|
67 |
+
for stereo in ['left', 'right']:
|
68 |
+
files = sorted(os.listdir(os.path.join(dataset_path, folder, f'frame_{stereo}')))
|
69 |
+
output_folder = os.path.join(save_data_path, folder)
|
70 |
+
output_folder = os.path.join(output_folder, f'flow_{dir}_{stereo}')
|
71 |
+
|
72 |
+
if not os.path.exists(output_folder):
|
73 |
+
os.makedirs(output_folder)
|
74 |
+
|
75 |
+
for ct_f in range(len(files) - 1):
|
76 |
+
# Read images
|
77 |
+
if dir == 'FW':
|
78 |
+
f1 = files[ct_f]
|
79 |
+
f2 = files[ct_f + 1]
|
80 |
+
else:
|
81 |
+
f2 = files[ct_f]
|
82 |
+
f1 = files[ct_f + 1]
|
83 |
+
t = time.time()
|
84 |
+
image1_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f1)
|
85 |
+
image2_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f2)
|
86 |
+
|
87 |
+
idx = image1_path.split('/')[-1].split('.')[0].split('_')[-1]
|
88 |
+
flow_save_path = os.path.join(output_folder, f'flow_{dir}_{stereo}_' + idx + '.flo5')
|
89 |
+
|
90 |
+
# if os.path.exists(flow_save_path):
|
91 |
+
# try:
|
92 |
+
# with h5py.File(flow_save_path, 'r+') as f:
|
93 |
+
# if f['flow'][:].shape[0] == 2:
|
94 |
+
# flow = f['flow'][:].transpose([1, 2, 0])
|
95 |
+
# del f['flow']
|
96 |
+
# f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
|
97 |
+
# continue
|
98 |
+
# else:
|
99 |
+
# continue
|
100 |
+
# except:
|
101 |
+
# pass
|
102 |
+
|
103 |
+
image1_ = plt.imread(image1_path)
|
104 |
+
image2_ = plt.imread(image2_path)
|
105 |
+
|
106 |
+
image1 = resize_crop(image1_)
|
107 |
+
image2 = resize_crop(image2_)
|
108 |
+
|
109 |
+
forward_flow = model.forward(image1, image2)
|
110 |
+
|
111 |
+
writeFlo5File(forward_flow, flow_save_path)
|
cwm/eval/Flow/flow_extraction_classes.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import cwm.model.model_pretrain as vmae_tranformers
|
5 |
+
from . import flow_utils
|
6 |
+
from . import losses as bblosses
|
7 |
+
|
8 |
+
|
9 |
+
# Normal Resolution
|
10 |
+
def l2_norm(x):
|
11 |
+
return x.square().sum(-3, True).sqrt()
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
# x.shape
|
17 |
+
def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
|
18 |
+
fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
|
19 |
+
flow_diff_fwd = flow_fwd + fwd_bck_cycle
|
20 |
+
|
21 |
+
bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
|
22 |
+
flow_diff_bck = flow_bck + bck_fwd_cycle
|
23 |
+
|
24 |
+
norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
|
25 |
+
norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
|
26 |
+
|
27 |
+
occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
|
28 |
+
occ_thresh_bck = occ_thresh * norm_bck + 0.5
|
29 |
+
|
30 |
+
occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
|
31 |
+
occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
|
32 |
+
|
33 |
+
return occ_mask_fwd, occ_mask_bck
|
34 |
+
|
35 |
+
|
36 |
+
class ExtractFlow(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
super().__init__()
|
40 |
+
return
|
41 |
+
|
42 |
+
def forward(self, img1, img2):
|
43 |
+
'''
|
44 |
+
img1: first frame
|
45 |
+
img2: second frame
|
46 |
+
returns: flow map (h, w, 2)
|
47 |
+
'''
|
48 |
+
|
49 |
+
from cwm.data.masking_generator import RotatedTableMaskingGenerator
|
50 |
+
|
51 |
+
class CWM(ExtractFlow):
|
52 |
+
def __init__(self, model_name, patch_size, weights_path):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.patch_size = patch_size
|
56 |
+
model = getattr(vmae_tranformers, model_name)
|
57 |
+
vmae_8x8_full = model().cuda().eval().requires_grad_(False)
|
58 |
+
|
59 |
+
VMAE_LOAD_PATH = weights_path
|
60 |
+
did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False)
|
61 |
+
print(did_load, VMAE_LOAD_PATH)
|
62 |
+
|
63 |
+
self.predictor = vmae_8x8_full
|
64 |
+
|
65 |
+
self.mask_generator = RotatedTableMaskingGenerator(
|
66 |
+
input_size=(vmae_8x8_full.num_frames, 28, 28),
|
67 |
+
mask_ratio=0.0,
|
68 |
+
tube_length=1,
|
69 |
+
batch_size=1,
|
70 |
+
mask_type='rotated_table'
|
71 |
+
)
|
72 |
+
|
73 |
+
def forward(self, img1, img2):
|
74 |
+
'''
|
75 |
+
img1: [3, 1024, 1024]
|
76 |
+
img1: [3, 1024, 1024]
|
77 |
+
both images are imagenet normalized
|
78 |
+
'''
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
|
82 |
+
self.mask_generator, img1[None],
|
83 |
+
img2[None],
|
84 |
+
num_scales=2,
|
85 |
+
min_scale=224,
|
86 |
+
N_mask_samples=1)
|
87 |
+
|
88 |
+
BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
|
89 |
+
self.mask_generator,
|
90 |
+
img2[None],
|
91 |
+
img1[None],
|
92 |
+
num_scales=2,
|
93 |
+
min_scale=224,
|
94 |
+
N_mask_samples=1)
|
95 |
+
|
96 |
+
# FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
|
97 |
+
# self.mask_generator, img1[None],
|
98 |
+
# img2[None], img2[None],
|
99 |
+
# neg_back_flow=True, num_scales=1,
|
100 |
+
# min_scale=224, N_mask_samples=1,
|
101 |
+
# mask_ratio=0.0)
|
102 |
+
#
|
103 |
+
# BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
|
104 |
+
# self.mask_generator, img2[None],
|
105 |
+
# img1[None], img1[None],
|
106 |
+
# neg_back_flow=True, num_scales=1,
|
107 |
+
# min_scale=224, N_mask_samples=1,
|
108 |
+
# mask_ratio=0.0)
|
109 |
+
|
110 |
+
occ_mask = get_occ_masks(FF, BF)[0]
|
111 |
+
|
112 |
+
FF = FF * occ_mask
|
113 |
+
|
114 |
+
FF = FF[0]
|
115 |
+
|
116 |
+
return FF#.cpu().numpy().transpose([1, 2, 0])
|
117 |
+
|
118 |
+
|
119 |
+
class CWM_8x8(CWM):
|
120 |
+
def __init__(self):
|
121 |
+
super().__init__('vitb_8x8patch_3frames', 8,
|
122 |
+
'/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth')
|
cwm/eval/Flow/flow_utils.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from . import losses as bblosses
|
7 |
+
import kornia
|
8 |
+
|
9 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
10 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
11 |
+
|
12 |
+
def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
|
13 |
+
# Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
|
14 |
+
mask_unrolled = mask_tensor.view(-1)
|
15 |
+
|
16 |
+
second_frame_unmask_indices = torch.where(mask_unrolled[frame_size ** 2:] == False)[0]
|
17 |
+
|
18 |
+
# Divide the embedding tensor into two parts: corresponding to the first and the second frame
|
19 |
+
first_frame_embeddings = embedding_tensor[0, :frame_size ** 2, :]
|
20 |
+
second_frame_embeddings = embedding_tensor[0, frame_size ** 2:, :]
|
21 |
+
|
22 |
+
# print(first_frame_embeddings.shape, second_frame_embeddings.shape, embedding_tensor.shape)
|
23 |
+
|
24 |
+
# Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
|
25 |
+
dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
|
26 |
+
norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
|
27 |
+
cos_sim_matrix = dot_product / norms
|
28 |
+
|
29 |
+
# Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
|
30 |
+
first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
|
31 |
+
|
32 |
+
# Convert the 1D pixel indices into 2D coordinates
|
33 |
+
second_frame_y = second_frame_unmask_indices // frame_size
|
34 |
+
second_frame_x = second_frame_unmask_indices % frame_size
|
35 |
+
first_frame_y = first_frame_most_similar_indices // frame_size
|
36 |
+
first_frame_x = first_frame_most_similar_indices % frame_size
|
37 |
+
|
38 |
+
# Compute the x and y displacements and convert them to float
|
39 |
+
displacements_x = (second_frame_x - first_frame_x).float()
|
40 |
+
displacements_y = (second_frame_y - first_frame_y).float()
|
41 |
+
|
42 |
+
# Initialize optical flow tensor
|
43 |
+
optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
|
44 |
+
|
45 |
+
# Assign the computed displacements to the corresponding pixels in the optical flow tensor
|
46 |
+
optical_flow[0, second_frame_y, second_frame_x] = displacements_x
|
47 |
+
optical_flow[1, second_frame_y, second_frame_x] = displacements_y
|
48 |
+
|
49 |
+
return optical_flow
|
50 |
+
|
51 |
+
|
52 |
+
def get_minimal_224_crops_new_batched(video_tensor, N):
|
53 |
+
B, T, C, H, W = video_tensor.shape
|
54 |
+
|
55 |
+
# Calculate the number of crops needed in both the height and width dimensions
|
56 |
+
num_crops_h = math.ceil(H / 224) if H > 224 else 1
|
57 |
+
num_crops_w = math.ceil(W / 224) if W > 224 else 1
|
58 |
+
|
59 |
+
# Calculate the step size for the height and width dimensions
|
60 |
+
step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
|
61 |
+
step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
|
62 |
+
|
63 |
+
# Create a list to store the cropped tensors and their start positions
|
64 |
+
cropped_tensors = []
|
65 |
+
crop_positions = []
|
66 |
+
|
67 |
+
# Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
|
68 |
+
for i in range(num_crops_h):
|
69 |
+
for j in range(num_crops_w):
|
70 |
+
start_h = i * step_size_h
|
71 |
+
start_w = j * step_size_w
|
72 |
+
end_h = min(start_h + 224, H)
|
73 |
+
end_w = min(start_w + 224, W)
|
74 |
+
crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
|
75 |
+
cropped_tensors.append(crop)
|
76 |
+
crop_positions.append((start_h, start_w))
|
77 |
+
|
78 |
+
D = len(cropped_tensors)
|
79 |
+
|
80 |
+
# If N is greater than D, generate additional random crops
|
81 |
+
if N > D and H > 224 and W > 224: # check if H and W are greater than 224
|
82 |
+
for _ in range(N - D):
|
83 |
+
start_h = random.randint(0, H - 224)
|
84 |
+
start_w = random.randint(0, W - 224)
|
85 |
+
crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
|
86 |
+
cropped_tensors.append(crop)
|
87 |
+
crop_positions.append((start_h, start_w))
|
88 |
+
|
89 |
+
# Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
|
90 |
+
cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
|
91 |
+
|
92 |
+
return cropped_tensors, crop_positions
|
93 |
+
|
94 |
+
|
95 |
+
def create_weighted_mask_batched(h, w):
|
96 |
+
y_mask = np.linspace(0, 1, h)
|
97 |
+
y_mask = np.minimum(y_mask, 1 - y_mask)
|
98 |
+
x_mask = np.linspace(0, 1, w)
|
99 |
+
x_mask = np.minimum(x_mask, 1 - x_mask)
|
100 |
+
weighted_mask = np.outer(y_mask, x_mask)
|
101 |
+
return torch.from_numpy(weighted_mask).float()
|
102 |
+
|
103 |
+
|
104 |
+
def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
|
105 |
+
B, T, C, H, W = original_shape
|
106 |
+
|
107 |
+
# Initialize an empty tensor to store the reconstructed video
|
108 |
+
reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
|
109 |
+
|
110 |
+
# Create a tensor to store the sum of weighted masks
|
111 |
+
weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
|
112 |
+
|
113 |
+
# Create a weighted mask for the crops
|
114 |
+
weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
|
115 |
+
weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
|
116 |
+
|
117 |
+
for idx, crop in enumerate(cropped_tensors):
|
118 |
+
start_h, start_w = crop_positions[idx]
|
119 |
+
|
120 |
+
# Multiply the crop with the weighted mask
|
121 |
+
weighted_crop = crop * weighted_mask
|
122 |
+
|
123 |
+
# Add the weighted crop to the corresponding location in the reconstructed_video tensor
|
124 |
+
reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
|
125 |
+
|
126 |
+
# Update the weighted_masks_sum tensor
|
127 |
+
weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
|
128 |
+
|
129 |
+
# Add a small epsilon value to avoid division by zero
|
130 |
+
epsilon = 1e-8
|
131 |
+
|
132 |
+
# Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
|
133 |
+
reconstructed_video /= (weighted_masks_sum + epsilon)
|
134 |
+
|
135 |
+
return reconstructed_video
|
136 |
+
|
137 |
+
|
138 |
+
def l2_norm(x):
|
139 |
+
return x.square().sum(-3, True).sqrt()
|
140 |
+
|
141 |
+
|
142 |
+
resize = lambda x, a: F.interpolate(x, [int(a * x.shape[-2]), int(a * x.shape[-1])], mode='bilinear',
|
143 |
+
align_corners=False)
|
144 |
+
|
145 |
+
upsample = lambda x, H, W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
|
146 |
+
|
147 |
+
|
148 |
+
def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
|
149 |
+
fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
|
150 |
+
flow_diff_fwd = flow_fwd + fwd_bck_cycle
|
151 |
+
|
152 |
+
bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
|
153 |
+
flow_diff_bck = flow_bck + bck_fwd_cycle
|
154 |
+
|
155 |
+
norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
|
156 |
+
norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
|
157 |
+
|
158 |
+
occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
|
159 |
+
occ_thresh_bck = occ_thresh * norm_bck + 0.5
|
160 |
+
|
161 |
+
occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
|
162 |
+
occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
|
163 |
+
|
164 |
+
return occ_mask_fwd, occ_mask_bck
|
165 |
+
|
166 |
+
def forward_backward_cycle_consistency(flow_fwd, flow_bck, niters=10):
|
167 |
+
# Make sure to be using axes-swapped, upsampled flows!
|
168 |
+
bck_flow_clone = flow_bck.clone().detach()
|
169 |
+
fwd_flow_clone = flow_fwd.clone().detach()
|
170 |
+
|
171 |
+
for i in range(niters):
|
172 |
+
|
173 |
+
fwd_bck_cycle_orig, _ = bblosses.backward_warp(img2=bck_flow_clone, flow=fwd_flow_clone)
|
174 |
+
flow_diff_fwd_orig = fwd_flow_clone + fwd_bck_cycle_orig
|
175 |
+
|
176 |
+
fwd_flow_clone = fwd_flow_clone - flow_diff_fwd_orig/2
|
177 |
+
|
178 |
+
bck_fwd_cycle_orig, _ = bblosses.backward_warp(img2=fwd_flow_clone, flow=bck_flow_clone)
|
179 |
+
flow_diff_bck_orig = bck_flow_clone + bck_fwd_cycle_orig
|
180 |
+
|
181 |
+
|
182 |
+
bck_flow_clone = bck_flow_clone - flow_diff_bck_orig/2
|
183 |
+
|
184 |
+
return fwd_flow_clone, bck_flow_clone
|
185 |
+
|
186 |
+
from PIL import Image
|
187 |
+
def resize_flow_map(flow_map, target_size):
|
188 |
+
"""
|
189 |
+
Resize a flow map to a target size while adjusting the flow vectors.
|
190 |
+
|
191 |
+
Parameters:
|
192 |
+
flow_map (numpy.ndarray): Input flow map of shape (H, W, 2) where each pixel contains a (dx, dy) flow vector.
|
193 |
+
target_size (tuple): Target size (height, width) for the resized flow map.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
numpy.ndarray: Resized and scaled flow map of shape (target_size[0], target_size[1], 2).
|
197 |
+
"""
|
198 |
+
# Get the original size
|
199 |
+
flow_map = flow_map[0].detach().cpu().numpy()
|
200 |
+
flow_map = flow_map.transpose(1, 2, 0)
|
201 |
+
original_size = flow_map.shape[:2]
|
202 |
+
|
203 |
+
# Separate the flow map into two channels: dx and dy
|
204 |
+
flow_map_x = flow_map[:, :, 0]
|
205 |
+
flow_map_y = flow_map[:, :, 1]
|
206 |
+
|
207 |
+
# Convert each flow channel to a PIL image for resizing
|
208 |
+
flow_map_x_img = Image.fromarray(flow_map_x)
|
209 |
+
flow_map_y_img = Image.fromarray(flow_map_y)
|
210 |
+
|
211 |
+
# Resize both channels to the target size using bilinear interpolation
|
212 |
+
flow_map_x_resized = flow_map_x_img.resize(target_size, Image.BILINEAR)
|
213 |
+
flow_map_y_resized = flow_map_y_img.resize(target_size, Image.BILINEAR)
|
214 |
+
|
215 |
+
# Convert resized PIL images back to NumPy arrays
|
216 |
+
flow_map_x_resized = np.array(flow_map_x_resized)
|
217 |
+
flow_map_y_resized = np.array(flow_map_y_resized)
|
218 |
+
|
219 |
+
# Compute the scaling factor based on the size change
|
220 |
+
scale_factor = target_size[0] / original_size[0] # Scaling factor for both dx and dy
|
221 |
+
|
222 |
+
# Scale the flow vectors (dx and dy) accordingly
|
223 |
+
flow_map_x_resized *= scale_factor
|
224 |
+
flow_map_y_resized *= scale_factor
|
225 |
+
|
226 |
+
# Recombine the two channels into a resized flow map
|
227 |
+
flow_map_resized = np.stack([flow_map_x_resized, flow_map_y_resized], axis=-1)
|
228 |
+
|
229 |
+
flow_map_resized = torch.from_numpy(flow_map_resized)[None].permute(0, 3, 1, 2)
|
230 |
+
|
231 |
+
return flow_map_resized
|
232 |
+
|
233 |
+
def get_vmae_optical_flow_crop_batched_smoothed(generator,
|
234 |
+
mask_generator,
|
235 |
+
img1,
|
236 |
+
img2,
|
237 |
+
neg_back_flow=True,
|
238 |
+
num_scales=1,
|
239 |
+
min_scale=400,
|
240 |
+
N_mask_samples=100,
|
241 |
+
mask_ratio=0.8,
|
242 |
+
smoothing_factor=1):
|
243 |
+
|
244 |
+
##### DEPRECATED
|
245 |
+
print('Deprecated. Please use scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed')
|
246 |
+
|
247 |
+
return scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
|
248 |
+
mask_generator,
|
249 |
+
img1,
|
250 |
+
img2,
|
251 |
+
neg_back_flow=neg_back_flow,
|
252 |
+
num_scales=num_scales,
|
253 |
+
min_scale=min_scale,
|
254 |
+
N_mask_samples=N_mask_samples,
|
255 |
+
mask_ratio=mask_ratio,
|
256 |
+
smoothing_factor=smoothing_factor)
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
def average_crops(tensor, D):
|
261 |
+
C, H, W = tensor.shape
|
262 |
+
|
263 |
+
# Create zero-filled tensors for the shifted crops
|
264 |
+
down_shifted = torch.zeros_like(tensor)
|
265 |
+
up_shifted = torch.zeros_like(tensor)
|
266 |
+
right_shifted = torch.zeros_like(tensor)
|
267 |
+
left_shifted = torch.zeros_like(tensor)
|
268 |
+
|
269 |
+
# Shift the tensor and store the results in the zero-filled tensors
|
270 |
+
down_shifted[:, :H-D, :] = tensor[:, D:, :]
|
271 |
+
up_shifted[:, D:, :] = tensor[:, :H-D, :]
|
272 |
+
right_shifted[:, :, :W-D] = tensor[:, :, D:]
|
273 |
+
left_shifted[:, :, D:] = tensor[:, :, :W-D]
|
274 |
+
|
275 |
+
# Average the tensor with its four crops
|
276 |
+
result = (tensor + down_shifted + up_shifted + right_shifted + left_shifted) / 5.0
|
277 |
+
|
278 |
+
return result
|
279 |
+
|
280 |
+
|
281 |
+
def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(predictor,
|
282 |
+
mask_generator,
|
283 |
+
img1,
|
284 |
+
img2,
|
285 |
+
conditioning_img=None,
|
286 |
+
num_scales=1,
|
287 |
+
min_scale=400,
|
288 |
+
N_mask_samples=100,
|
289 |
+
smoothing_factor=1):
|
290 |
+
B = img1.shape[0]
|
291 |
+
assert len(img1.shape) == 4
|
292 |
+
assert num_scales >= 1
|
293 |
+
|
294 |
+
# For scaling
|
295 |
+
h1 = img2.shape[-2]
|
296 |
+
w1 = img2.shape[-1]
|
297 |
+
|
298 |
+
|
299 |
+
alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
|
300 |
+
|
301 |
+
frame_size = 224 // predictor.patch_size[-1]
|
302 |
+
|
303 |
+
patch_size = predictor.patch_size[-1]
|
304 |
+
|
305 |
+
num_frames = predictor.num_frames
|
306 |
+
|
307 |
+
all_fwd_flows_e2d = []
|
308 |
+
|
309 |
+
s_hs = []
|
310 |
+
s_ws = []
|
311 |
+
|
312 |
+
for aidx in range(num_scales):
|
313 |
+
# print(aidx)
|
314 |
+
|
315 |
+
# print('aidx: ', aidx)
|
316 |
+
|
317 |
+
img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
|
318 |
+
mode='bicubic', align_corners=True)
|
319 |
+
img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
|
320 |
+
mode='bicubic', align_corners=True)
|
321 |
+
|
322 |
+
if conditioning_img is not None:
|
323 |
+
conditioning_img_scaled = F.interpolate(conditioning_img.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
|
324 |
+
mode='bilinear', align_corners=False)
|
325 |
+
|
326 |
+
# print("img1_scaled", img1_scaled.shape, alpha, min_scale, num_scales)
|
327 |
+
|
328 |
+
h2 = img2_scaled.shape[-2]
|
329 |
+
w2 = img2_scaled.shape[-1]
|
330 |
+
|
331 |
+
s_h = h1 / h2
|
332 |
+
s_w = w1 / w2
|
333 |
+
|
334 |
+
s_hs.append(s_h)
|
335 |
+
s_ws.append(s_w)
|
336 |
+
|
337 |
+
if conditioning_img is not None:
|
338 |
+
video = torch.cat([conditioning_img_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
|
339 |
+
else:
|
340 |
+
video = torch.cat([img2_scaled.unsqueeze(1)]*(num_frames-1) + [img1_scaled.unsqueeze(1)], 1)
|
341 |
+
|
342 |
+
# Should work, even if the incoming video is already 224x224
|
343 |
+
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
|
344 |
+
|
345 |
+
num_crops = len(crops1)
|
346 |
+
|
347 |
+
crop_flows_enc = []
|
348 |
+
crop_flows_enc2dec = []
|
349 |
+
N_samples = N_mask_samples
|
350 |
+
|
351 |
+
crop = torch.cat(crops1, 0).cuda()
|
352 |
+
|
353 |
+
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
|
354 |
+
mask_counts = torch.zeros(frame_size, frame_size).cuda()
|
355 |
+
|
356 |
+
i = 0
|
357 |
+
while i < N_samples or (mask_counts == 0).any().item():
|
358 |
+
if i % 100 == 0:
|
359 |
+
pass # print(i)
|
360 |
+
|
361 |
+
# This would be that every sample has the same mask. For now that's okay I think
|
362 |
+
mask = mask_generator().bool().cuda()
|
363 |
+
mask_2f = ~mask[0, (frame_size * frame_size)*(num_frames-1):]
|
364 |
+
mask_counts += mask_2f.reshape(frame_size, frame_size)
|
365 |
+
|
366 |
+
with torch.cuda.amp.autocast(enabled=True):
|
367 |
+
|
368 |
+
processed_x = crop.transpose(1, 2)
|
369 |
+
|
370 |
+
encoder_out = predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
|
371 |
+
encoder_to_decoder = predictor.encoder_to_decoder(encoder_out)
|
372 |
+
|
373 |
+
encoder_to_decoder = encoder_to_decoder[:, (frame_size * frame_size)*(num_frames-2):, :]
|
374 |
+
flow_mask = mask[:, (frame_size * frame_size)*(num_frames-2):]
|
375 |
+
|
376 |
+
optical_flow_e2d = []
|
377 |
+
# one per batch element for now
|
378 |
+
for b in range(B * num_crops):
|
379 |
+
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
|
380 |
+
# optical_flow_e2d.append(batch_flow.unsqueeze(0))
|
381 |
+
|
382 |
+
optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
|
383 |
+
|
384 |
+
optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
|
385 |
+
optical_flows_enc2dec += optical_flow_e2d
|
386 |
+
i += 1
|
387 |
+
|
388 |
+
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
|
389 |
+
|
390 |
+
#other fucntion
|
391 |
+
# scale_factor_y = video.shape[-2] / 224
|
392 |
+
# scale_factor_x = video.shape[-1] / 224
|
393 |
+
#
|
394 |
+
# scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
|
395 |
+
# scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
|
396 |
+
# scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
|
397 |
+
#
|
398 |
+
# # split the crops back up
|
399 |
+
# crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
|
400 |
+
|
401 |
+
###
|
402 |
+
#Kevin's fn
|
403 |
+
crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
|
404 |
+
|
405 |
+
###
|
406 |
+
|
407 |
+
#Changed by Kevin
|
408 |
+
T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
|
409 |
+
crop_flows_enc2dec]
|
410 |
+
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
|
411 |
+
B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
|
412 |
+
|
413 |
+
#other function
|
414 |
+
# optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
|
415 |
+
# [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
|
416 |
+
# crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
|
417 |
+
#
|
418 |
+
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
|
419 |
+
|
420 |
+
#other function
|
421 |
+
# all_fwd_flows_e2d_new = []
|
422 |
+
#
|
423 |
+
# for r in all_fwd_flows_e2d:
|
424 |
+
# new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
|
425 |
+
# all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
|
426 |
+
# return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
|
427 |
+
#
|
428 |
+
#
|
429 |
+
# return_flow = -return_flow
|
430 |
+
# all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
|
431 |
+
#
|
432 |
+
# return return_flow, all_fwd_flows_e2d_new
|
433 |
+
|
434 |
+
#Kevin's method
|
435 |
+
all_fwd_flows_e2d_new = []
|
436 |
+
|
437 |
+
for ridx, r in enumerate(all_fwd_flows_e2d):
|
438 |
+
# print('ridx', ridx)
|
439 |
+
# print('sh', s_hs[ridx])
|
440 |
+
# print('sw', s_ws[ridx])
|
441 |
+
# print('scale_fac y', scale_ys[ridx])
|
442 |
+
# print('scale_fac x', scale_xs[ridx])
|
443 |
+
|
444 |
+
_sh = s_hs[ridx]
|
445 |
+
_sw = s_ws[ridx]
|
446 |
+
_sfy = predictor.patch_size[-1]
|
447 |
+
_sfx = predictor.patch_size[-1]
|
448 |
+
|
449 |
+
# plt.figure(figsize=(20, 20))
|
450 |
+
|
451 |
+
# plt.subplot(1,3,1)
|
452 |
+
# plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
|
453 |
+
|
454 |
+
# plt.subplot(1,3,2)
|
455 |
+
new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], mode='bicubic', align_corners=True)
|
456 |
+
# plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
|
457 |
+
|
458 |
+
scaled_new_r = torch.zeros_like(new_r)
|
459 |
+
scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
|
460 |
+
scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
|
461 |
+
|
462 |
+
# plt.subplot(1,3,3)
|
463 |
+
# plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
|
464 |
+
|
465 |
+
# plt.show()
|
466 |
+
|
467 |
+
all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
|
468 |
+
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
|
469 |
+
|
470 |
+
return_flow = -return_flow
|
471 |
+
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
|
472 |
+
|
473 |
+
return return_flow , all_fwd_flows_e2d_new
|
474 |
+
|
475 |
+
def extract_jacobians_and_flows(img1, img2,
|
476 |
+
flow_generator,
|
477 |
+
mask,
|
478 |
+
target_mask=None):
|
479 |
+
|
480 |
+
IMAGE_SIZE = img1.shape[-2:]
|
481 |
+
|
482 |
+
y = torch.cat([img2.unsqueeze(1), img1.unsqueeze(1)], 1)
|
483 |
+
|
484 |
+
jacobians, flows, _ = flow_generator(y, mask, target_mask)
|
485 |
+
|
486 |
+
# swap x,y flow dims
|
487 |
+
flows = torch.cat([flows[0, 1].unsqueeze(0), flows[0, 0].unsqueeze(0)])
|
488 |
+
|
489 |
+
# upsample to 224
|
490 |
+
flows = flows.unsqueeze(0).repeat_interleave(IMAGE_SIZE[0] // flows.shape[-1], -1).repeat_interleave(
|
491 |
+
IMAGE_SIZE[0] // flows.shape[-1], -2)
|
492 |
+
|
493 |
+
return jacobians, flows
|
494 |
+
|
495 |
+
import matplotlib.pyplot as plt
|
496 |
+
|
497 |
+
class FlowToRgb(object):
|
498 |
+
|
499 |
+
def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
|
500 |
+
self.max_speed = max_speed
|
501 |
+
self.from_image_coordinates = from_image_coordinates
|
502 |
+
self.from_sampling_grid = from_sampling_grid
|
503 |
+
|
504 |
+
def __call__(self, flow):
|
505 |
+
assert flow.size(-3) == 2, flow.shape
|
506 |
+
if self.from_sampling_grid:
|
507 |
+
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
|
508 |
+
flow_y = -flow_y
|
509 |
+
elif not self.from_image_coordinates:
|
510 |
+
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
|
511 |
+
else:
|
512 |
+
flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
|
513 |
+
flow_x, flow_y = [flow_w, -flow_h]
|
514 |
+
|
515 |
+
|
516 |
+
# print("flow_x", flow_x[0, :, 0, 0], flow_y[0, :, 0, 0])
|
517 |
+
angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
|
518 |
+
speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
|
519 |
+
|
520 |
+
# print("angle", angle[0, :, 0, 0] * 180 / np.pi)
|
521 |
+
|
522 |
+
hue = torch.fmod(angle, torch.tensor(2 * np.pi))
|
523 |
+
sat = torch.ones_like(hue)
|
524 |
+
val = speed
|
525 |
+
|
526 |
+
hsv = torch.cat([hue, sat, val], -3)
|
527 |
+
rgb = kornia.color.hsv_to_rgb(hsv)
|
528 |
+
return rgb
|
529 |
+
|
530 |
+
def make_colorwheel(self):
|
531 |
+
"""
|
532 |
+
Generates a color wheel for optical flow visualization as presented in:
|
533 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
534 |
+
"""
|
535 |
+
RY = 15
|
536 |
+
YG = 6
|
537 |
+
GC = 4
|
538 |
+
CB = 11
|
539 |
+
BM = 13
|
540 |
+
MR = 6
|
541 |
+
|
542 |
+
ncols = RY + YG + GC + CB + BM + MR
|
543 |
+
colorwheel = np.zeros((ncols, 3))
|
544 |
+
col = 0
|
545 |
+
|
546 |
+
# RY
|
547 |
+
colorwheel[0:RY, 0] = 255
|
548 |
+
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
|
549 |
+
col += RY
|
550 |
+
# YG
|
551 |
+
colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
|
552 |
+
colorwheel[col:col + YG, 1] = 255
|
553 |
+
col += YG
|
554 |
+
# GC
|
555 |
+
colorwheel[col:col + GC, 1] = 255
|
556 |
+
colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
|
557 |
+
col += GC
|
558 |
+
# CB
|
559 |
+
colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB)
|
560 |
+
colorwheel[col:col + CB, 2] = 255
|
561 |
+
col += CB
|
562 |
+
# BM
|
563 |
+
colorwheel[col:col + BM, 2] = 255
|
564 |
+
colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
|
565 |
+
col += BM
|
566 |
+
# MR
|
567 |
+
colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR)
|
568 |
+
colorwheel[col:col + MR, 0] = 255
|
569 |
+
return colorwheel
|
cwm/eval/Flow/flow_utils_legacy.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
|
2 |
+
mask_generator,
|
3 |
+
img1,
|
4 |
+
img2,
|
5 |
+
neg_back_flow=True,
|
6 |
+
num_scales=1,
|
7 |
+
min_scale=400,
|
8 |
+
N_mask_samples=100,
|
9 |
+
mask_ratio=0.8,
|
10 |
+
smoothing_factor=1):
|
11 |
+
B = img1.shape[0]
|
12 |
+
assert len(img1.shape) == 4
|
13 |
+
assert num_scales >= 1
|
14 |
+
|
15 |
+
# For scaling
|
16 |
+
h1 = img2.shape[-2]
|
17 |
+
w1 = img2.shape[-1]
|
18 |
+
assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible
|
19 |
+
|
20 |
+
if neg_back_flow is False:
|
21 |
+
print('WARNING: Not calculating negative backward flow')
|
22 |
+
|
23 |
+
alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
|
24 |
+
|
25 |
+
frame_size = 224 // generator.patch_size[-1]
|
26 |
+
|
27 |
+
all_fwd_flows_e2d = []
|
28 |
+
|
29 |
+
s_hs = []
|
30 |
+
s_ws = []
|
31 |
+
|
32 |
+
for aidx in range(num_scales):
|
33 |
+
print(aidx)
|
34 |
+
|
35 |
+
# print('aidx: ', aidx)
|
36 |
+
|
37 |
+
img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
|
38 |
+
mode='bicubic', align_corners=True)
|
39 |
+
img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
|
40 |
+
mode='bicubic', align_corners=True)
|
41 |
+
|
42 |
+
h2 = img2_scaled.shape[-2]
|
43 |
+
w2 = img2_scaled.shape[-1]
|
44 |
+
|
45 |
+
s_h = h1 / h2
|
46 |
+
s_w = w1 / w2
|
47 |
+
|
48 |
+
s_hs.append(s_h)
|
49 |
+
s_ws.append(s_w)
|
50 |
+
|
51 |
+
# Because technically the compute_optical_flow function returns neg back flow
|
52 |
+
if neg_back_flow is True:
|
53 |
+
video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
|
54 |
+
else:
|
55 |
+
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1)
|
56 |
+
|
57 |
+
# Should work, even if the incoming video is already 224x224
|
58 |
+
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
|
59 |
+
|
60 |
+
num_crops = len(crops1)
|
61 |
+
|
62 |
+
crop_flows_enc = []
|
63 |
+
crop_flows_enc2dec = []
|
64 |
+
N_samples = N_mask_samples
|
65 |
+
|
66 |
+
crop = torch.cat(crops1, 0).cuda()
|
67 |
+
|
68 |
+
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
|
69 |
+
mask_counts = torch.zeros(frame_size, frame_size).cuda()
|
70 |
+
|
71 |
+
i = 0
|
72 |
+
while i < N_samples or (mask_counts == 0).any().item():
|
73 |
+
if i % 100 == 0:
|
74 |
+
pass # print(i)
|
75 |
+
mask_generator.mask_ratio = mask_ratio
|
76 |
+
|
77 |
+
# This would be that every sample has the same mask. For now that's okay I think
|
78 |
+
mask = mask_generator()[None]
|
79 |
+
mask_2f = ~mask[0, frame_size * frame_size:]
|
80 |
+
mask_counts += mask_2f.reshape(frame_size, frame_size)
|
81 |
+
|
82 |
+
with torch.cuda.amp.autocast(enabled=True):
|
83 |
+
|
84 |
+
processed_x = generator._preprocess(crop)
|
85 |
+
|
86 |
+
encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
|
87 |
+
encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out)
|
88 |
+
|
89 |
+
optical_flow_e2d = []
|
90 |
+
# one per batch element for now
|
91 |
+
for b in range(B * num_crops):
|
92 |
+
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size)
|
93 |
+
optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
|
94 |
+
|
95 |
+
optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
|
96 |
+
optical_flows_enc2dec += optical_flow_e2d
|
97 |
+
i += 1
|
98 |
+
|
99 |
+
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
|
100 |
+
|
101 |
+
# split the crops back up
|
102 |
+
crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
|
103 |
+
|
104 |
+
T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
|
105 |
+
crop_flows_enc2dec]
|
106 |
+
|
107 |
+
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
|
108 |
+
B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
|
109 |
+
|
110 |
+
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
|
111 |
+
|
112 |
+
all_fwd_flows_e2d_new = []
|
113 |
+
|
114 |
+
for ridx, r in enumerate(all_fwd_flows_e2d):
|
115 |
+
# print('ridx', ridx)
|
116 |
+
# print('sh', s_hs[ridx])
|
117 |
+
# print('sw', s_ws[ridx])
|
118 |
+
# print('scale_fac y', scale_ys[ridx])
|
119 |
+
# print('scale_fac x', scale_xs[ridx])
|
120 |
+
|
121 |
+
_sh = s_hs[ridx]
|
122 |
+
_sw = s_ws[ridx]
|
123 |
+
_sfy = generator.patch_size[-1]
|
124 |
+
_sfx = generator.patch_size[-1]
|
125 |
+
|
126 |
+
# plt.figure(figsize=(20, 20))
|
127 |
+
|
128 |
+
# plt.subplot(1,3,1)
|
129 |
+
# plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
|
130 |
+
|
131 |
+
# plt.subplot(1,3,2)
|
132 |
+
new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])],
|
133 |
+
mode='bicubic', align_corners=True)
|
134 |
+
# plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
|
135 |
+
|
136 |
+
scaled_new_r = torch.zeros_like(new_r)
|
137 |
+
scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
|
138 |
+
scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
|
139 |
+
|
140 |
+
# plt.subplot(1,3,3)
|
141 |
+
# plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
|
142 |
+
|
143 |
+
# plt.show()
|
144 |
+
|
145 |
+
all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
|
146 |
+
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
|
147 |
+
|
148 |
+
if neg_back_flow is True:
|
149 |
+
return_flow = -return_flow
|
150 |
+
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
|
151 |
+
|
152 |
+
return return_flow, all_fwd_flows_e2d_new
|
cwm/eval/Flow/generator.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kornia
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import cwm.eval.Flow.masking_flow as masking
|
9 |
+
|
10 |
+
|
11 |
+
def boltzmann(x, beta=1, eps=1e-9):
|
12 |
+
if beta is None:
|
13 |
+
return x
|
14 |
+
x = torch.exp(x * beta)
|
15 |
+
return x / x.amax((-1,-2), keepdim=True).clamp(min=eps)
|
16 |
+
|
17 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
18 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
19 |
+
|
20 |
+
def imagenet_normalize(x, temporal_dim=1):
|
21 |
+
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x)
|
22 |
+
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x)
|
23 |
+
if temporal_dim == 2:
|
24 |
+
mean = mean.transpose(1,2)
|
25 |
+
std = std.transpose(1,2)
|
26 |
+
return (x - mean) / std
|
27 |
+
|
28 |
+
def imagenet_unnormalize(x, temporal_dim=2):
|
29 |
+
device = x.device
|
30 |
+
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x)
|
31 |
+
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x)
|
32 |
+
if temporal_dim == 2:
|
33 |
+
mean = mean.transpose(1,2)
|
34 |
+
std = std.transpose(1,2)
|
35 |
+
x = x*std + mean
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32):
|
41 |
+
static = False
|
42 |
+
if seq_length == 0:
|
43 |
+
static = True
|
44 |
+
seq_length = 1
|
45 |
+
B = batch_size
|
46 |
+
T = seq_length
|
47 |
+
H,W = imsize
|
48 |
+
ones = torch.ones([B,H,W,1], dtype=dtype_out)
|
49 |
+
if normalize:
|
50 |
+
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out))
|
51 |
+
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
|
52 |
+
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out))
|
53 |
+
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
|
54 |
+
else:
|
55 |
+
h = torch.arange(H).to(ones).view(1,H,1,1) * ones
|
56 |
+
w = torch.arange(W).to(ones).view(1,1,W,1) * ones
|
57 |
+
h = torch.stack([h]*T, 1)
|
58 |
+
w = torch.stack([w]*T, 1)
|
59 |
+
hw_ims = torch.cat([h,w], -1)
|
60 |
+
if static:
|
61 |
+
hw_ims = hw_ims[:,0]
|
62 |
+
return hw_ims
|
63 |
+
|
64 |
+
|
65 |
+
def get_distribution_centroid(dist, eps=1e-9, normalize=False):
|
66 |
+
|
67 |
+
B,T,C,H,W = dist.shape
|
68 |
+
assert C == 1
|
69 |
+
dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps)
|
70 |
+
dist = dist / dist_sum
|
71 |
+
|
72 |
+
grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device)
|
73 |
+
grid = grid.permute(0,1,4,2,3)
|
74 |
+
centroid = (grid * dist).sum((-2,-1))
|
75 |
+
return centroid
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
class FlowToRgb(object):
|
80 |
+
|
81 |
+
def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
|
82 |
+
self.max_speed = max_speed
|
83 |
+
self.from_image_coordinates = from_image_coordinates
|
84 |
+
self.from_sampling_grid = from_sampling_grid
|
85 |
+
|
86 |
+
def __call__(self, flow):
|
87 |
+
assert flow.size(-3) == 2, flow.shape
|
88 |
+
if self.from_sampling_grid:
|
89 |
+
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
|
90 |
+
flow_y = -flow_y
|
91 |
+
elif not self.from_image_coordinates:
|
92 |
+
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
|
93 |
+
else:
|
94 |
+
flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
|
95 |
+
flow_x, flow_y = [flow_w, -flow_h]
|
96 |
+
|
97 |
+
angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
|
98 |
+
speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
|
99 |
+
|
100 |
+
hue = torch.fmod(angle, torch.tensor(2 * np.pi))
|
101 |
+
sat = torch.ones_like(hue)
|
102 |
+
val = speed
|
103 |
+
|
104 |
+
hsv = torch.cat([hue, sat, val], -3)
|
105 |
+
rgb = kornia.color.hsv_to_rgb(hsv)
|
106 |
+
return rgb
|
107 |
+
|
108 |
+
class Patchify(nn.Module):
|
109 |
+
"""Convert a set of images or a movie into patch vectors"""
|
110 |
+
|
111 |
+
def __init__(self,
|
112 |
+
patch_size=(16, 16),
|
113 |
+
temporal_dim=1,
|
114 |
+
squeeze_channel_dim=True
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.set_patch_size(patch_size)
|
118 |
+
self.temporal_dim = temporal_dim
|
119 |
+
assert self.temporal_dim in [1, 2], self.temporal_dim
|
120 |
+
self._squeeze_channel_dim = squeeze_channel_dim
|
121 |
+
|
122 |
+
@property
|
123 |
+
def num_patches(self):
|
124 |
+
if (self.T is None) or (self.H is None) or (self.W is None):
|
125 |
+
return None
|
126 |
+
else:
|
127 |
+
return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw)
|
128 |
+
|
129 |
+
def set_patch_size(self, patch_size):
|
130 |
+
self.patch_size = patch_size
|
131 |
+
if len(self.patch_size) == 2:
|
132 |
+
self.ph, self.pw = self.patch_size
|
133 |
+
self.pt = 1
|
134 |
+
self._patches_are_3d = False
|
135 |
+
elif len(self.patch_size) == 3:
|
136 |
+
self.pt, self.ph, self.pw = self.patch_size
|
137 |
+
self._patches_are_3d = True
|
138 |
+
else:
|
139 |
+
raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size)
|
140 |
+
|
141 |
+
self.shape_inp = self.rank_inp = self.H = self.W = self.T = None
|
142 |
+
self.D = self.C = self.E = self.embed_dim = None
|
143 |
+
|
144 |
+
def _check_shape(self, x):
|
145 |
+
self.shape_inp = x.shape
|
146 |
+
self.rank_inp = len(self.shape_inp)
|
147 |
+
self.H, self.W = self.shape_inp[-2:]
|
148 |
+
assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size)
|
149 |
+
if (self.rank_inp == 5) and self._patches_are_3d:
|
150 |
+
self.T = self.shape_inp[self.temporal_dim]
|
151 |
+
assert (self.T % self.pt) == 0, (self.T, self.pt)
|
152 |
+
elif self.rank_inp == 5:
|
153 |
+
self.T = self.shape_inp[self.temporal_dim]
|
154 |
+
else:
|
155 |
+
self.T = 1
|
156 |
+
|
157 |
+
def split_by_time(self, x):
|
158 |
+
shape = x.shape
|
159 |
+
assert shape[1] % self.T == 0, (shape, self.T)
|
160 |
+
return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:])
|
161 |
+
|
162 |
+
def merge_by_time(self, x):
|
163 |
+
shape = x.shape
|
164 |
+
return x.view(shape[0], shape[1] * shape[2], *shape[3:])
|
165 |
+
|
166 |
+
def video_to_patches(self, x):
|
167 |
+
if self.rank_inp == 4:
|
168 |
+
assert self.pt == 1, (self.pt, x.shape)
|
169 |
+
x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw)
|
170 |
+
else:
|
171 |
+
assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp)
|
172 |
+
dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)'
|
173 |
+
x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw)
|
174 |
+
|
175 |
+
self.N, self.D, self.C = x.shape[-3:]
|
176 |
+
self.embed_dim = self.E = self.D * self.C
|
177 |
+
return x
|
178 |
+
|
179 |
+
def patches_to_video(self, x):
|
180 |
+
shape = x.shape
|
181 |
+
rank = len(shape)
|
182 |
+
if rank == 4:
|
183 |
+
B, _N, _D, _C = shape
|
184 |
+
else:
|
185 |
+
assert rank == 3, rank
|
186 |
+
B, _N, _E = shape
|
187 |
+
assert (_E % self.D == 0), (_E, self.D)
|
188 |
+
x = x.view(B, _N, self.D, -1)
|
189 |
+
|
190 |
+
if _N < self.num_patches:
|
191 |
+
masked_patches = self.get_masked_patches(
|
192 |
+
x,
|
193 |
+
num_patches=(self.num_patches - _N),
|
194 |
+
mask_mode=self.mask_mode)
|
195 |
+
x = torch.cat([x, masked_patches], 1)
|
196 |
+
|
197 |
+
x = rearrange(
|
198 |
+
x,
|
199 |
+
'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)',
|
200 |
+
pt=self.pt, ph=self.ph, pw=self.pw,
|
201 |
+
t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw))
|
202 |
+
|
203 |
+
if self.rank_inp == 5 and (self.temporal_dim == 1):
|
204 |
+
x = x.transpose(1, 2)
|
205 |
+
elif self.rank_inp == 4:
|
206 |
+
assert x.shape[2] == 1, x.shape
|
207 |
+
x = x[:, :, 0]
|
208 |
+
return x
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
def get_masked_patches(x, num_patches, mask_mode='zeros'):
|
212 |
+
shape = x.shape
|
213 |
+
patches_shape = (shape[0], num_patches, *shape[2:])
|
214 |
+
if mask_mode == 'zeros':
|
215 |
+
return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach()
|
216 |
+
elif mask_mode == 'gray':
|
217 |
+
return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach()
|
218 |
+
else:
|
219 |
+
raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode)
|
220 |
+
|
221 |
+
def average_within_patches(self, z):
|
222 |
+
if len(z.shape) == 3:
|
223 |
+
z = rearrange(z, 'b n (d c) -> b n d c', c=self.C)
|
224 |
+
return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1)
|
225 |
+
|
226 |
+
def forward(self, x, to_video=False, mask_mode='zeros'):
|
227 |
+
if not to_video:
|
228 |
+
self._check_shape(x)
|
229 |
+
x = self.video_to_patches(x)
|
230 |
+
return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1)
|
231 |
+
|
232 |
+
else: # x are patches
|
233 |
+
assert (self.shape_inp is not None) and (self.num_patches is not None)
|
234 |
+
self.mask_mode = mask_mode
|
235 |
+
x = self.patches_to_video(x)
|
236 |
+
return x
|
237 |
+
|
238 |
+
|
239 |
+
class DerivativeFlowGenerator(nn.Module):
|
240 |
+
"""Estimate flow of a two-frame predictor using torch autograd"""
|
241 |
+
|
242 |
+
def __init__(self,
|
243 |
+
predictor,
|
244 |
+
perturbation_patch_size=None,
|
245 |
+
aggregation_patch_size=None,
|
246 |
+
agg_power=None,
|
247 |
+
agg_channel_func=None,
|
248 |
+
num_samples=1,
|
249 |
+
leave_one_out_sampling=False,
|
250 |
+
average_jacobian=True,
|
251 |
+
confidence_thresh=None,
|
252 |
+
temporal_dim=2,
|
253 |
+
imagenet_normalize_inputs=True):
|
254 |
+
|
255 |
+
super(DerivativeFlowGenerator, self).__init__()
|
256 |
+
|
257 |
+
self.predictor = predictor
|
258 |
+
|
259 |
+
self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True)
|
260 |
+
|
261 |
+
self.set_temporal_dim(temporal_dim)
|
262 |
+
|
263 |
+
self.imagenet_normalize_inputs = imagenet_normalize_inputs
|
264 |
+
|
265 |
+
self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size
|
266 |
+
self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size
|
267 |
+
self.agg_patchify = Patchify(self.aggregation_patch_size,
|
268 |
+
temporal_dim=1,
|
269 |
+
squeeze_channel_dim=False)
|
270 |
+
self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True))
|
271 |
+
self.average_jacobian = average_jacobian
|
272 |
+
self.confidence_thresh = confidence_thresh
|
273 |
+
|
274 |
+
self.num_samples = num_samples
|
275 |
+
self.leave_one_out_sampling = leave_one_out_sampling
|
276 |
+
self.agg_power = agg_power
|
277 |
+
self.t_dim = temporal_dim
|
278 |
+
|
279 |
+
def _get_patch_size(self, p):
|
280 |
+
if p is None:
|
281 |
+
return None
|
282 |
+
elif isinstance(p, int):
|
283 |
+
return (1, p, p)
|
284 |
+
elif len(p) == 2:
|
285 |
+
return (1, p[0], p[1])
|
286 |
+
else:
|
287 |
+
assert len(p) == 3, p
|
288 |
+
return (p[0], p[1], p[2])
|
289 |
+
|
290 |
+
def set_temporal_dim(self, t_dim):
|
291 |
+
if t_dim == 1:
|
292 |
+
self.predictor.t_dim = 1
|
293 |
+
self.predictor.c_dim = 2
|
294 |
+
elif t_dim == 2:
|
295 |
+
self.predictor.c_dim = 1
|
296 |
+
self.predictor.t_dim = 2
|
297 |
+
else:
|
298 |
+
raise ValueError("temporal_dim must be 1 or 2")
|
299 |
+
|
300 |
+
@property
|
301 |
+
def c_dim(self):
|
302 |
+
if self.predictor is None:
|
303 |
+
return None
|
304 |
+
return self.predictor.c_dim
|
305 |
+
|
306 |
+
@property
|
307 |
+
def patch_size(self):
|
308 |
+
if self.predictor is None:
|
309 |
+
return None
|
310 |
+
elif hasattr(self.predictor, 'patch_size'):
|
311 |
+
return self.predictor.patch_size
|
312 |
+
elif hasattr(self.predictor.encoder.patch_embed, 'proj'):
|
313 |
+
return self.predictor.encoder.patch_embed.proj.kernel_size
|
314 |
+
else:
|
315 |
+
return None
|
316 |
+
@property
|
317 |
+
def S(self):
|
318 |
+
return self.num_samples
|
319 |
+
|
320 |
+
@property
|
321 |
+
def sequence_length(self):
|
322 |
+
if self.predictor is None:
|
323 |
+
return None
|
324 |
+
elif hasattr(self.predictor, 'sequence_length'):
|
325 |
+
return self.predictor.sequence_length
|
326 |
+
elif hasattr(self.predictor, 'num_frames'):
|
327 |
+
return self.predictor.num_frames
|
328 |
+
else:
|
329 |
+
return 2
|
330 |
+
@property
|
331 |
+
def mask_shape(self):
|
332 |
+
if self.predictor is None:
|
333 |
+
return None
|
334 |
+
elif hasattr(self.predictor, 'mask_shape'):
|
335 |
+
return self.predictor.mask_shape
|
336 |
+
|
337 |
+
assert self.patch_size is not None
|
338 |
+
pt, ph, pw = self.patch_size
|
339 |
+
return (self.sequence_length // pt,
|
340 |
+
self.inp_shape[-2] // ph,
|
341 |
+
self.inp_shape[-1] // pw)
|
342 |
+
|
343 |
+
@property
|
344 |
+
def perturbation_mask_shape(self):
|
345 |
+
return (
|
346 |
+
self.mask_shape[0],
|
347 |
+
self.inp_shape[-2] // self.perturbation_patch_size[-2],
|
348 |
+
self.inp_shape[-1] // self.perturbation_patch_size[-1]
|
349 |
+
)
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
@property
|
354 |
+
def p_mask_shape(self):
|
355 |
+
return self.perturbation_mask_shape
|
356 |
+
|
357 |
+
@property
|
358 |
+
def aggregation_mask_shape(self):
|
359 |
+
return (
|
360 |
+
1,
|
361 |
+
self.inp_shape[-2] // self.aggregation_patch_size[-2],
|
362 |
+
self.inp_shape[-1] // self.aggregation_patch_size[-1]
|
363 |
+
)
|
364 |
+
|
365 |
+
@property
|
366 |
+
def a_mask_shape(self):
|
367 |
+
return self.aggregation_mask_shape
|
368 |
+
|
369 |
+
def get_perturbation_input(self, x):
|
370 |
+
self.set_input(x)
|
371 |
+
y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True)
|
372 |
+
y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1)
|
373 |
+
return y
|
374 |
+
|
375 |
+
def pred_patches_to_video(self, y, x, mask):
|
376 |
+
"""input at visible positions, preds at masked positions"""
|
377 |
+
B, C = y.shape[0], y.shape[-1]
|
378 |
+
self.patchify._check_shape(x)
|
379 |
+
self.patchify.D = np.prod(self.patch_size)
|
380 |
+
x = self.patchify(x)
|
381 |
+
y_out = torch.zeros_like(x)
|
382 |
+
x_vis = x[~mask]
|
383 |
+
|
384 |
+
y_out[~mask] = x_vis.view(-1, C)
|
385 |
+
try:
|
386 |
+
y_out[mask] = y.view(-1, C)
|
387 |
+
except:
|
388 |
+
y_out[mask] = y.reshape(-1, C)
|
389 |
+
|
390 |
+
return self.patchify(y_out, to_video=True)
|
391 |
+
|
392 |
+
def set_image_size(self, *args, **kwargs):
|
393 |
+
assert self.predictor is not None, "Can't set the image size without a predictor"
|
394 |
+
if hasattr(self.predictor, 'set_image_size'):
|
395 |
+
self.predictor.set_image_size(*args, **kwargs)
|
396 |
+
else:
|
397 |
+
self.predictor.image_size = args[0]
|
398 |
+
|
399 |
+
def predict(self, x=None, mask=None, forward_full=False):
|
400 |
+
if x is None:
|
401 |
+
x = self.x
|
402 |
+
if mask is None:
|
403 |
+
mask = self.generate_mask(x)
|
404 |
+
|
405 |
+
self.set_image_size(x.shape[-2:])
|
406 |
+
y = self.predictor(
|
407 |
+
self._preprocess(x),
|
408 |
+
mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full)
|
409 |
+
|
410 |
+
y = self.pred_patches_to_video(y, x, mask=mask)
|
411 |
+
|
412 |
+
frame = -1 % y.size(1)
|
413 |
+
y = y[:, frame:frame + 1]
|
414 |
+
|
415 |
+
return y
|
416 |
+
|
417 |
+
def _get_perturbation_func(self, x=None, mask=None):
|
418 |
+
|
419 |
+
if (x is not None):
|
420 |
+
self.set_input(x, mask)
|
421 |
+
|
422 |
+
def forward_mini_image(y):
|
423 |
+
y = y.repeat_interleave(self.perturbation_patch_size[-2], -2)
|
424 |
+
y = y.repeat_interleave(self.perturbation_patch_size[-1], -1)
|
425 |
+
x_pred = self.predict(self.x + y, self.mask)
|
426 |
+
x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape)
|
427 |
+
return x_pred[self.targets]
|
428 |
+
|
429 |
+
return forward_mini_image
|
430 |
+
|
431 |
+
def _postprocess_jacobian(self, jac):
|
432 |
+
_jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype)
|
433 |
+
_jac[self.targets] = jac
|
434 |
+
jac = self.agg_channel_func(_jac)
|
435 |
+
assert jac.size(-3) == 1, jac.shape
|
436 |
+
jac = jac.squeeze(-3)[..., 0, :, :] # derivative w.r.t. first frame and agg channels
|
437 |
+
jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1],
|
438 |
+
self.B, self.p_mask_shape[-2], self.p_mask_shape[-1])
|
439 |
+
bs = torch.arange(0, self.B).long().to(jac.device)
|
440 |
+
jac = jac[bs, :, :, bs, :, :] # take diagonal
|
441 |
+
return jac
|
442 |
+
|
443 |
+
def _confident_jacobian(self, jac):
|
444 |
+
if self.confidence_thresh is None:
|
445 |
+
return torch.ones_like(jac[:, None, ..., 0, 0])
|
446 |
+
conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None]
|
447 |
+
return conf
|
448 |
+
|
449 |
+
def set_input(self, x, mask=None, timestamps=None):
|
450 |
+
shape = x.shape
|
451 |
+
if len(shape) == 4:
|
452 |
+
x = x.unsqueeze(1)
|
453 |
+
else:
|
454 |
+
assert len(shape) == 5, \
|
455 |
+
"Input must be a movie of shape [B,T,C,H,W]" + \
|
456 |
+
"or a single frame of shape [B,C,H,W]"
|
457 |
+
|
458 |
+
self.inp_shape = x.shape
|
459 |
+
self.x = x
|
460 |
+
self.B = self.inp_shape[0]
|
461 |
+
self.T = self.inp_shape[1]
|
462 |
+
self.C = self.inp_shape[2]
|
463 |
+
if mask is not None:
|
464 |
+
self.mask = mask
|
465 |
+
|
466 |
+
if timestamps is not None:
|
467 |
+
self.timestamps = timestamps
|
468 |
+
|
469 |
+
def _preprocess(self, x):
|
470 |
+
if self.imagenet_normalize_inputs:
|
471 |
+
x = imagenet_normalize(x)
|
472 |
+
if self.t_dim != 1:
|
473 |
+
x = x.transpose(self.t_dim, self.c_dim)
|
474 |
+
return x
|
475 |
+
|
476 |
+
def _jacobian_to_flows(self, jac):
|
477 |
+
if self.agg_power is None:
|
478 |
+
jac = (jac == jac.amax((-2, -1), True)).float()
|
479 |
+
else:
|
480 |
+
jac = torch.pow(jac, self.agg_power)
|
481 |
+
|
482 |
+
jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:])
|
483 |
+
centroids = get_distribution_centroid(jac, normalize=False).view(
|
484 |
+
self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2)
|
485 |
+
rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2],
|
486 |
+
self.a_mask_shape[-1] / self.p_mask_shape[-1]]
|
487 |
+
centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2)
|
488 |
+
|
489 |
+
flows = centroids - \
|
490 |
+
coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device)
|
491 |
+
flows = flows.permute(0, 3, 1, 2)
|
492 |
+
px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1)
|
493 |
+
flows *= px_scale
|
494 |
+
|
495 |
+
return flows
|
496 |
+
|
497 |
+
def set_targets(self, targets=None, frame=-1):
|
498 |
+
frame = frame % self.mask_shape[0]
|
499 |
+
if targets is None:
|
500 |
+
targets = self.get_mask_image(self.mask)[:, frame:frame + 1]
|
501 |
+
else:
|
502 |
+
assert len(targets.shape) == 4, targets.shape
|
503 |
+
targets = targets[:, frame:frame + 1]
|
504 |
+
self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:])
|
505 |
+
|
506 |
+
def _get_mask_partition(self, mask):
|
507 |
+
mask = self.get_mask_image(mask)
|
508 |
+
mask_list = masking.partition_masks(
|
509 |
+
mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling)
|
510 |
+
return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1)
|
511 |
+
for m in mask_list]
|
512 |
+
|
513 |
+
def _compute_jacobian(self, y):
|
514 |
+
perturbation_func = self._get_perturbation_func()
|
515 |
+
jac = torch.autograd.functional.jacobian(
|
516 |
+
perturbation_func,
|
517 |
+
y,
|
518 |
+
vectorize=False)
|
519 |
+
jac = self._postprocess_jacobian(jac)
|
520 |
+
return jac
|
521 |
+
|
522 |
+
def _upsample_mask(self, mask):
|
523 |
+
return masking.upsample_masks(
|
524 |
+
mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:])
|
525 |
+
|
526 |
+
def get_mask_image(self, mask, upsample=False, invert=False, shape=None):
|
527 |
+
if shape is None:
|
528 |
+
shape = self.mask_shape
|
529 |
+
mask = mask.view(-1, *shape)
|
530 |
+
if upsample:
|
531 |
+
mask = self._upsample_mask(mask)
|
532 |
+
if invert:
|
533 |
+
mask = 1 - mask
|
534 |
+
return mask
|
535 |
+
|
536 |
+
def forward(self, x, mask, targets=None):
|
537 |
+
self.set_input(x, mask)
|
538 |
+
y = self.get_perturbation_input(x)
|
539 |
+
mask_list = self._get_mask_partition(mask)
|
540 |
+
|
541 |
+
jacobian, flows, confident = [], [], []
|
542 |
+
for s, mask_sample in enumerate(mask_list):
|
543 |
+
self.set_input(x, mask_sample)
|
544 |
+
self.set_targets(targets)
|
545 |
+
|
546 |
+
import time
|
547 |
+
t1 = time.time()
|
548 |
+
jac = self._compute_jacobian(y)
|
549 |
+
conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:])
|
550 |
+
jacobian.append(jac)
|
551 |
+
confident.append(conf_jac)
|
552 |
+
if not self.average_jacobian:
|
553 |
+
flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \
|
554 |
+
masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:])
|
555 |
+
flows.append(flow)
|
556 |
+
t2 = time.time()
|
557 |
+
print(t2 - t1)
|
558 |
+
|
559 |
+
jacobian = torch.stack(jacobian, -1)
|
560 |
+
confident = torch.stack(confident, -1)
|
561 |
+
valid = torch.stack([masking.upsample_masks(
|
562 |
+
self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1)
|
563 |
+
valid = valid * confident
|
564 |
+
|
565 |
+
if self.average_jacobian:
|
566 |
+
_valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2)
|
567 |
+
jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1)
|
568 |
+
flows = self._jacobian_to_flows(jac) * \
|
569 |
+
masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:])
|
570 |
+
if targets is not None:
|
571 |
+
self.set_targets(targets)
|
572 |
+
flows *= self.targets
|
573 |
+
else:
|
574 |
+
flows = torch.stack(flows, -1)
|
575 |
+
flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1)
|
576 |
+
|
577 |
+
valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1)
|
578 |
+
|
579 |
+
return (jacobian, flows, valid)
|
cwm/eval/Flow/losses.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
|
6 |
+
def sampling_grid(height, width):
|
7 |
+
H, W = height, width
|
8 |
+
grid = torch.stack([
|
9 |
+
torch.arange(W).view(1, -1).repeat(H, 1),
|
10 |
+
torch.arange(H).view(-1, 1).repeat(1, W)
|
11 |
+
], -1)
|
12 |
+
grid = grid.view(1, H, W, 2)
|
13 |
+
return grid
|
14 |
+
|
15 |
+
|
16 |
+
def normalize_sampling_grid(coords):
|
17 |
+
assert len(coords.shape) == 4, coords.shape
|
18 |
+
assert coords.size(-1) == 2, coords.shape
|
19 |
+
H, W = coords.shape[-3:-1]
|
20 |
+
xs, ys = coords.split([1, 1], -1)
|
21 |
+
xs = 2 * xs / (W - 1) - 1
|
22 |
+
ys = 2 * ys / (H - 1) - 1
|
23 |
+
return torch.cat([xs, ys], -1)
|
24 |
+
|
25 |
+
|
26 |
+
def backward_warp(img2, flow, do_mask=False):
|
27 |
+
"""
|
28 |
+
Grid sample from img2 using the flow from img1->img2 to get a prediction of img1.
|
29 |
+
|
30 |
+
flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels
|
31 |
+
should be (x,y) where larger y values correspond to lower parts of the image.
|
32 |
+
"""
|
33 |
+
|
34 |
+
## resize the flow to the image size.
|
35 |
+
## since flow has units of pixels, its values need to be rescaled accordingly.
|
36 |
+
if list(img2.shape[-2:]) != list(flow.shape[-2:]):
|
37 |
+
scale = [img2.size(-1) / flow.size(-1), # x
|
38 |
+
img2.size(-2) / flow.size(-2)] # y
|
39 |
+
scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device)
|
40 |
+
flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear
|
41 |
+
|
42 |
+
B, C, H, W = img2.shape
|
43 |
+
|
44 |
+
## use flow to warp sampling grid
|
45 |
+
grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1)
|
46 |
+
|
47 |
+
## put grid in normalized image coordinates
|
48 |
+
grid = normalize_sampling_grid(grid)
|
49 |
+
|
50 |
+
## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y)
|
51 |
+
img1_pred = F.grid_sample(img2, grid, align_corners=True)
|
52 |
+
|
53 |
+
if do_mask:
|
54 |
+
mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \
|
55 |
+
(grid[..., 1] > -1) & (grid[..., 1] < 1)
|
56 |
+
mask = mask[:, None].to(img2.dtype)
|
57 |
+
return (img1_pred, mask)
|
58 |
+
|
59 |
+
else:
|
60 |
+
return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float())
|
cwm/eval/Flow/masking_flow.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
|
7 |
+
def upsample_masks(masks, size, thresh=0.5):
|
8 |
+
shape = masks.shape
|
9 |
+
dtype = masks.dtype
|
10 |
+
h, w = shape[-2:]
|
11 |
+
H, W = size
|
12 |
+
if (H == h) and (W == w):
|
13 |
+
return masks
|
14 |
+
elif (H < h) and (W < w):
|
15 |
+
s = (h // H, w // W)
|
16 |
+
return masks[..., ::s[0], ::s[1]]
|
17 |
+
|
18 |
+
masks = masks.unsqueeze(-2).unsqueeze(-1)
|
19 |
+
masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w)
|
20 |
+
if ((H % h) == 0) and ((W % w) == 0):
|
21 |
+
masks = masks.view(*shape[:-2], H, W)
|
22 |
+
else:
|
23 |
+
_H = np.prod(masks.shape[-4:-2])
|
24 |
+
_W = np.prod(masks.shape[-2:])
|
25 |
+
masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh
|
26 |
+
masks = masks.view(*shape[:2], H, W).to(masks.dtype)
|
27 |
+
return masks
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def partition_masks(masks, num_samples=2, leave_one_out=False):
|
33 |
+
B = masks.shape[0]
|
34 |
+
S = num_samples
|
35 |
+
masks = masks.view(B, -1)
|
36 |
+
partitioned = [torch.ones_like(masks) for _ in range(S)]
|
37 |
+
for b in range(B):
|
38 |
+
vis_inds = torch.where(~masks[b])[0]
|
39 |
+
vis_inds = vis_inds[torch.randperm(vis_inds.size(0))]
|
40 |
+
if leave_one_out:
|
41 |
+
for s in range(S):
|
42 |
+
partitioned[s][b][vis_inds] = 0
|
43 |
+
partitioned[s][b][vis_inds[s::S]] = 1
|
44 |
+
else:
|
45 |
+
for s in range(S):
|
46 |
+
partitioned[s][b][vis_inds[s::S]] = 0
|
47 |
+
return partitioned
|
48 |
+
|
49 |
+
|
50 |
+
class RectangularizeMasks(nn.Module):
|
51 |
+
"""Make sure all masks in a batch have same number of 1s and 0s"""
|
52 |
+
|
53 |
+
def __init__(self, truncation_mode='min'):
|
54 |
+
super().__init__()
|
55 |
+
self._mode = truncation_mode
|
56 |
+
assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode)
|
57 |
+
|
58 |
+
def set_mode(self, mode):
|
59 |
+
self._mode = mode
|
60 |
+
|
61 |
+
def __call__(self, masks):
|
62 |
+
|
63 |
+
if self._mode in ['none', None]:
|
64 |
+
return masks
|
65 |
+
|
66 |
+
assert isinstance(masks, torch.Tensor), type(masks)
|
67 |
+
if self._mode == 'full':
|
68 |
+
return torch.ones_like(masks)
|
69 |
+
|
70 |
+
shape = masks.shape
|
71 |
+
masks = masks.flatten(1)
|
72 |
+
B, N = masks.shape
|
73 |
+
num_masked = masks.float().sum(-1)
|
74 |
+
M = {
|
75 |
+
'min': torch.amin, 'max': torch.amax, 'mean': torch.mean
|
76 |
+
}[self._mode](num_masked).long()
|
77 |
+
|
78 |
+
num_changes = num_masked.long() - M
|
79 |
+
|
80 |
+
for b in range(B):
|
81 |
+
nc = num_changes[b]
|
82 |
+
if nc > 0:
|
83 |
+
inds = torch.where(masks[b])[0]
|
84 |
+
inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)]
|
85 |
+
masks[b, inds] = 0
|
86 |
+
elif nc < 0:
|
87 |
+
inds = torch.where(~masks[b])[0]
|
88 |
+
inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)]
|
89 |
+
masks[b, inds] = 1
|
90 |
+
if list(masks.shape) != list(shape):
|
91 |
+
masks = masks.view(*shape)
|
92 |
+
|
93 |
+
return masks
|
94 |
+
|
95 |
+
|
96 |
+
class UniformMaskingGenerator(object):
|
97 |
+
def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False):
|
98 |
+
self.frames = None
|
99 |
+
if len(input_size) == 3:
|
100 |
+
self.frames, self.height, self.width = input_size
|
101 |
+
elif len(input_size) == 2:
|
102 |
+
self.height, self.width = input_size
|
103 |
+
elif len(input_size) == 1 or isinstance(input_size, int):
|
104 |
+
self.height = self.width = input_size
|
105 |
+
|
106 |
+
self.clumping_factor = clumping_factor
|
107 |
+
self.pad_h = self.height % self.c[0]
|
108 |
+
self.pad_w = self.width % self.c[1]
|
109 |
+
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
|
110 |
+
self.mask_ratio = mask_ratio
|
111 |
+
|
112 |
+
self.rng = np.random.RandomState(seed=seed)
|
113 |
+
self.randomize_num_visible = randomize_num_visible
|
114 |
+
|
115 |
+
@property
|
116 |
+
def num_masks_per_frame(self):
|
117 |
+
if not hasattr(self, '_num_masks_per_frame'):
|
118 |
+
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
|
119 |
+
return self._num_masks_per_frame
|
120 |
+
|
121 |
+
@num_masks_per_frame.setter
|
122 |
+
def num_masks_per_frame(self, val):
|
123 |
+
self._num_masks_per_frame = val
|
124 |
+
self._mask_ratio = (val / self.num_patches_per_frame)
|
125 |
+
|
126 |
+
@property
|
127 |
+
def c(self):
|
128 |
+
if isinstance(self.clumping_factor, int):
|
129 |
+
return (self.clumping_factor, self.clumping_factor)
|
130 |
+
else:
|
131 |
+
return self.clumping_factor[:2]
|
132 |
+
|
133 |
+
@property
|
134 |
+
def mask_ratio(self):
|
135 |
+
return self._mask_ratio
|
136 |
+
|
137 |
+
@mask_ratio.setter
|
138 |
+
def mask_ratio(self, val):
|
139 |
+
self._mask_ratio = val
|
140 |
+
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
|
141 |
+
|
142 |
+
@property
|
143 |
+
def num_visible(self):
|
144 |
+
return self.num_patches_per_frame - self.num_masks_per_frame
|
145 |
+
|
146 |
+
@num_visible.setter
|
147 |
+
def num_visible(self, val):
|
148 |
+
self.num_masks_per_frame = self.num_patches_per_frame - val
|
149 |
+
|
150 |
+
def __repr__(self):
|
151 |
+
repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format(
|
152 |
+
self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible
|
153 |
+
)
|
154 |
+
return repr_str
|
155 |
+
|
156 |
+
def sample_mask_per_frame(self):
|
157 |
+
num_masks = self.num_masks_per_frame
|
158 |
+
if self.randomize_num_visible:
|
159 |
+
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
|
160 |
+
mask = np.hstack([
|
161 |
+
np.zeros(self.num_patches_per_frame - num_masks),
|
162 |
+
np.ones(num_masks)])
|
163 |
+
self.rng.shuffle(mask)
|
164 |
+
if max(*self.c) > 1:
|
165 |
+
mask = mask.reshape(self.height // self.c[0],
|
166 |
+
1,
|
167 |
+
self.width // self.c[1],
|
168 |
+
1)
|
169 |
+
mask = np.tile(mask, (1, self.c[0], 1, self.c[1]))
|
170 |
+
mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w))
|
171 |
+
_pad_h = self.rng.choice(range(self.pad_h + 1))
|
172 |
+
pad_h = (self.pad_h - _pad_h, _pad_h)
|
173 |
+
_pad_w = self.rng.choice(range(self.pad_w + 1))
|
174 |
+
pad_w = (self.pad_w - _pad_w, _pad_w)
|
175 |
+
mask = np.pad(mask,
|
176 |
+
(pad_h, pad_w),
|
177 |
+
constant_values=1
|
178 |
+
).reshape((self.height, self.width))
|
179 |
+
return mask
|
180 |
+
|
181 |
+
def __call__(self, num_frames=None):
|
182 |
+
num_frames = (num_frames or self.frames) or 1
|
183 |
+
masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten()
|
184 |
+
return masks
|
185 |
+
|
186 |
+
|
187 |
+
class TubeMaskingGenerator(UniformMaskingGenerator):
|
188 |
+
|
189 |
+
def __call__(self, num_frames=None):
|
190 |
+
num_frames = (num_frames or self.frames) or 1
|
191 |
+
masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten()
|
192 |
+
return masks
|
193 |
+
|
194 |
+
|
195 |
+
class RotatedTableMaskingGenerator(TubeMaskingGenerator):
|
196 |
+
|
197 |
+
def __init__(self, tube_length=None, *args, **kwargs):
|
198 |
+
super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs)
|
199 |
+
self.tube_length = tube_length
|
200 |
+
|
201 |
+
def __call__(self, num_frames=None):
|
202 |
+
num_frames = (num_frames or self.frames) or 2
|
203 |
+
tube_length = self.tube_length or (num_frames - 1)
|
204 |
+
table_thickness = num_frames - tube_length
|
205 |
+
assert tube_length < num_frames, (tube_length, num_frames)
|
206 |
+
|
207 |
+
tubes = super().__call__(num_frames=tube_length)
|
208 |
+
top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten()
|
209 |
+
masks = np.concatenate([top, tubes], 0)
|
210 |
+
return masks
|
211 |
+
|
212 |
+
|
213 |
+
class PytorchMaskGeneratorWrapper(nn.Module):
|
214 |
+
"""Pytorch wrapper for numpy masking generators"""
|
215 |
+
|
216 |
+
def __init__(self,
|
217 |
+
mask_generator=TubeMaskingGenerator,
|
218 |
+
*args, **kwargs):
|
219 |
+
super().__init__()
|
220 |
+
self.mask_generator = mask_generator(*args, **kwargs)
|
221 |
+
|
222 |
+
@property
|
223 |
+
def mask_ratio(self):
|
224 |
+
return self.mask_generator.mask_ratio
|
225 |
+
|
226 |
+
@mask_ratio.setter
|
227 |
+
def mask_ratio(self, value):
|
228 |
+
self.mask_generator.mask_ratio = value
|
229 |
+
|
230 |
+
def forward(self, device='cuda', dtype_out=torch.bool, **kwargs):
|
231 |
+
masks = self.mask_generator(**kwargs)
|
232 |
+
masks = torch.tensor(masks).to(device).to(dtype_out)
|
233 |
+
return masks
|
234 |
+
|
235 |
+
|
236 |
+
class MaskingGenerator(nn.Module):
|
237 |
+
"""Pytorch base class for masking generators"""
|
238 |
+
|
239 |
+
def __init__(self,
|
240 |
+
input_size,
|
241 |
+
mask_ratio,
|
242 |
+
seed=0,
|
243 |
+
visible_frames=0,
|
244 |
+
clumping_factor=1,
|
245 |
+
randomize_num_visible=False,
|
246 |
+
create_on_cpu=True,
|
247 |
+
always_batch=False):
|
248 |
+
super().__init__()
|
249 |
+
self.frames = None
|
250 |
+
|
251 |
+
if len(input_size) == 3:
|
252 |
+
self.frames, self.height, self.width = input_size
|
253 |
+
elif len(input_size) == 2:
|
254 |
+
self.height, self.width = input_size
|
255 |
+
elif len(input_size) == 1 or isinstance(input_size, int):
|
256 |
+
self.height = self.width = input_size
|
257 |
+
|
258 |
+
self.clumping_factor = clumping_factor
|
259 |
+
self.pad_h = self.height % self.c[0]
|
260 |
+
self.pad_w = self.width % self.c[1]
|
261 |
+
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
|
262 |
+
|
263 |
+
self.mask_ratio = mask_ratio
|
264 |
+
self.visible_frames = visible_frames
|
265 |
+
self.always_batch = always_batch
|
266 |
+
self.create_on_cpu = create_on_cpu
|
267 |
+
|
268 |
+
self.rng = np.random.RandomState(seed=seed)
|
269 |
+
self._set_torch_seed(seed)
|
270 |
+
|
271 |
+
self.randomize_num_visible = randomize_num_visible
|
272 |
+
|
273 |
+
@property
|
274 |
+
def num_masks_per_frame(self):
|
275 |
+
if not hasattr(self, '_num_masks_per_frame'):
|
276 |
+
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
|
277 |
+
return self._num_masks_per_frame
|
278 |
+
|
279 |
+
@num_masks_per_frame.setter
|
280 |
+
def num_masks_per_frame(self, val):
|
281 |
+
self._num_masks_per_frame = val
|
282 |
+
self._mask_ratio = (val / self.num_patches_per_frame)
|
283 |
+
|
284 |
+
@property
|
285 |
+
def c(self):
|
286 |
+
if isinstance(self.clumping_factor, int):
|
287 |
+
return (self.clumping_factor,) * 2
|
288 |
+
else:
|
289 |
+
return self.clumping_factor[:2]
|
290 |
+
|
291 |
+
@property
|
292 |
+
def mask_ratio(self):
|
293 |
+
return self._mask_ratio
|
294 |
+
|
295 |
+
@mask_ratio.setter
|
296 |
+
def mask_ratio(self, val):
|
297 |
+
self._mask_ratio = val
|
298 |
+
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
|
299 |
+
|
300 |
+
@property
|
301 |
+
def num_visible(self):
|
302 |
+
return self.num_patches_per_frame - self.num_masks_per_frame
|
303 |
+
|
304 |
+
@num_visible.setter
|
305 |
+
def num_visible(self, val):
|
306 |
+
self.num_masks_per_frame = self.num_patches_per_frame - val
|
307 |
+
|
308 |
+
def _set_torch_seed(self, seed):
|
309 |
+
self.seed = seed
|
310 |
+
torch.manual_seed(self.seed)
|
311 |
+
|
312 |
+
def __repr__(self):
|
313 |
+
repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \
|
314 |
+
"mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \
|
315 |
+
"randomize num visible? {}").format(
|
316 |
+
type(self).__name__, self.num_patches_per_frame,
|
317 |
+
self.num_masks_per_frame, self.num_visible, self.mask_ratio,
|
318 |
+
self.randomize_num_visible
|
319 |
+
)
|
320 |
+
return repr_str
|
321 |
+
|
322 |
+
def sample_mask_per_frame(self, *args, **kwargs):
|
323 |
+
num_masks = self.num_masks_per_frame
|
324 |
+
if self.randomize_num_visible:
|
325 |
+
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
|
326 |
+
|
327 |
+
mask = torch.cat([
|
328 |
+
torch.zeros([self.num_patches_per_frame - num_masks]),
|
329 |
+
torch.ones([num_masks])], 0).bool()
|
330 |
+
inds = torch.randperm(mask.size(0)).long()
|
331 |
+
mask = mask[inds]
|
332 |
+
|
333 |
+
if max(*self.c) > 1:
|
334 |
+
mask = mask.view(self.height // self.c[0],
|
335 |
+
1,
|
336 |
+
self.width // self.c[1],
|
337 |
+
1)
|
338 |
+
mask = torch.tile(mask, (1, self.c[0], 1, self.c[1]))
|
339 |
+
mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w)
|
340 |
+
_pad_h = self.rng.choice(range(self.pad_h + 1))
|
341 |
+
pad_h = (self.pad_h - _pad_h, _pad_h)
|
342 |
+
_pad_w = self.rng.choice(range(self.pad_w + 1))
|
343 |
+
pad_w = (self.pad_w - _pad_w, _pad_w)
|
344 |
+
mask = F.pad(mask,
|
345 |
+
pad_w + pad_h,
|
346 |
+
mode='constant',
|
347 |
+
value=1)
|
348 |
+
mask = mask.reshape(self.height, self.width)
|
349 |
+
|
350 |
+
return mask
|
351 |
+
|
352 |
+
def forward(self, x=None, num_frames=None):
|
353 |
+
|
354 |
+
num_frames = (num_frames or self.frames) or 1
|
355 |
+
if isinstance(x, torch.Tensor):
|
356 |
+
batch_size = x.size(0)
|
357 |
+
masks = torch.stack([
|
358 |
+
torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
|
359 |
+
for b in range(batch_size)], 0)
|
360 |
+
if not self.create_on_cpu:
|
361 |
+
masks = masks.to(x.device)
|
362 |
+
if batch_size == 1 and not self.always_batch:
|
363 |
+
masks = masks.squeeze(0)
|
364 |
+
else:
|
365 |
+
batch_size = 1
|
366 |
+
masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
|
367 |
+
if self.always_batch:
|
368 |
+
masks = masks[None]
|
369 |
+
|
370 |
+
if self.visible_frames > 0:
|
371 |
+
vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool)
|
372 |
+
vis = vis.view(masks.shape).to(masks.device)
|
373 |
+
masks = torch.cat(([vis] * self.visible_frames) + [masks], -1)
|
374 |
+
|
375 |
+
return masks
|
cwm/eval/Flow/vis_utils.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def imshow(ims, ax=None, t=0, vmin=None, vmax=None, title=None, cmap=None, fontsize=20):
|
7 |
+
if ax is None:
|
8 |
+
fig, ax = plt.subplots(1,1)
|
9 |
+
with torch.no_grad():
|
10 |
+
im = ims[t].float().cpu().numpy().transpose((1,2,0))
|
11 |
+
if (vmin is not None) and (vmax is not None):
|
12 |
+
im =ax.imshow(im, vmin=vmin, vmax=vmax, cmap=(cmap or 'viridis'))
|
13 |
+
else:
|
14 |
+
im =ax.imshow(im)
|
15 |
+
|
16 |
+
if title is not None:
|
17 |
+
ax.set_title(title, fontsize=fontsize)
|
18 |
+
|
19 |
+
return (im, ax)
|
20 |
+
|
21 |
+
def make_colorwheel():
|
22 |
+
"""
|
23 |
+
Generates a color wheel for optical flow visualization as presented in:
|
24 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
25 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
26 |
+
|
27 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
28 |
+
Code follows the the Matlab source code of Deqing Sun.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
np.ndarray: Color wheel
|
32 |
+
"""
|
33 |
+
|
34 |
+
RY = 15
|
35 |
+
YG = 6
|
36 |
+
GC = 4
|
37 |
+
CB = 11
|
38 |
+
BM = 13
|
39 |
+
MR = 6
|
40 |
+
|
41 |
+
ncols = RY + YG + GC + CB + BM + MR
|
42 |
+
colorwheel = np.zeros((ncols, 3))
|
43 |
+
col = 0
|
44 |
+
|
45 |
+
# RY
|
46 |
+
colorwheel[0:RY, 0] = 255
|
47 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
48 |
+
col = col+RY
|
49 |
+
# YG
|
50 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
51 |
+
colorwheel[col:col+YG, 1] = 255
|
52 |
+
col = col+YG
|
53 |
+
# GC
|
54 |
+
colorwheel[col:col+GC, 1] = 255
|
55 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
56 |
+
col = col+GC
|
57 |
+
# CB
|
58 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
59 |
+
colorwheel[col:col+CB, 2] = 255
|
60 |
+
col = col+CB
|
61 |
+
# BM
|
62 |
+
colorwheel[col:col+BM, 2] = 255
|
63 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
64 |
+
col = col+BM
|
65 |
+
# MR
|
66 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
67 |
+
colorwheel[col:col+MR, 0] = 255
|
68 |
+
return colorwheel
|
69 |
+
|
70 |
+
|
71 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
72 |
+
"""
|
73 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
74 |
+
|
75 |
+
According to the C++ source code of Daniel Scharstein
|
76 |
+
According to the Matlab source code of Deqing Sun
|
77 |
+
|
78 |
+
Args:
|
79 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
80 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
81 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
85 |
+
"""
|
86 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
87 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
88 |
+
ncols = colorwheel.shape[0]
|
89 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
90 |
+
a = np.arctan2(-v, -u)/np.pi
|
91 |
+
fk = (a+1) / 2*(ncols-1)
|
92 |
+
k0 = np.floor(fk).astype(np.int32)
|
93 |
+
k1 = k0 + 1
|
94 |
+
k1[k1 == ncols] = 0
|
95 |
+
f = fk - k0
|
96 |
+
for i in range(colorwheel.shape[1]):
|
97 |
+
tmp = colorwheel[:,i]
|
98 |
+
col0 = tmp[k0] / 255.0
|
99 |
+
col1 = tmp[k1] / 255.0
|
100 |
+
col = (1-f)*col0 + f*col1
|
101 |
+
idx = (rad <= 1)
|
102 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
103 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
104 |
+
# Note the 2-i => BGR instead of RGB
|
105 |
+
ch_idx = 2-i if convert_to_bgr else i
|
106 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
107 |
+
return flow_image
|
108 |
+
|
109 |
+
|
110 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
111 |
+
"""
|
112 |
+
Expects a two dimensional flow image of shape.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
116 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
117 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
121 |
+
"""
|
122 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
123 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
124 |
+
if clip_flow is not None:
|
125 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
126 |
+
u = flow_uv[:,:,0]
|
127 |
+
v = flow_uv[:,:,1]
|
128 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
129 |
+
rad_max = np.max(rad)
|
130 |
+
epsilon = 1e-5
|
131 |
+
u = u / (rad_max + epsilon)
|
132 |
+
v = v / (rad_max + epsilon)
|
133 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
134 |
+
|
135 |
+
from decord import VideoReader, cpu
|
136 |
+
from PIL import Image
|
137 |
+
from torchvision import transforms
|
138 |
+
def get_video(video_name, num_frames=2, delta_time=4, frame=None):
|
139 |
+
decord_vr = VideoReader(video_name, num_threads=1, ctx=cpu(0))
|
140 |
+
max_end_ind = len(decord_vr) - num_frames*delta_time - 1
|
141 |
+
start_frame = frame if frame is not None else rng.randint(1, max_end_ind)
|
142 |
+
print("fps", decord_vr.get_avg_fps())
|
143 |
+
print("start frame = %d" % start_frame)
|
144 |
+
frame_id_list = list(range(start_frame, start_frame + num_frames*delta_time, delta_time))
|
145 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
146 |
+
video_data = [Image.fromarray(video_data[t]).convert('RGB') for t, _ in enumerate(frame_id_list)]
|
147 |
+
return (torch.stack([transforms.ToTensor()(im) for im in video_data], 0), start_frame)
|
148 |
+
|
149 |
+
|
150 |
+
|
cwm/eval/IntPhys/__init__.py
ADDED
File without changes
|
cwm/eval/Physion/__init__.py
ADDED
File without changes
|
cwm/eval/Physion/feature_extractor.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor
|
3 |
+
from physion_evaluator.utils import DataAugmentationForVideoMAE
|
4 |
+
|
5 |
+
from torch.functional import F
|
6 |
+
|
7 |
+
from cwm.eval.Flow.flow_utils import get_occ_masks
|
8 |
+
|
9 |
+
from cwm.model.model_factory import model_factory
|
10 |
+
import torch
|
11 |
+
|
12 |
+
def load_predictor(
|
13 |
+
model_func_,
|
14 |
+
load_path_,
|
15 |
+
**kwargs):
|
16 |
+
predictor = model_func_(**kwargs).eval().requires_grad_(False)
|
17 |
+
|
18 |
+
did_load = predictor.load_state_dict(
|
19 |
+
torch.load(load_path_, map_location=torch.device("cpu"))['model'])
|
20 |
+
predictor._predictor_load_path = load_path_
|
21 |
+
print(did_load, load_path_)
|
22 |
+
return predictor
|
23 |
+
|
24 |
+
|
25 |
+
class CWM(PhysionFeatureExtractor):
|
26 |
+
def __init__(self, model_name, aggregate_embeddings=False):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.model = model_factory.load_model(model_name).cuda().half()
|
30 |
+
|
31 |
+
self.num_frames = self.model.num_frames
|
32 |
+
|
33 |
+
self.timestamps = np.arange(self.num_frames)
|
34 |
+
|
35 |
+
ps = (224 // self.model.patch_size[1]) ** 2
|
36 |
+
|
37 |
+
self.bool_masked_pos = np.zeros([ps * self.num_frames])
|
38 |
+
self.bool_masked_pos[ps * (self.num_frames - 1):] = 1
|
39 |
+
|
40 |
+
self.ps = ps
|
41 |
+
|
42 |
+
self.aggregate_embeddings = aggregate_embeddings
|
43 |
+
|
44 |
+
def transform(self):
|
45 |
+
|
46 |
+
return DataAugmentationForVideoMAE(
|
47 |
+
imagenet_normalize=True,
|
48 |
+
rescale_size=224,
|
49 |
+
), 150, 4
|
50 |
+
|
51 |
+
def fwd(self, videos):
|
52 |
+
bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
|
53 |
+
bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
|
54 |
+
x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
|
55 |
+
return_features=True)
|
56 |
+
return x_encoded
|
57 |
+
|
58 |
+
def extract_features(self, videos, for_flow=False):
|
59 |
+
'''
|
60 |
+
videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
|
61 |
+
returns: [B, T, D] extracted features
|
62 |
+
'''
|
63 |
+
|
64 |
+
videos = videos.transpose(1, 2)
|
65 |
+
|
66 |
+
all_features = []
|
67 |
+
|
68 |
+
# repeat the last frame of the video
|
69 |
+
videos = torch.cat([videos, videos[:, :, -1:]], dim=2)
|
70 |
+
|
71 |
+
for x in range(0, 4, self.num_frames - 1):
|
72 |
+
vid = videos[:, :, x:x + self.num_frames, :, :]
|
73 |
+
all_features.append(self.fwd(vid))
|
74 |
+
if self.aggregate_embeddings:
|
75 |
+
feats = all_features[-1].mean(dim=1, keepdim=True)
|
76 |
+
all_features[-1] = feats
|
77 |
+
# feats = feats.view(feats.shape[0], -1, self.model.num_patches_per_frame, feats.shape[-1])
|
78 |
+
# feats = feats.mean(dim=2)
|
79 |
+
# all_features[-1] = feats
|
80 |
+
|
81 |
+
x_encoded = torch.cat(all_features, dim=1)
|
82 |
+
|
83 |
+
return x_encoded
|
84 |
+
|
85 |
+
|
86 |
+
class CWM_Keypoints(PhysionFeatureExtractor):
|
87 |
+
def __init__(self, model_name):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.model = model_factory.load_model(model_name).cuda().half()
|
91 |
+
|
92 |
+
self.frames = [[0, 1, 2], [1, 2, 3]]
|
93 |
+
|
94 |
+
self.num_frames = self.model.num_frames
|
95 |
+
|
96 |
+
self.ps = (224 // self.model.patch_size[1]) ** 2
|
97 |
+
|
98 |
+
self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
|
99 |
+
self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1
|
100 |
+
|
101 |
+
self.frame_gap = 150
|
102 |
+
|
103 |
+
self.num_frames_dataset = 4
|
104 |
+
|
105 |
+
self.res = 224
|
106 |
+
|
107 |
+
|
108 |
+
def transform(self):
|
109 |
+
|
110 |
+
return DataAugmentationForVideoMAE(
|
111 |
+
imagenet_normalize=True,
|
112 |
+
rescale_size=self.res,
|
113 |
+
), self.frame_gap, self.num_frames_dataset
|
114 |
+
|
115 |
+
def fwd(self, videos):
|
116 |
+
bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
|
117 |
+
bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
|
118 |
+
_, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
|
119 |
+
return_features=True)
|
120 |
+
return x_encoded
|
121 |
+
|
122 |
+
def extract_features(self, videos, segments=None):
|
123 |
+
'''
|
124 |
+
videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
|
125 |
+
returns: [B, T, D] extracted features
|
126 |
+
'''
|
127 |
+
|
128 |
+
videos = videos.transpose(1, 2)
|
129 |
+
|
130 |
+
all_features = []
|
131 |
+
|
132 |
+
for x, arr in enumerate(self.frames):
|
133 |
+
|
134 |
+
#use the downsampled videos for keypoints
|
135 |
+
vid = videos[:, :, arr, :, :].half()
|
136 |
+
frame0 = vid[:, :, 0]
|
137 |
+
frame1 = vid[:, :, 1]
|
138 |
+
frame2 = vid[:, :, 2]
|
139 |
+
|
140 |
+
#extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
|
141 |
+
mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1)
|
142 |
+
|
143 |
+
#reshape the features to [batch size, num_features]
|
144 |
+
k_feat = k_feat.view(k_feat.shape[0], -1)
|
145 |
+
|
146 |
+
all_features.append(k_feat)
|
147 |
+
|
148 |
+
x_encoded = torch.cat(all_features, dim=1)
|
149 |
+
|
150 |
+
return x_encoded
|
151 |
+
|
152 |
+
|
153 |
+
class CWM_KeypointsFlow(PhysionFeatureExtractor):
|
154 |
+
def __init__(self, model_name):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.model = model_factory.load_model(model_name).cuda().half()
|
158 |
+
|
159 |
+
self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]]
|
160 |
+
|
161 |
+
self.num_frames = self.model.num_frames
|
162 |
+
|
163 |
+
self.timestamps = np.arange(self.num_frames)
|
164 |
+
|
165 |
+
self.ps = (224 // self.model.patch_size[1]) ** 2
|
166 |
+
|
167 |
+
self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
|
168 |
+
self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1
|
169 |
+
|
170 |
+
self.frame_gap = 50
|
171 |
+
|
172 |
+
self.num_frames_dataset = 9
|
173 |
+
|
174 |
+
self.res = 512
|
175 |
+
|
176 |
+
def transform(self):
|
177 |
+
|
178 |
+
return DataAugmentationForVideoMAE(
|
179 |
+
imagenet_normalize=True,
|
180 |
+
rescale_size=self.res,
|
181 |
+
), self.frame_gap, self.num_frames_dataset
|
182 |
+
|
183 |
+
def fwd(self, videos):
|
184 |
+
bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
|
185 |
+
bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
|
186 |
+
_, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
|
187 |
+
return_features=True)
|
188 |
+
return x_encoded
|
189 |
+
|
190 |
+
def get_forward_flow(self, videos):
|
191 |
+
|
192 |
+
fid = 6
|
193 |
+
|
194 |
+
forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine')
|
195 |
+
|
196 |
+
backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine')
|
197 |
+
|
198 |
+
occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0]
|
199 |
+
|
200 |
+
forward_flow = forward_flow * occlusion_mask
|
201 |
+
|
202 |
+
forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1)
|
203 |
+
|
204 |
+
forward_flow = forward_flow.to(videos.device)
|
205 |
+
|
206 |
+
forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest')
|
207 |
+
|
208 |
+
return forward_flow
|
209 |
+
|
210 |
+
def extract_features(self, videos, segments=None):
|
211 |
+
'''
|
212 |
+
videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
|
213 |
+
returns: [B, T, D] extracted features
|
214 |
+
Note:
|
215 |
+
For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient
|
216 |
+
for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed,
|
217 |
+
depending on the complexity of the task.
|
218 |
+
'''
|
219 |
+
|
220 |
+
|
221 |
+
#resize to 224 to get keypoints and features
|
222 |
+
videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False)
|
223 |
+
videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224)
|
224 |
+
|
225 |
+
#for computing flow at higher resolution
|
226 |
+
videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False)
|
227 |
+
videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024)
|
228 |
+
|
229 |
+
videos = videos.transpose(1, 2).half()
|
230 |
+
videos_downsampled = videos_downsampled.transpose(1, 2).half()
|
231 |
+
|
232 |
+
# Get the forward flow for the frame at 300ms
|
233 |
+
forward_flow = self.get_forward_flow(videos)
|
234 |
+
|
235 |
+
# Verify that there are no nans forward flow
|
236 |
+
assert not torch.isnan(forward_flow).any(), "Forward flow is nan"
|
237 |
+
|
238 |
+
all_features = []
|
239 |
+
|
240 |
+
for x, arr in enumerate(self.frames):
|
241 |
+
|
242 |
+
#use the downsampled videos for keypoints
|
243 |
+
vid = videos_downsampled[:, :, arr, :, :]
|
244 |
+
frame0 = vid[:, :, 0]
|
245 |
+
frame1 = vid[:, :, 1]
|
246 |
+
frame2 = vid[:, :, 2]
|
247 |
+
|
248 |
+
#extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
|
249 |
+
mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1)
|
250 |
+
|
251 |
+
#for the last set of frames only use features at keypoint regions of frame2
|
252 |
+
if (x == 2):
|
253 |
+
k_feat = k_feat[:, -10:, :]
|
254 |
+
|
255 |
+
#reshape the features to [batch size, num_features]
|
256 |
+
k_feat = k_feat.view(k_feat.shape[0], -1)
|
257 |
+
|
258 |
+
choices_image_resolution = choices * self.model.patch_size[1]
|
259 |
+
|
260 |
+
# At 300ms, add optical flow patches at the detected keypoint locations
|
261 |
+
# For the first frame (x == 0)
|
262 |
+
if x == 0:
|
263 |
+
# Extract the optical flow information from the forward flow matrix for the second channel (index 2)
|
264 |
+
flow_keyp = forward_flow[:, 2]
|
265 |
+
|
266 |
+
# Initialize a result tensor to store the flow patches
|
267 |
+
# Tensor shape: [batch_size, 8x8 patch (flattened to 64) * 2 channels, 10 keypoints]
|
268 |
+
flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device)
|
269 |
+
|
270 |
+
# Patch size shift (since 8x8 patches are being extracted)
|
271 |
+
shift = 8
|
272 |
+
|
273 |
+
# Loop over each element in the batch to process individual video frames
|
274 |
+
for b in range(flow_keyp.size(0)):
|
275 |
+
# Extract the x and y coordinates of the keypoint locations for this batch element
|
276 |
+
x_indices = choices_image_resolution[b, :, 0]
|
277 |
+
y_indices = choices_image_resolution[b, :, 1]
|
278 |
+
|
279 |
+
# For each keypoint (10 total keypoints in this case)
|
280 |
+
for ind in range(10):
|
281 |
+
# Extract the 8x8 patch of optical flow at each keypoint's (x, y) location
|
282 |
+
# Flatten the patch and assign it to the corresponding slice in the result tensor
|
283 |
+
flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift,
|
284 |
+
x_indices[ind]:x_indices[ind] + shift].flatten()
|
285 |
+
|
286 |
+
# Reshape the flow tensor for easier concatenation (flatten across all patches)
|
287 |
+
flow = flow.view(flow.shape[0], -1)
|
288 |
+
|
289 |
+
# Concatenate the extracted optical flow features with the existing feature tensor (k_feat)
|
290 |
+
k_feat = torch.cat([k_feat, flow], dim=1)
|
291 |
+
|
292 |
+
all_features.append(k_feat)
|
293 |
+
|
294 |
+
x_encoded = torch.cat(all_features, dim=1)
|
295 |
+
|
296 |
+
return x_encoded
|
297 |
+
|
298 |
+
|
299 |
+
class CWM_base_8x8_3frame(CWM):
|
300 |
+
def __init__(self,):
|
301 |
+
super().__init__('vitb_8x8patch_3frames')
|
302 |
+
|
303 |
+
class CWM_base_8x8_3frame_mean_embed(CWM):
|
304 |
+
def __init__(self,):
|
305 |
+
super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True)
|
306 |
+
|
307 |
+
# CWM* (keypoints only) 74.7
|
308 |
+
class CWM_base_8x8_3frame_keypoints(CWM_Keypoints):
|
309 |
+
def __init__(self,):
|
310 |
+
super().__init__('vitb_8x8patch_3frames')
|
311 |
+
|
312 |
+
|
313 |
+
# CWM* (keypoints + Flow) 75.4
|
314 |
+
class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow):
|
315 |
+
def __init__(self,):
|
316 |
+
super().__init__('vitb_8x8patch_3frames')
|
317 |
+
|
cwm/eval/Physion/flow_utils.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import math
|
6 |
+
|
7 |
+
def create_weighted_mask_batched(h, w):
|
8 |
+
y_mask = np.linspace(0, 1, h)
|
9 |
+
y_mask = np.minimum(y_mask, 1 - y_mask)
|
10 |
+
x_mask = np.linspace(0, 1, w)
|
11 |
+
x_mask = np.minimum(x_mask, 1 - x_mask)
|
12 |
+
weighted_mask = np.outer(y_mask, x_mask)
|
13 |
+
return torch.from_numpy(weighted_mask).float()
|
14 |
+
|
15 |
+
def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
|
16 |
+
B, T, C, H, W = original_shape
|
17 |
+
|
18 |
+
# Initialize an empty tensor to store the reconstructed video
|
19 |
+
reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
|
20 |
+
|
21 |
+
# Create a tensor to store the sum of weighted masks
|
22 |
+
weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
|
23 |
+
|
24 |
+
# Create a weighted mask for the crops
|
25 |
+
weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
|
26 |
+
weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
|
27 |
+
|
28 |
+
for idx, crop in enumerate(cropped_tensors):
|
29 |
+
start_h, start_w = crop_positions[idx]
|
30 |
+
|
31 |
+
# Multiply the crop with the weighted mask
|
32 |
+
weighted_crop = crop * weighted_mask
|
33 |
+
|
34 |
+
# Add the weighted crop to the corresponding location in the reconstructed_video tensor
|
35 |
+
reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
|
36 |
+
|
37 |
+
# Update the weighted_masks_sum tensor
|
38 |
+
weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
|
39 |
+
|
40 |
+
# Add a small epsilon value to avoid division by zero
|
41 |
+
epsilon = 1e-8
|
42 |
+
|
43 |
+
# Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
|
44 |
+
reconstructed_video /= (weighted_masks_sum + epsilon)
|
45 |
+
|
46 |
+
return reconstructed_video
|
47 |
+
|
48 |
+
import torch.nn.functional as F
|
49 |
+
|
50 |
+
resize = lambda x,a: F.interpolate(x, [int(a*x.shape[-2]), int(a*x.shape[-1])], mode='bilinear', align_corners=False)
|
51 |
+
|
52 |
+
upsample = lambda x,H,W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
#
|
57 |
+
def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
|
58 |
+
# Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
|
59 |
+
mask_unrolled = mask_tensor.view(-1)
|
60 |
+
|
61 |
+
second_frame_unmask_indices = torch.where(mask_unrolled[frame_size**2:] == False)[0]
|
62 |
+
|
63 |
+
# Divide the embedding tensor into two parts: corresponding to the first and the second frame
|
64 |
+
first_frame_embeddings = embedding_tensor[0, :frame_size**2, :]
|
65 |
+
second_frame_embeddings = embedding_tensor[0, frame_size**2:, :]
|
66 |
+
|
67 |
+
# Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
|
68 |
+
dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
|
69 |
+
norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
|
70 |
+
cos_sim_matrix = dot_product / norms
|
71 |
+
|
72 |
+
# Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
|
73 |
+
first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
|
74 |
+
|
75 |
+
# Convert the 1D pixel indices into 2D coordinates
|
76 |
+
second_frame_y = second_frame_unmask_indices // frame_size
|
77 |
+
second_frame_x = second_frame_unmask_indices % frame_size
|
78 |
+
first_frame_y = first_frame_most_similar_indices // frame_size
|
79 |
+
first_frame_x = first_frame_most_similar_indices % frame_size
|
80 |
+
|
81 |
+
# Compute the x and y displacements and convert them to float
|
82 |
+
displacements_x = (second_frame_x - first_frame_x).float()
|
83 |
+
displacements_y = (second_frame_y - first_frame_y).float()
|
84 |
+
|
85 |
+
# Initialize optical flow tensor
|
86 |
+
optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
|
87 |
+
|
88 |
+
# Assign the computed displacements to the corresponding pixels in the optical flow tensor
|
89 |
+
optical_flow[0, second_frame_y, second_frame_x] = displacements_x
|
90 |
+
optical_flow[1, second_frame_y, second_frame_x] = displacements_y
|
91 |
+
|
92 |
+
return optical_flow
|
93 |
+
|
94 |
+
def get_minimal_224_crops_new_batched(video_tensor, N):
|
95 |
+
B, T, C, H, W = video_tensor.shape
|
96 |
+
|
97 |
+
# Calculate the number of crops needed in both the height and width dimensions
|
98 |
+
num_crops_h = math.ceil(H / 224) if H > 224 else 1
|
99 |
+
num_crops_w = math.ceil(W / 224) if W > 224 else 1
|
100 |
+
|
101 |
+
# Calculate the step size for the height and width dimensions
|
102 |
+
step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
|
103 |
+
step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
|
104 |
+
|
105 |
+
# Create a list to store the cropped tensors and their start positions
|
106 |
+
cropped_tensors = []
|
107 |
+
crop_positions = []
|
108 |
+
|
109 |
+
# Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
|
110 |
+
for i in range(num_crops_h):
|
111 |
+
for j in range(num_crops_w):
|
112 |
+
start_h = i * step_size_h
|
113 |
+
start_w = j * step_size_w
|
114 |
+
end_h = min(start_h + 224, H)
|
115 |
+
end_w = min(start_w + 224, W)
|
116 |
+
crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
|
117 |
+
cropped_tensors.append(crop)
|
118 |
+
crop_positions.append((start_h, start_w))
|
119 |
+
|
120 |
+
D = len(cropped_tensors)
|
121 |
+
|
122 |
+
# If N is greater than D, generate additional random crops
|
123 |
+
if N > D and H > 224 and W > 224: # check if H and W are greater than 224
|
124 |
+
for _ in range(N - D):
|
125 |
+
start_h = random.randint(0, H - 224)
|
126 |
+
start_w = random.randint(0, W - 224)
|
127 |
+
crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
|
128 |
+
cropped_tensors.append(crop)
|
129 |
+
crop_positions.append((start_h, start_w))
|
130 |
+
|
131 |
+
# Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
|
132 |
+
cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
|
133 |
+
|
134 |
+
return cropped_tensors, crop_positions
|
135 |
+
|
136 |
+
def get_honglin_3frame_vmae_optical_flow_crop_batched(generator,
|
137 |
+
mask_generator,
|
138 |
+
img1,
|
139 |
+
img2,
|
140 |
+
img3,
|
141 |
+
neg_back_flow=True,
|
142 |
+
num_scales=1,
|
143 |
+
min_scale=400,
|
144 |
+
N_mask_samples=100,
|
145 |
+
mask_ratio=0.8,
|
146 |
+
flow_frames='23'):
|
147 |
+
B = img1.shape[0]
|
148 |
+
assert len(img1.shape) == 4
|
149 |
+
assert num_scales >= 1
|
150 |
+
|
151 |
+
# For scaling
|
152 |
+
h1 = img2.shape[-2]
|
153 |
+
w1 = img2.shape[-1]
|
154 |
+
assert min_scale < h1
|
155 |
+
|
156 |
+
if neg_back_flow is False:
|
157 |
+
print('WARNING: Not calculating negative backward flow')
|
158 |
+
|
159 |
+
alpha = (min_scale / img1.shape[-2]) ** (1 / 4)
|
160 |
+
|
161 |
+
frame_size = 224 // generator.patch_size[-1]
|
162 |
+
|
163 |
+
patch_size = generator.patch_size[-1]
|
164 |
+
|
165 |
+
all_fwd_flows_e2d = []
|
166 |
+
|
167 |
+
for aidx in range(num_scales):
|
168 |
+
|
169 |
+
# print('aidx: ', aidx)
|
170 |
+
|
171 |
+
img1_scaled = resize(img1.clone(), alpha ** aidx)
|
172 |
+
img2_scaled = resize(img2.clone(), alpha ** aidx)
|
173 |
+
img3_scaled = resize(img3.clone(), alpha ** aidx)
|
174 |
+
|
175 |
+
h2 = img2_scaled.shape[-2]
|
176 |
+
w2 = img2_scaled.shape[-1]
|
177 |
+
|
178 |
+
s_h = h1 / h2
|
179 |
+
s_w = w1 / w2
|
180 |
+
|
181 |
+
# Because technically the compute_optical_flow function returns neg back flow
|
182 |
+
if neg_back_flow is True:
|
183 |
+
video = torch.cat([img3_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
|
184 |
+
else:
|
185 |
+
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img3_scaled.unsqueeze(1)], 1)
|
186 |
+
|
187 |
+
# Should work, even if the incoming video is already 224x224
|
188 |
+
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
|
189 |
+
|
190 |
+
# print(len(crops1), crops1[0].shape)
|
191 |
+
|
192 |
+
num_crops = len(crops1)
|
193 |
+
|
194 |
+
crop_flows_enc = []
|
195 |
+
crop_flows_enc2dec = []
|
196 |
+
N_samples = N_mask_samples
|
197 |
+
|
198 |
+
crop = torch.cat(crops1, 0).cuda()
|
199 |
+
# print(crop.shape)
|
200 |
+
|
201 |
+
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
|
202 |
+
mask_counts = torch.zeros(frame_size, frame_size).cuda()
|
203 |
+
|
204 |
+
i = 0
|
205 |
+
while i < N_samples or (mask_counts == 0).any().item():
|
206 |
+
if i % 100 == 0:
|
207 |
+
pass # print(i)
|
208 |
+
mask_generator.mask_ratio = mask_ratio
|
209 |
+
|
210 |
+
# breakpoint()
|
211 |
+
# This would be that every sample has the same mask. For now that's okay I think
|
212 |
+
mask = mask_generator(num_frames=3)[None]
|
213 |
+
mask_2f = ~mask[0, frame_size * frame_size * 2:]
|
214 |
+
mask_counts += mask_2f.reshape(frame_size, frame_size)
|
215 |
+
|
216 |
+
with torch.cuda.amp.autocast(enabled=True):
|
217 |
+
|
218 |
+
processed_x = crop.transpose(1, 2)
|
219 |
+
|
220 |
+
# print("crop", processed_x.max())
|
221 |
+
|
222 |
+
encoder_out = generator.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
|
223 |
+
encoder_to_decoder = generator.encoder_to_decoder(encoder_out)
|
224 |
+
# print(encoder_to_decoder.shape)
|
225 |
+
|
226 |
+
if flow_frames == '23':
|
227 |
+
encoder_to_decoder = encoder_to_decoder[:, frame_size * frame_size:, :]
|
228 |
+
flow_mask = mask[:, frame_size * frame_size:]
|
229 |
+
# print(encoder_to_decoder.shape)
|
230 |
+
elif flow_frames == '12':
|
231 |
+
encoder_to_decoder = encoder_to_decoder[:, :frame_size * frame_size * 2, :]
|
232 |
+
# print(encoder_to_decoder.shape)
|
233 |
+
flow_mask = mask[:, :frame_size * frame_size * 2]
|
234 |
+
# print(mask.shape)
|
235 |
+
# print(flow_mask.shape)
|
236 |
+
# print()
|
237 |
+
|
238 |
+
optical_flow_e2d = []
|
239 |
+
# one per batch element for now
|
240 |
+
for b in range(B * num_crops):
|
241 |
+
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
|
242 |
+
optical_flow_e2d.append(batch_flow.unsqueeze(0))
|
243 |
+
|
244 |
+
optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
|
245 |
+
optical_flows_enc2dec += optical_flow_e2d
|
246 |
+
i += 1
|
247 |
+
|
248 |
+
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
|
249 |
+
|
250 |
+
scale_factor_y = video.shape[-2] / 224
|
251 |
+
scale_factor_x = video.shape[-1] / 224
|
252 |
+
|
253 |
+
scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
|
254 |
+
scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
|
255 |
+
scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
|
256 |
+
|
257 |
+
# split the crops back up
|
258 |
+
crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
|
259 |
+
# print(len(crop_flows_enc2dec))
|
260 |
+
|
261 |
+
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
|
262 |
+
[_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
|
263 |
+
crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
|
264 |
+
|
265 |
+
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
|
266 |
+
|
267 |
+
all_fwd_flows_e2d_new = []
|
268 |
+
|
269 |
+
for r in all_fwd_flows_e2d:
|
270 |
+
new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
|
271 |
+
all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
|
272 |
+
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
|
273 |
+
|
274 |
+
if neg_back_flow is True:
|
275 |
+
return_flow = -return_flow
|
276 |
+
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
|
277 |
+
|
278 |
+
return return_flow, all_fwd_flows_e2d_new
|
279 |
+
|
cwm/eval/Physion/run_eval.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#physion_feature_extract \
|
2 |
+
#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
|
3 |
+
#--data_root_path /ccn2/u/rmvenkat/data/testing_physion/regenerate_from_old_commit/ \
|
4 |
+
#--model_class feature_extractor.CWM_base_8x8_3frame \
|
5 |
+
#--gpu 1 \
|
6 |
+
#--batch_size 8 \
|
7 |
+
#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release/ \
|
8 |
+
#--mode ocp
|
9 |
+
|
10 |
+
physion_train_readout \
|
11 |
+
--train-path /ccn2/u/rmvenkat/data/physion_release/ocp/train_features.hdf5 \
|
12 |
+
--test-path /ccn2/u/rmvenkat/data/physion_release/ocp/test_features.hdf5 \
|
13 |
+
--model-name CWM_base_8x8_3frame \
|
14 |
+
--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/train_json.json \
|
15 |
+
--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/test_json.json \
|
16 |
+
--test-scenario-map /ccn2/u/rmvenkat/data/physion_release/ocp/test_scenario_map.json \
|
17 |
+
--save_path /ccn2/u/rmvenkat/data/physion_release/
|
cwm/eval/Physion/run_eval_kfflow.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#physion_feature_extract \
|
2 |
+
#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
|
3 |
+
#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
|
4 |
+
#--model_class feature_extractor.CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \
|
5 |
+
#--gpu 1 \
|
6 |
+
#--batch_size 8 \
|
7 |
+
#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_kfflow/ \
|
8 |
+
#--mode ocp
|
9 |
+
|
10 |
+
|
11 |
+
physion_train_readout \
|
12 |
+
--train-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_features.hdf5 \
|
13 |
+
--test-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_features.hdf5 \
|
14 |
+
--model-name CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \
|
15 |
+
--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_json.json \
|
16 |
+
--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_json.json \
|
17 |
+
--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_scenario_map.json \
|
18 |
+
--save_path /ccn2/u/rmvenkat/data/physion_release_kfflow/
|
cwm/eval/Physion/run_eval_mp4s.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dir_for_saving=/ccn2/u/rmvenkat/data/physion_release/
|
2 |
+
model_name=CWM_base_8x8_3frame
|
3 |
+
|
4 |
+
physion_feature_extract \
|
5 |
+
--data_root_path /ccn2/u/rmvenkat/data/download_test/physion_mp4s/ \
|
6 |
+
--model_class feature_extractor.$model_name \
|
7 |
+
--gpu 1 \
|
8 |
+
--batch_size 8 \
|
9 |
+
--dir_for_saving $dir_for_saving \
|
10 |
+
--mode ocd
|
11 |
+
|
12 |
+
physion_train_readout \
|
13 |
+
--train-path ${dir_for_saving}ocd/train_features.hdf5 \
|
14 |
+
--test-path ${dir_for_saving}ocd/test_features.hdf5 \
|
15 |
+
--model-name $model_name \
|
16 |
+
--train-scenario-indices ${dir_for_saving}ocd/train_json.json \
|
17 |
+
--test-scenario-indices ${dir_for_saving}ocd/test_json.json \
|
18 |
+
--test-scenario-map ${dir_for_saving}ocd/test_scenario_map.json \
|
19 |
+
--save_path $dir_for_saving
|
cwm/eval/Physion/run_eval_mp4s_keyp.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#physion_feature_extract \
|
2 |
+
#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
|
3 |
+
#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
|
4 |
+
#--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints \
|
5 |
+
#--gpu 3 \
|
6 |
+
#--batch_size 8 \
|
7 |
+
#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp/ \
|
8 |
+
#--mode ocp
|
9 |
+
|
10 |
+
physion_train_readout \
|
11 |
+
--train-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_features.hdf5 \
|
12 |
+
--test-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_features.hdf5 \
|
13 |
+
--model-name CWM_base_8x8_3frame \
|
14 |
+
--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_json.json \
|
15 |
+
--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_json.json \
|
16 |
+
--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_scenario_map.json \
|
17 |
+
--save_path /ccn2/u/rmvenkat/data/physion_release_keyp/
|
cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#physion_feature_extract \
|
2 |
+
#--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
|
3 |
+
#--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
|
4 |
+
#--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints_flow \
|
5 |
+
#--gpu 7 \
|
6 |
+
#--batch_size 8 \
|
7 |
+
#--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ \
|
8 |
+
#--mode ocp
|
9 |
+
|
10 |
+
physion_train_readout \
|
11 |
+
--train-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_features.hdf5 \
|
12 |
+
--test-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_features.hdf5 \
|
13 |
+
--model-name CWM_base_8x8_3frame \
|
14 |
+
--train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_json.json \
|
15 |
+
--test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_json.json \
|
16 |
+
--test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_scenario_map.json \
|
17 |
+
--save_path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/
|
cwm/eval/Segmentation/__init__.py
ADDED
File without changes
|
cwm/eval/Segmentation/archive/__init__.py
ADDED
File without changes
|
cwm/eval/Segmentation/archive/common/__init__.py
ADDED
File without changes
|
cwm/eval/Segmentation/archive/common/coco_loader_lsj.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import detectron2.data.transforms as T
|
2 |
+
from detectron2 import model_zoo
|
3 |
+
from detectron2.config import LazyCall as L
|
4 |
+
|
5 |
+
# Data using LSJ
|
6 |
+
image_size = 512
|
7 |
+
dataloader = model_zoo.get_config("common/data/coco.py").dataloader
|
8 |
+
dataloader.train.mapper.augmentations = [
|
9 |
+
L(T.RandomFlip)(horizontal=True), # flip first
|
10 |
+
L(T.ResizeScale)(
|
11 |
+
min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
|
12 |
+
),
|
13 |
+
L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
|
14 |
+
]
|
15 |
+
dataloader.train.mapper.image_format = "RGB"
|
16 |
+
dataloader.train.total_batch_size = 64
|
17 |
+
dataloader.train.num_workers = 0
|
18 |
+
# recompute boxes due to cropping
|
19 |
+
dataloader.train.mapper.recompute_boxes = True
|
20 |
+
|
21 |
+
dataloader.test.mapper.augmentations = [
|
22 |
+
L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
import copy
|
29 |
+
import logging
|
30 |
+
import numpy as np
|
31 |
+
from typing import List, Optional, Union
|
32 |
+
import torch
|
33 |
+
|
34 |
+
from detectron2.config import configurable
|
35 |
+
|
36 |
+
from detectron2.data import detection_utils as utils
|
37 |
+
from detectron2.data import transforms as T
|
38 |
+
|
39 |
+
"""
|
40 |
+
This file contains the default mapping that's applied to "dataset dicts".
|
41 |
+
"""
|
42 |
+
|
43 |
+
__all__ = ["DatasetMapper"]
|
44 |
+
|
45 |
+
|
46 |
+
class DatasetMapper:
|
47 |
+
"""
|
48 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
49 |
+
and map it into a format used by the model.
|
50 |
+
|
51 |
+
This is the default callable to be used to map your dataset dict into training data.
|
52 |
+
You may need to follow it to implement your own one for customized logic,
|
53 |
+
such as a different way to read or transform images.
|
54 |
+
See :doc:`/tutorials/data_loading` for details.
|
55 |
+
|
56 |
+
The callable currently does the following:
|
57 |
+
|
58 |
+
1. Read the image from "file_name"
|
59 |
+
2. Applies cropping/geometric transforms to the image and annotations
|
60 |
+
3. Prepare data and annotations to Tensor and :class:`Instances`
|
61 |
+
"""
|
62 |
+
|
63 |
+
@configurable
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
is_train: bool,
|
67 |
+
*,
|
68 |
+
augmentations: List[Union[T.Augmentation, T.Transform]],
|
69 |
+
image_format: str,
|
70 |
+
use_instance_mask: bool = False,
|
71 |
+
use_keypoint: bool = False,
|
72 |
+
instance_mask_format: str = "polygon",
|
73 |
+
keypoint_hflip_indices: Optional[np.ndarray] = None,
|
74 |
+
precomputed_proposal_topk: Optional[int] = None,
|
75 |
+
recompute_boxes: bool = False,
|
76 |
+
):
|
77 |
+
"""
|
78 |
+
NOTE: this interface is experimental.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
is_train: whether it's used in training or inference
|
82 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
83 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
84 |
+
use_instance_mask: whether to process instance segmentation annotations, if available
|
85 |
+
use_keypoint: whether to process keypoint annotations if available
|
86 |
+
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
|
87 |
+
masks into this format.
|
88 |
+
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
|
89 |
+
precomputed_proposal_topk: if given, will load pre-computed
|
90 |
+
proposals from dataset_dict and keep the top k proposals for each image.
|
91 |
+
recompute_boxes: whether to overwrite bounding box annotations
|
92 |
+
by computing tight bounding boxes from instance mask annotations.
|
93 |
+
"""
|
94 |
+
if recompute_boxes:
|
95 |
+
assert use_instance_mask, "recompute_boxes requires instance masks"
|
96 |
+
# fmt: off
|
97 |
+
self.is_train = is_train
|
98 |
+
self.augmentations = T.AugmentationList(augmentations)
|
99 |
+
self.image_format = image_format
|
100 |
+
self.use_instance_mask = use_instance_mask
|
101 |
+
self.instance_mask_format = instance_mask_format
|
102 |
+
self.use_keypoint = use_keypoint
|
103 |
+
self.keypoint_hflip_indices = keypoint_hflip_indices
|
104 |
+
self.proposal_topk = precomputed_proposal_topk
|
105 |
+
self.recompute_boxes = recompute_boxes
|
106 |
+
# fmt: on
|
107 |
+
logger = logging.getLogger(__name__)
|
108 |
+
mode = "training" if is_train else "inference"
|
109 |
+
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def from_config(cls, cfg, is_train: bool = True):
|
113 |
+
augs = utils.build_augmentation(cfg, is_train)
|
114 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
115 |
+
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
|
116 |
+
recompute_boxes = cfg.MODEL.MASK_ON
|
117 |
+
else:
|
118 |
+
recompute_boxes = False
|
119 |
+
|
120 |
+
ret = {
|
121 |
+
"is_train": is_train,
|
122 |
+
"augmentations": augs,
|
123 |
+
"image_format": cfg.INPUT.FORMAT,
|
124 |
+
"use_instance_mask": cfg.MODEL.MASK_ON,
|
125 |
+
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
|
126 |
+
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
|
127 |
+
"recompute_boxes": recompute_boxes,
|
128 |
+
}
|
129 |
+
|
130 |
+
if cfg.MODEL.KEYPOINT_ON:
|
131 |
+
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
132 |
+
|
133 |
+
if cfg.MODEL.LOAD_PROPOSALS:
|
134 |
+
ret["precomputed_proposal_topk"] = (
|
135 |
+
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
136 |
+
if is_train
|
137 |
+
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
138 |
+
)
|
139 |
+
return ret
|
140 |
+
|
141 |
+
def _transform_annotations(self, dataset_dict, transforms, image_shape):
|
142 |
+
# USER: Modify this if you want to keep them for some reason.
|
143 |
+
for anno in dataset_dict["annotations"]:
|
144 |
+
if not self.use_instance_mask:
|
145 |
+
anno.pop("segmentation", None)
|
146 |
+
if not self.use_keypoint:
|
147 |
+
anno.pop("keypoints", None)
|
148 |
+
|
149 |
+
# USER: Implement additional transformations if you have other types of data
|
150 |
+
annos = [
|
151 |
+
utils.transform_instance_annotations(
|
152 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
153 |
+
)
|
154 |
+
for obj in dataset_dict.pop("annotations")
|
155 |
+
if obj.get("iscrowd", 0) == 0
|
156 |
+
]
|
157 |
+
instances = utils.annotations_to_instances(
|
158 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
159 |
+
)
|
160 |
+
|
161 |
+
# After transforms such as cropping are applied, the bounding box may no longer
|
162 |
+
# tightly bound the object. As an example, imagine a triangle object
|
163 |
+
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
|
164 |
+
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
|
165 |
+
# the intersection of original bounding box and the cropping box.
|
166 |
+
if self.recompute_boxes:
|
167 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
168 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
169 |
+
|
170 |
+
def __call__(self, dataset_dict):
|
171 |
+
"""
|
172 |
+
Args:
|
173 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
dict: a format that builtin models in detectron2 accept
|
177 |
+
"""
|
178 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
179 |
+
# USER: Write your own image loading if it's not from a file
|
180 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
|
181 |
+
utils.check_image_size(dataset_dict, image)
|
182 |
+
|
183 |
+
# USER: Remove if you don't do semantic/panoptic segmentation.
|
184 |
+
if "sem_seg_file_name" in dataset_dict:
|
185 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
|
186 |
+
else:
|
187 |
+
sem_seg_gt = None
|
188 |
+
|
189 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
190 |
+
transforms = self.augmentations(aug_input)
|
191 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
192 |
+
|
193 |
+
image_shape = image.shape[:2] # h, w
|
194 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
195 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
196 |
+
# Therefore it's important to use torch.Tensor.
|
197 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
198 |
+
if sem_seg_gt is not None:
|
199 |
+
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
|
200 |
+
|
201 |
+
# USER: Remove if you don't use pre-computed proposals.
|
202 |
+
# Most users would not need this feature.
|
203 |
+
if self.proposal_topk is not None:
|
204 |
+
utils.transform_proposals(
|
205 |
+
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
|
206 |
+
)
|
207 |
+
|
208 |
+
if not self.is_train:
|
209 |
+
# USER: Modify this if you want to keep them for some reason.
|
210 |
+
dataset_dict.pop("annotations", None)
|
211 |
+
dataset_dict.pop("sem_seg_file_name", None)
|
212 |
+
return dataset_dict
|
213 |
+
|
214 |
+
if "annotations" in dataset_dict:
|
215 |
+
self._transform_annotations(dataset_dict, transforms, image_shape)
|
216 |
+
|
217 |
+
# Modified by Honglin Chen: change it to class-agnostic instance labels
|
218 |
+
dataset_dict['instances'].gt_classes *= 0
|
219 |
+
return dataset_dict
|
220 |
+
|
221 |
+
|
222 |
+
dataloader.train.mapper._target_ = DatasetMapper
|
cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
parser = argparse.ArgumentParser()
|
5 |
+
parser.add_argument('--input', type=str, help='The path to the checkpoint.')
|
6 |
+
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
|
7 |
+
args = parser.parse_args()
|
8 |
+
|
9 |
+
state_dict = torch.load(args.input, map_location='cpu')['model']
|
10 |
+
|
11 |
+
new_state_dict = {}
|
12 |
+
for k, v in state_dict.items():
|
13 |
+
if 'encoder' in k and not 'decoder' in k:
|
14 |
+
new_k = 'backbone.net.model.' + k
|
15 |
+
new_state_dict[new_k] = v
|
16 |
+
|
17 |
+
output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
|
18 |
+
torch.save(new_state_dict, output_path)
|
19 |
+
print('Save model to', output_path)
|
cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import sys
|
4 |
+
sys.path.append('../../../')
|
5 |
+
from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--input', type=str, help='The path to the checkpoint.')
|
9 |
+
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
state_dict = torch.load(args.input, map_location='cpu')['model']
|
13 |
+
mae = True
|
14 |
+
# C = state_dict['encoder.patch_embed.proj.weight'].shape[0]
|
15 |
+
C = 768
|
16 |
+
pos_embed = get_sinusoid_encoding_table(14*14, C)
|
17 |
+
cls_token = torch.zeros(1, 1, C)
|
18 |
+
pos_embed = torch.cat([cls_token, pos_embed], dim=1)
|
19 |
+
|
20 |
+
|
21 |
+
new_state_dict = {'backbone.net.pos_embed': pos_embed}
|
22 |
+
for k, v in state_dict.items():
|
23 |
+
|
24 |
+
if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k):
|
25 |
+
|
26 |
+
if 'patch_embed.proj.weight' in k:
|
27 |
+
|
28 |
+
if len(v.shape) == 5:
|
29 |
+
if v.shape[2] == 1:
|
30 |
+
v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16)
|
31 |
+
else:
|
32 |
+
v = v[:, :, 0]
|
33 |
+
|
34 |
+
old_k = k
|
35 |
+
k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k
|
36 |
+
|
37 |
+
if 'attn' in k and '_bias' in k:
|
38 |
+
old_attn = '.'.join(old_k.split('.')[:-1])
|
39 |
+
attn = '.'.join(k.split('.')[:-1])
|
40 |
+
k = attn + '.qkv.bias'
|
41 |
+
if k in new_state_dict:
|
42 |
+
continue
|
43 |
+
|
44 |
+
v = torch.cat([
|
45 |
+
state_dict[old_attn + '.q_bias'],
|
46 |
+
state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']),
|
47 |
+
state_dict[old_attn + '.v_bias'],
|
48 |
+
], dim=0)
|
49 |
+
print(k, v.shape)
|
50 |
+
new_state_dict[k] = v
|
51 |
+
|
52 |
+
output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
|
53 |
+
torch.save(new_state_dict, output_path)
|
54 |
+
print('Save model to', output_path)
|
cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import sys
|
4 |
+
sys.path.append('../../../')
|
5 |
+
from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--input', type=str, help='The path to the checkpoint.')
|
9 |
+
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
|
10 |
+
args = parser.parse_args()
|
11 |
+
breakpoint()
|
12 |
+
state_dict = torch.load(args.input, map_location='cpu')
|
13 |
+
|
14 |
+
new_state_dict = {}
|
15 |
+
|
16 |
+
for k, v in state_dict.items():
|
17 |
+
if 'pos_embed' in k:
|
18 |
+
breakpoint()
|
19 |
+
else:
|
20 |
+
pass
|
21 |
+
k = 'backbone.net.' + k
|
22 |
+
new_state_dict[k] = v
|
23 |
+
|
24 |
+
output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
|
25 |
+
torch.save(new_state_dict, output_path)
|
26 |
+
print('Save model to', output_path)
|
cwm/eval/Segmentation/archive/competition.py
ADDED
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.distributions.categorical import Categorical
|
7 |
+
|
8 |
+
from kornia.filters.kernels import (get_spatial_gradient_kernel2d,
|
9 |
+
normalize_kernel2d)
|
10 |
+
|
11 |
+
def l2_normalize(x):
|
12 |
+
return F.normalize(x, p=2.0, dim=-1, eps=1e-6)
|
13 |
+
|
14 |
+
def reduce_max(x, dim, keepdim=True):
|
15 |
+
return torch.max(x, dim=dim, keepdim=keepdim)[0]
|
16 |
+
|
17 |
+
def coordinate_ims(batch_size, seq_length, imsize):
|
18 |
+
static = False
|
19 |
+
if seq_length == 0:
|
20 |
+
static = True
|
21 |
+
seq_length = 1
|
22 |
+
B = batch_size
|
23 |
+
T = seq_length
|
24 |
+
H,W = imsize
|
25 |
+
ones = torch.ones([B,H,W,1], dtype=torch.float32)
|
26 |
+
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32))
|
27 |
+
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
|
28 |
+
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32))
|
29 |
+
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
|
30 |
+
h = torch.stack([h]*T, 1)
|
31 |
+
w = torch.stack([w]*T, 1)
|
32 |
+
hw_ims = torch.cat([h,w], -1)
|
33 |
+
if static:
|
34 |
+
hw_ims = hw_ims[:,0]
|
35 |
+
return hw_ims
|
36 |
+
|
37 |
+
def dot_product_attention(queries, keys, normalize=True, eps=1e-8):
|
38 |
+
"""
|
39 |
+
Compute the normalized dot product between two PyTorch tensors
|
40 |
+
"""
|
41 |
+
B,N,D_q = queries.size()
|
42 |
+
_B,N_k,D_k = keys.size()
|
43 |
+
assert D_q == D_k, (queries.shape, keys.shape)
|
44 |
+
if normalize:
|
45 |
+
queries = F.normalize(queries, p=2.0, dim=-1, eps=eps)
|
46 |
+
keys = F.normalize(keys, p=2.0, dim=-1, eps=eps)
|
47 |
+
|
48 |
+
outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) # [B, N, N_k]
|
49 |
+
attention = torch.transpose(outputs, 1, 2) # [B, N_k, N]
|
50 |
+
|
51 |
+
return outputs
|
52 |
+
|
53 |
+
def sample_image_inds_from_probs(probs, num_points, eps=1e-9):
|
54 |
+
|
55 |
+
B,H,W = probs.shape
|
56 |
+
P = num_points
|
57 |
+
N = H*W
|
58 |
+
|
59 |
+
probs = probs.reshape(B,N)
|
60 |
+
probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps)
|
61 |
+
dist = Categorical(probs=probs, validate_args=False)
|
62 |
+
|
63 |
+
indices = dist.sample([P]).permute(1,0).to(torch.int32) # [B,P]
|
64 |
+
|
65 |
+
indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1))
|
66 |
+
indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1))
|
67 |
+
indices = torch.stack([indices_h, indices_w], dim=-1) # [B,P,2]
|
68 |
+
return indices
|
69 |
+
|
70 |
+
def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True):
|
71 |
+
|
72 |
+
B,C,H,W = list(image.size())
|
73 |
+
|
74 |
+
# prepare kernel
|
75 |
+
kernel = get_spatial_gradient_kernel2d(mode, order)
|
76 |
+
if normalize_kernel:
|
77 |
+
kernel = normalize_kernel2d(kernel)
|
78 |
+
tmp_kernel = kernel.to(image).detach()
|
79 |
+
tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
|
80 |
+
kernel_flip = tmp_kernel.flip(-3)
|
81 |
+
|
82 |
+
# pad spatial dims of image
|
83 |
+
padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
|
84 |
+
out_channels = 3 if (order == 2) else 2
|
85 |
+
padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] # [B*C,1,1,H+p,W+p]
|
86 |
+
gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W)
|
87 |
+
return gradient_image
|
88 |
+
|
89 |
+
def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True):
|
90 |
+
"""
|
91 |
+
Sample num_points in normalized (h,w) coordinates from the borders of the input image
|
92 |
+
"""
|
93 |
+
B,C,H,W = list(image.size())
|
94 |
+
if mask is not None:
|
95 |
+
assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size())
|
96 |
+
else:
|
97 |
+
mask = torch.ones(size=(B,1,H,W)).to(image)
|
98 |
+
|
99 |
+
gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) # [B,C,2,H,W]
|
100 |
+
gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2))
|
101 |
+
if sum_edges:
|
102 |
+
edges = gradient_magnitude.sum(1) # [B,H,W]
|
103 |
+
else:
|
104 |
+
edges = gradient_magnitude.max(1)[0]
|
105 |
+
|
106 |
+
if mask is not None:
|
107 |
+
edges = edges * mask[:,0]
|
108 |
+
|
109 |
+
coordinates = sample_image_inds_from_probs(edges, num_points=num_points)
|
110 |
+
if normalized_coordinates:
|
111 |
+
coordinates = coordinates.to(torch.float32)
|
112 |
+
coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device)
|
113 |
+
coordinates = 2.0 * coordinates - 1.0
|
114 |
+
return coordinates
|
115 |
+
|
116 |
+
def index_into_images(images, indices, channels_last=False):
|
117 |
+
"""
|
118 |
+
index into an image at P points to get its values
|
119 |
+
|
120 |
+
images: [B,C,H,W]
|
121 |
+
indices: [B,P,2]
|
122 |
+
"""
|
123 |
+
assert indices.size(-1) == 2, indices.size()
|
124 |
+
if channels_last:
|
125 |
+
images = images.permute(0,3,1,2) # [B,C,H,W]
|
126 |
+
B,C,H,W = images.shape
|
127 |
+
_,P,_ = indices.shape
|
128 |
+
inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) # [B,P] each
|
129 |
+
inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices)
|
130 |
+
inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long)
|
131 |
+
values = images.permute(0,2,3,1)[list(inds)] # [B,P,C]
|
132 |
+
return values
|
133 |
+
|
134 |
+
def soft_index(images, indices, scale_by_imsize=True):
|
135 |
+
assert indices.shape[-1] == 2, indices.shape
|
136 |
+
B,C,H,W = images.shape
|
137 |
+
_,P,_ = indices.shape
|
138 |
+
|
139 |
+
# h_inds, w_inds = indices.split([1,1], dim=-1)
|
140 |
+
h_inds, w_inds = list(indices.permute(2,0,1))
|
141 |
+
if scale_by_imsize:
|
142 |
+
h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5
|
143 |
+
w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5
|
144 |
+
|
145 |
+
h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds))
|
146 |
+
w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds))
|
147 |
+
|
148 |
+
h_floor = torch.floor(h_inds)
|
149 |
+
w_floor = torch.floor(w_inds)
|
150 |
+
h_ceil = torch.ceil(h_inds)
|
151 |
+
w_ceil = torch.ceil(w_inds)
|
152 |
+
|
153 |
+
bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor)
|
154 |
+
bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds)
|
155 |
+
top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor)
|
156 |
+
top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds)
|
157 |
+
|
158 |
+
in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95
|
159 |
+
in_bounds = in_bounds.to(torch.float32)
|
160 |
+
|
161 |
+
top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1))
|
162 |
+
top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1))
|
163 |
+
bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1))
|
164 |
+
bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1))
|
165 |
+
|
166 |
+
im_vals = top_left_vals * top_left_weight[...,None]
|
167 |
+
im_vals += top_right_vals * top_right_weight[...,None]
|
168 |
+
im_vals += bot_left_vals * bot_left_weight[...,None]
|
169 |
+
im_vals += bot_right_vals * bot_right_weight[...,None]
|
170 |
+
|
171 |
+
im_vals = im_vals.view(B,P,C)
|
172 |
+
|
173 |
+
return im_vals
|
174 |
+
|
175 |
+
def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1):
|
176 |
+
"""
|
177 |
+
Compute how well "fit" each agent is for the position it's at on the plateau,
|
178 |
+
according to its "phenotype"
|
179 |
+
|
180 |
+
positions: [B,P,2]
|
181 |
+
plateau: [B,H,W,Q]
|
182 |
+
phenotypes: [B,P,D] or None
|
183 |
+
availability: [B,H,W,A]
|
184 |
+
"""
|
185 |
+
B,H,W,Q = plateau.shape
|
186 |
+
P = positions.shape[1]
|
187 |
+
if phenotypes is None:
|
188 |
+
phenotypes = soft_index(plateau, positions)
|
189 |
+
|
190 |
+
if availability is not None:
|
191 |
+
assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape)
|
192 |
+
A = availability.size(-1)
|
193 |
+
assert P % A == 0, (P, A)
|
194 |
+
S = P // A # population size
|
195 |
+
print("computing availability -- needlessly?", [B,H,W,A,Q])
|
196 |
+
plateau = availability[...,None] * plateau[...,None,:] # [B,H,W,A,Q]
|
197 |
+
plateau = plateau.view(B,H,W,A*Q)
|
198 |
+
|
199 |
+
plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True)
|
200 |
+
if noise > 0:
|
201 |
+
plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device)
|
202 |
+
|
203 |
+
if availability is not None:
|
204 |
+
plateau_values = l2_normalize(plateau_values.view(B, P, A, Q))
|
205 |
+
inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] # [B,P,A,1]
|
206 |
+
plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) # [B,P,Q]
|
207 |
+
else:
|
208 |
+
plateau_values = l2_normalize(plateau_values)
|
209 |
+
|
210 |
+
compatibility = torch.sum(
|
211 |
+
l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) # [B,P,1]
|
212 |
+
|
213 |
+
return compatibility
|
214 |
+
|
215 |
+
def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6):
|
216 |
+
"""Find overlaps between masks"""
|
217 |
+
B,N,P = masks.shape
|
218 |
+
if masks_target is None:
|
219 |
+
masks_target = masks
|
220 |
+
if mask_thresh is not None:
|
221 |
+
masks = (masks > mask_thresh).to(torch.float32)
|
222 |
+
masks_target = (masks_target > mask_thresh).to(torch.float32)
|
223 |
+
|
224 |
+
## union and intersection
|
225 |
+
overlaps = masks[...,None] * masks_target[...,None,:] # [B,N,P,P]
|
226 |
+
I = overlaps.sum(dim=1)
|
227 |
+
U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1)
|
228 |
+
iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) # [B,P,P]
|
229 |
+
|
230 |
+
return iou
|
231 |
+
|
232 |
+
def compete_agents(masks, fitnesses, alive,
|
233 |
+
mask_thresh=0.5, compete_thresh=0.2,
|
234 |
+
sticky_winners=True):
|
235 |
+
"""
|
236 |
+
Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each
|
237 |
+
|
238 |
+
args:
|
239 |
+
masks: [B,N,P]
|
240 |
+
fitnesses: [B,P,1]
|
241 |
+
alive: [B,P,1]
|
242 |
+
|
243 |
+
returns:
|
244 |
+
still_alive: [B,P,1]
|
245 |
+
|
246 |
+
"""
|
247 |
+
B,N,P = masks.shape
|
248 |
+
assert list(alive.shape) == [B,P,1], alive.shape
|
249 |
+
assert list(fitnesses.shape) == [B,P,1], fitnesses.shape
|
250 |
+
|
251 |
+
## find territorial disputes
|
252 |
+
overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh)
|
253 |
+
disputes = overlaps > compete_thresh # [B,P,P] <bool>
|
254 |
+
|
255 |
+
## agents don't fight themselves
|
256 |
+
disputes = torch.logical_and(
|
257 |
+
disputes, torch.logical_not(
|
258 |
+
torch.eye(P, dtype=torch.bool, device=disputes.device).unsqueeze(0).expand(B,-1,-1)))
|
259 |
+
|
260 |
+
## kill off the agents with lower fitness in each dispute
|
261 |
+
killed = torch.logical_and(disputes, fitnesses < torch.transpose(fitnesses, 1, 2))
|
262 |
+
|
263 |
+
## once an agent wins, it always wins again
|
264 |
+
if sticky_winners:
|
265 |
+
winners = (alive > 0.5)
|
266 |
+
losers = torch.logical_not(winners)
|
267 |
+
|
268 |
+
## winners can't lose to last round's losers
|
269 |
+
winners_vs_losers = torch.logical_and(winners, torch.transpose(losers, 1, 2)) # [B,P,P]
|
270 |
+
killed = torch.logical_and(killed, torch.logical_not(winners_vs_losers))
|
271 |
+
|
272 |
+
## losers can't overtake last round's winners
|
273 |
+
losers_vs_winners = torch.logical_and(losers, torch.transpose(winners, 1, 2))
|
274 |
+
losers_vs_winners_disputes = torch.logical_and(losers_vs_winners, disputes)
|
275 |
+
killed = torch.logical_or(killed, losers_vs_winners_disputes)
|
276 |
+
|
277 |
+
## if an agent was killed by *any* competitor, it's dead
|
278 |
+
killed = torch.any(killed, dim=2, keepdim=True)
|
279 |
+
alive = torch.logical_not(killed).to(torch.float32)
|
280 |
+
|
281 |
+
return alive
|
282 |
+
|
283 |
+
def compute_distance_weighted_vectors(vector_map, positions, mask=None, beta=1.0, eps=1e-8):
|
284 |
+
"""
|
285 |
+
compute vectors whose values are a weighted mean of vector_map, where weights are given by distance.
|
286 |
+
"""
|
287 |
+
B,H,W,D = vector_map.shape
|
288 |
+
assert positions.size(-1) == 2, positions.size()
|
289 |
+
B,P,_ = positions.shape
|
290 |
+
N = H*W
|
291 |
+
|
292 |
+
if mask is None:
|
293 |
+
mask = torch.ones_like(vector_map[...,0:1]).to(vector_map.device)
|
294 |
+
else:
|
295 |
+
assert list(mask.shape) == [B,H,W,1]
|
296 |
+
|
297 |
+
hw_grid = coordinate_ims(B, 0, [H,W]).view(B, N, 2).to(vector_map.device)
|
298 |
+
delta_positions = hw_grid[:,None] - positions[:,:,None] # [B,P,N,2]
|
299 |
+
distances = torch.sqrt(delta_positions[...,0]**2 + delta_positions[...,1]**2 + eps) # [B,P,N]
|
300 |
+
|
301 |
+
## max distance is 2*sqrt(2)
|
302 |
+
inv_distances = (2.0 * np.sqrt(2.0)) / (distances + eps)
|
303 |
+
inv_distances = F.softmax(beta * inv_distances * mask.view(B, 1, N), dim=-1) # [B,P,N]
|
304 |
+
distance_weighted_vectors = torch.sum(
|
305 |
+
vector_map.view(B, 1, N, D) * inv_distances[...,None], dim=2, keepdim=False) # [B,P,D]
|
306 |
+
return distance_weighted_vectors
|
307 |
+
|
308 |
+
def masks_from_phenotypes(plateau, phenotypes, normalize=True):
|
309 |
+
|
310 |
+
B,H,W,Q = plateau.shape
|
311 |
+
N = H*W
|
312 |
+
masks = dot_product_attention(
|
313 |
+
queries=plateau.view(B,N,Q),
|
314 |
+
keys=phenotypes,
|
315 |
+
normalize=normalize)
|
316 |
+
masks = F.relu(masks)
|
317 |
+
return masks
|
318 |
+
|
319 |
+
class Competition(nn.Module):
|
320 |
+
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
size=None,
|
324 |
+
num_masks=16,
|
325 |
+
num_competition_rounds=5,
|
326 |
+
mask_beta=10.0,
|
327 |
+
reduce_func=reduce_max,
|
328 |
+
stop_gradient=True,
|
329 |
+
stop_gradient_phenotypes=True,
|
330 |
+
normalization_func=l2_normalize,
|
331 |
+
sum_edges=True,
|
332 |
+
mask_thresh=0.5,
|
333 |
+
compete_thresh=0.2,
|
334 |
+
sticky_winners=True,
|
335 |
+
selection_strength=100.0,
|
336 |
+
homing_strength=10.0,
|
337 |
+
mask_dead_segments=True
|
338 |
+
):
|
339 |
+
super().__init__()
|
340 |
+
self.num_masks = self.M = num_masks
|
341 |
+
self.num_competition_rounds = num_competition_rounds
|
342 |
+
self.mask_beta = mask_beta
|
343 |
+
self.reduce_func = reduce_func
|
344 |
+
self.normalization_func = normalization_func
|
345 |
+
|
346 |
+
## stop gradients
|
347 |
+
self.sg_func = lambda x: (x.detach() if stop_gradient else x)
|
348 |
+
self.sg_phenotypes_func = lambda x: (x.detach() if stop_gradient_phenotypes else x)
|
349 |
+
|
350 |
+
## agent sampling kwargs
|
351 |
+
self.sum_edges = sum_edges
|
352 |
+
|
353 |
+
## competition kwargs
|
354 |
+
self.mask_thresh = mask_thresh
|
355 |
+
self.compete_thresh = compete_thresh
|
356 |
+
self.sticky_winners = sticky_winners
|
357 |
+
self.selection_strength = selection_strength
|
358 |
+
self.homing_strength = homing_strength
|
359 |
+
self.mask_dead_segments = mask_dead_segments
|
360 |
+
|
361 |
+
## shapes
|
362 |
+
self.B = self.T = self.BT = self.N = self.Q = None
|
363 |
+
self.size = size # [H,W]
|
364 |
+
if self.size:
|
365 |
+
assert len(self.size) == 2, self.size
|
366 |
+
|
367 |
+
def reshape_batch_time(self, x, merge=True):
|
368 |
+
|
369 |
+
if merge:
|
370 |
+
self.is_temporal = True
|
371 |
+
B, T = x.size()[0:2]
|
372 |
+
if self.B:
|
373 |
+
assert (B == self.B), (B, self.B)
|
374 |
+
else:
|
375 |
+
self.B = B
|
376 |
+
|
377 |
+
if self.T:
|
378 |
+
assert (T == self.T), (T, self.T)
|
379 |
+
else:
|
380 |
+
self.T = T
|
381 |
+
|
382 |
+
assert B*T == (self.B * self.T), (B*T, self.B*self.T)
|
383 |
+
if self.BT is None:
|
384 |
+
self.BT = self.B * self.T
|
385 |
+
|
386 |
+
return torch.reshape(x, [self.BT] + list(x.size())[2:])
|
387 |
+
|
388 |
+
else: # split
|
389 |
+
BT = x.size()[0]
|
390 |
+
assert self.B and self.T, (self.B, self.T)
|
391 |
+
if self.BT is not None:
|
392 |
+
assert BT == self.BT, (BT, self.BT)
|
393 |
+
else:
|
394 |
+
self.BT = BT
|
395 |
+
|
396 |
+
return torch.reshape(x, [self.B, self.T] + list(x.size())[1:])
|
397 |
+
|
398 |
+
def process_plateau_input(self, plateau):
|
399 |
+
|
400 |
+
shape = plateau.size()
|
401 |
+
if len(shape) == 5:
|
402 |
+
self.is_temporal = True
|
403 |
+
self.B, self.T, self.H, self.W, self.Q = shape
|
404 |
+
self.N = self.H * self.W
|
405 |
+
self.BT = self.B * self.T
|
406 |
+
plateau = self.reshape_batch_time(plateau)
|
407 |
+
elif (len(shape) == 4) and (self.size is None):
|
408 |
+
self.is_temporal = False
|
409 |
+
self.B, self.H, self.W, self.Q = shape
|
410 |
+
self.N = self.H * self.W
|
411 |
+
self.T = 1
|
412 |
+
self.BT = self.B*self.T
|
413 |
+
elif (len(shape) == 4) and (self.size is not None):
|
414 |
+
self.is_temporal = True
|
415 |
+
self.B, self.T, self.N, self.Q = shape
|
416 |
+
self.BT = self.B * self.T
|
417 |
+
self.H, self.W = self.size
|
418 |
+
plateau = self.reshape_batch_time(plateau)
|
419 |
+
plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q])
|
420 |
+
elif len(shape) == 3:
|
421 |
+
assert self.size is not None, \
|
422 |
+
"You need to specify an image size to reshape the plateau of shape %s" % shape
|
423 |
+
self.is_temporal = False
|
424 |
+
self.B, self.N, self.Q = shape
|
425 |
+
self.T = 1
|
426 |
+
self.BT = self.B
|
427 |
+
self.H, self.W = self.size
|
428 |
+
plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q])
|
429 |
+
else:
|
430 |
+
raise ValueError("input plateau map with shape %s cannot be reshaped to [BT, H, W, Q]" % shape)
|
431 |
+
|
432 |
+
return plateau
|
433 |
+
|
434 |
+
def forward(self,
|
435 |
+
plateau,
|
436 |
+
agents=None,
|
437 |
+
alive=None,
|
438 |
+
phenotypes=None,
|
439 |
+
compete=True,
|
440 |
+
update_pointers=True,
|
441 |
+
yoke_phenotypes_to_agents=True,
|
442 |
+
noise=0.1
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Find the uniform regions within the plateau map
|
446 |
+
by competition between visual "indices."
|
447 |
+
|
448 |
+
args:
|
449 |
+
plateau: [B,[T],H,W,Q] feature map with smooth "plateaus"
|
450 |
+
|
451 |
+
returns:
|
452 |
+
masks: [B, [T], H, W, M] <float> one mask in each of M channels
|
453 |
+
agents: [B, [T], M, 2] <float> positions of agents in normalized coordinates
|
454 |
+
alive: [B, [T], M] <float> binary vector indicating which masks are valid
|
455 |
+
phenotypes: [B, [T], M, Q]
|
456 |
+
unharvested: [B, [T], H, W] <float> map of regions that weren't covered
|
457 |
+
|
458 |
+
"""
|
459 |
+
|
460 |
+
## preprocess
|
461 |
+
plateau = self.process_plateau_input(plateau) # [BT,H,W,Q]
|
462 |
+
plateau = self.normalization_func(plateau)
|
463 |
+
|
464 |
+
## sample initial indices ("agents") from borders of the plateau map
|
465 |
+
if agents is None:
|
466 |
+
agents = sample_coordinates_at_borders(
|
467 |
+
plateau.permute(0,3,1,2),
|
468 |
+
num_points=self.M,
|
469 |
+
mask=None,
|
470 |
+
sum_edges=self.sum_edges)
|
471 |
+
else:
|
472 |
+
if self.is_temporal:
|
473 |
+
agents = agents.view(self.BT, *agents.shape[2:])
|
474 |
+
|
475 |
+
## the agents have "phenotypes" depending on where they're situated on the plateau map
|
476 |
+
if phenotypes is None:
|
477 |
+
phenotypes = self.sg_phenotypes_func(
|
478 |
+
self.normalization_func(
|
479 |
+
soft_index(plateau.permute(0,3,1,2),
|
480 |
+
agents, scale_by_imsize=True)))
|
481 |
+
elif self.is_temporal:
|
482 |
+
phenotypes = phenotypes.view(self.BT, *phenotypes.shape[2:])
|
483 |
+
|
484 |
+
## the "fitness" of an agent -- how likely it is to survive competition --
|
485 |
+
## is how well its phenotype matches the plateau vector at its current position
|
486 |
+
## initially all of these agents are "alive"
|
487 |
+
if alive is None:
|
488 |
+
alive = torch.ones_like(agents[...,-1:]) # [BT,M,1]
|
489 |
+
fitnesses = compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise)
|
490 |
+
alive_mask = None
|
491 |
+
else:
|
492 |
+
if self.is_temporal:
|
493 |
+
alive = alive.view(self.BT, *alive.shape[2:])
|
494 |
+
alive_mask = (alive > 0.5).float()
|
495 |
+
fitnesses = alive_mask + compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) * (1 - alive_mask)
|
496 |
+
|
497 |
+
alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M]
|
498 |
+
|
499 |
+
## compute the masks at initialization
|
500 |
+
masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True)
|
501 |
+
|
502 |
+
## find the "unharvested" regions of the plateau map not covered by agents
|
503 |
+
unharvested = torch.minimum(self.reduce_func(masks_pred, dim=-1, keepdim=True), torch.tensor(1.0))
|
504 |
+
unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1)
|
505 |
+
|
506 |
+
if alive_mask is not None:
|
507 |
+
new_agents = sample_coordinates_at_borders(
|
508 |
+
plateau.permute(0,3,1,2), num_points=self.M,
|
509 |
+
mask=unharvested.permute(0,3,1,2),
|
510 |
+
sum_edges=self.sum_edges)
|
511 |
+
agents = agents * alive_mask + new_agents * (1.0 - alive_mask)
|
512 |
+
|
513 |
+
new_phenotypes = self.sg_phenotypes_func(
|
514 |
+
self.normalization_func(
|
515 |
+
soft_index(plateau.permute(0,3,1,2),
|
516 |
+
new_agents, scale_by_imsize=True)))
|
517 |
+
phenotypes = phenotypes * alive_mask + new_phenotypes * (1.0 - alive_mask)
|
518 |
+
|
519 |
+
for r in range(self.num_competition_rounds):
|
520 |
+
# print("Evolution round {}".format(r+1))
|
521 |
+
|
522 |
+
## compute the "availability" of the plateau map for each agent (i.e. where it can harvest from)
|
523 |
+
alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M]
|
524 |
+
# availability = alive_t * masks_pred + (1.0 - alive_t) * unharvested.view(self.BT, self.N, 1)
|
525 |
+
# availability = availability.view(self.BT, self.H, self.W, self.M)
|
526 |
+
|
527 |
+
## update the fitnesses
|
528 |
+
if update_pointers and compete:
|
529 |
+
fitnesses = compute_compatibility(
|
530 |
+
positions=agents,
|
531 |
+
plateau=plateau,
|
532 |
+
phenotypes=phenotypes,
|
533 |
+
# availability=availability)
|
534 |
+
availability=None,
|
535 |
+
noise=noise
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
## kill agents that have wandered off the map
|
540 |
+
in_bounds = torch.all(
|
541 |
+
torch.logical_and(agents < 1.0, agents > -1.0),
|
542 |
+
dim=-1, keepdim=True) # [BT,M,1]
|
543 |
+
fitnesses *= in_bounds.to(fitnesses)
|
544 |
+
|
545 |
+
## break ties in fitness
|
546 |
+
fitnesses -= 0.001 * torch.arange(self.M, dtype=torch.float32)[None,:,None].expand(self.BT,-1,-1).to(fitnesses.device)
|
547 |
+
|
548 |
+
## recompute the masks (why?)
|
549 |
+
if yoke_phenotypes_to_agents:
|
550 |
+
occupied_regions = self.sg_phenotypes_func(
|
551 |
+
soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True))
|
552 |
+
masks_pred = masks_from_phenotypes(plateau, occupied_regions, normalize=True) # [BT,N,M]
|
553 |
+
|
554 |
+
## have each pair of agents compete.
|
555 |
+
## If their masks overlap, the winner is the one with higher fitness
|
556 |
+
if compete:
|
557 |
+
alive = compete_agents(masks_pred, fitnesses, alive,
|
558 |
+
mask_thresh=self.mask_thresh,
|
559 |
+
compete_thresh=self.compete_thresh,
|
560 |
+
sticky_winners=self.sticky_winners)
|
561 |
+
|
562 |
+
alive *= in_bounds.to(alive)
|
563 |
+
alive_t = torch.transpose(alive, 1, 2)
|
564 |
+
|
565 |
+
# print("Num alive masks", alive.sum(), "which ones --> ", np.where(alive[0,:,0].detach().cpu().numpy()))
|
566 |
+
if not yoke_phenotypes_to_agents:
|
567 |
+
masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True)
|
568 |
+
|
569 |
+
## update which parts of the plateau are "unharvested"
|
570 |
+
unharvested = torch.minimum(self.reduce_func(masks_pred * alive_t, dim=-1, keepdim=True),
|
571 |
+
torch.tensor(1.0, dtype=torch.float32))
|
572 |
+
unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1)
|
573 |
+
|
574 |
+
|
575 |
+
## update phenotypes of the winners
|
576 |
+
if update_pointers:
|
577 |
+
if self.mask_thresh is not None:
|
578 |
+
winner_phenotypes = (masks_pred[...,None] > self.mask_thresh).to(plateau)
|
579 |
+
if self.selection_strength > 0:
|
580 |
+
winner_phenotypes = winner_phenotypes * plateau.view(self.BT, self.N, 1, self.Q)
|
581 |
+
winner_phenotypes = self.normalization_func(winner_phenotypes.mean(dim=1)) # [BT,M,Q]
|
582 |
+
phenotypes += (alive * winner_phenotypes) * self.selection_strength
|
583 |
+
|
584 |
+
## reinitialize losing agent positions
|
585 |
+
alive_mask = (alive > 0.5).to(torch.float32)
|
586 |
+
loser_agents = sample_coordinates_at_borders(
|
587 |
+
plateau.permute(0,3,1,2), num_points=self.M,
|
588 |
+
mask=unharvested.permute(0,3,1,2),
|
589 |
+
sum_edges=self.sum_edges)
|
590 |
+
agents = agents * alive_mask + loser_agents * (1.0 - alive_mask)
|
591 |
+
|
592 |
+
|
593 |
+
## reinitialize loser agent phenotypes
|
594 |
+
loser_phenotypes = self.normalization_func(
|
595 |
+
compute_distance_weighted_vectors(plateau, agents, mask=unharvested, beta=self.homing_strength))
|
596 |
+
phenotypes = alive_mask * phenotypes + (1.0 - alive_mask) * loser_phenotypes
|
597 |
+
phenotypes = self.normalization_func(phenotypes)
|
598 |
+
|
599 |
+
## that's it for this round!
|
600 |
+
# print("round %d" % r, alive.shape, torch.where(alive[0,:,0]))
|
601 |
+
|
602 |
+
## run a final competition between the surviving masks
|
603 |
+
if self.mask_beta is not None:
|
604 |
+
masks_pred = F.softmax(
|
605 |
+
self.mask_beta * masks_pred * alive_t - \
|
606 |
+
self.mask_beta * (1.0 - alive_t), dim=-1)
|
607 |
+
if self.mask_dead_segments:
|
608 |
+
masks_pred *= alive_t
|
609 |
+
|
610 |
+
masks_pred = masks_pred.view(self.BT,self.H,self.W,self.M)
|
611 |
+
if self.is_temporal:
|
612 |
+
masks_pred = self.reshape_batch_time(masks_pred, merge=False)
|
613 |
+
agents = self.reshape_batch_time(agents, merge=False)
|
614 |
+
alive = self.reshape_batch_time(alive, merge=False)
|
615 |
+
phenotypes = self.reshape_batch_time(phenotypes, merge=False)
|
616 |
+
unharvested = self.reshape_batch_time(unharvested, merge=False)
|
617 |
+
|
618 |
+
return (masks_pred, agents, alive, phenotypes, unharvested)
|
619 |
+
|
620 |
+
@staticmethod
|
621 |
+
def masks_to_segments(masks):
|
622 |
+
return masks.argmax(-1)
|
623 |
+
|
624 |
+
@staticmethod
|
625 |
+
def flatten_plateau_with_masks(plateau, masks, alive, flatten_masks=True):
|
626 |
+
B,M,_ = alive.shape
|
627 |
+
Q = plateau.shape[-1]
|
628 |
+
if flatten_masks:
|
629 |
+
masks = F.one_hot((alive[...,None,None,:,0] * masks).argmax(-1), num_classes=M).float()
|
630 |
+
|
631 |
+
flat_plateau = torch.zeros_like(plateau)
|
632 |
+
phenotypes = torch.zeros((B,M,Q), device=plateau.device).float()
|
633 |
+
for b in range(B):
|
634 |
+
m_inds = torch.where(alive[b,:,0])[0]
|
635 |
+
masks_b = masks[b,...,m_inds]
|
636 |
+
num_px = masks_b.sum((0,1)).clamp(min=1)[:,None] # [K,1]
|
637 |
+
phenos_b = torch.einsum('hwk,hwq->kq', masks_b, plateau[b]) / num_px # [K,Q]
|
638 |
+
flat_plateau_b = (masks_b[...,None] * phenos_b[None,None]).sum(-2) # [H,W,Q]
|
639 |
+
|
640 |
+
phenotypes[b,m_inds,:] = phenos_b
|
641 |
+
flat_plateau[b] = flat_plateau_b
|
642 |
+
|
643 |
+
_norm = lambda x: F.normalize(x, p=2, dim=-1)
|
644 |
+
return (_norm(flat_plateau), _norm(phenotypes))
|
645 |
+
|
646 |
+
@staticmethod
|
647 |
+
def plot_agents(agents, alive, size=[128,128]):
|
648 |
+
B,M,_ = alive.shape
|
649 |
+
agent_map = -1 * torch.ones((B,*size), device=alive.device, dtype=torch.long)
|
650 |
+
for b in range(B):
|
651 |
+
inds = torch.where(alive[b,:,0])
|
652 |
+
for i in inds[0]:
|
653 |
+
pos = agents[b,i]*0.5 + 0.5
|
654 |
+
pos = pos * torch.tensor(size, device=pos.device)
|
655 |
+
hmin, wmin = list(torch.floor(pos).long())
|
656 |
+
hmax, wmax = list(torch.ceil(pos).long())
|
657 |
+
agent_map[b,[hmin,hmin,hmax,hmax],[wmin,wmax,wmin,wmax]] = i
|
658 |
+
|
659 |
+
return agent_map
|
660 |
+
|
661 |
+
if __name__ == '__main__':
|
662 |
+
|
663 |
+
Comp = Competition(num_masks=32, num_competition_rounds=5)
|
664 |
+
|
665 |
+
left = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([1.,0.2,0.])
|
666 |
+
middle = torch.ones(size=(32,16)).unsqueeze(-1) * torch.tensor([0.,1.,0.2])
|
667 |
+
right = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([0.1,0.,1.])
|
668 |
+
plateau = torch.cat([left, middle, right], dim=-2).unsqueeze(0)
|
669 |
+
masks, agents, alive, phenotypes, unharvested = Comp(plateau)
|
670 |
+
mask_inds = np.where(alive[0,:,0].numpy())[0]
|
671 |
+
print(np.argmax(masks[0,...], axis=-1))
|
672 |
+
for ind in mask_inds:
|
673 |
+
print("num pixels in mask %d ---> %d" % (ind, (np.argmax(masks[0], -1) == ind).sum()))
|
cwm/eval/Segmentation/archive/configs/__init__.py
ADDED
File without changes
|
cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from fvcore.common.param_scheduler import MultiStepParamScheduler
|
3 |
+
|
4 |
+
from detectron2 import model_zoo
|
5 |
+
from detectron2.config import LazyCall as L
|
6 |
+
from detectron2.config import CfgNode, LazyConfig
|
7 |
+
from detectron2.solver import WarmupParamScheduler
|
8 |
+
from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
|
9 |
+
|
10 |
+
from ..common.coco_loader_lsj import dataloader
|
11 |
+
|
12 |
+
|
13 |
+
# model = model_zoo.get_config("./models/mask_rcnn_cwm.py").model
|
14 |
+
|
15 |
+
cfg_file = "./models/mask_rcnn_cwm.py"
|
16 |
+
model = LazyConfig.load(cfg_file).model
|
17 |
+
|
18 |
+
# url = get_checkpoint_url(config_path)
|
19 |
+
# if "train" in cfg and "init_checkpoint" in cfg.train:
|
20 |
+
# cfg.train.init_checkpoint = url
|
21 |
+
# else:
|
22 |
+
# raise NotImplementedError
|
23 |
+
|
24 |
+
# Initialization and trainer settings
|
25 |
+
train = model_zoo.get_config("common/train.py").train
|
26 |
+
train.amp.enabled = True
|
27 |
+
train.ddp.fp16_compression = True
|
28 |
+
train.init_checkpoint = (
|
29 |
+
#"/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping/checkpoint-799-encoder.pth"
|
30 |
+
'./output/model_0004999.pth'
|
31 |
+
)
|
32 |
+
train.eval_period = 1e9
|
33 |
+
|
34 |
+
# Schedule
|
35 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
36 |
+
# train.max_iter = 184375
|
37 |
+
# milestones = [163889, 177546]
|
38 |
+
|
39 |
+
# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
|
40 |
+
train.max_iter = 61458
|
41 |
+
milestones = [54629, 59182]
|
42 |
+
|
43 |
+
lr_multiplier = L(WarmupParamScheduler)(
|
44 |
+
scheduler=L(MultiStepParamScheduler)(
|
45 |
+
values=[1.0, 0.1, 0.01],
|
46 |
+
milestones=milestones,
|
47 |
+
num_updates=train.max_iter,
|
48 |
+
),
|
49 |
+
warmup_length=250 / train.max_iter,
|
50 |
+
warmup_factor=0.001,
|
51 |
+
)
|
52 |
+
|
53 |
+
# Optimizer
|
54 |
+
optimizer = model_zoo.get_config("common/optim.py").AdamW
|
55 |
+
optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
|
56 |
+
optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
|
cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from fvcore.common.param_scheduler import MultiStepParamScheduler
|
3 |
+
|
4 |
+
from detectron2 import model_zoo
|
5 |
+
from detectron2.config import LazyCall as L
|
6 |
+
from detectron2.solver import WarmupParamScheduler
|
7 |
+
from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
|
8 |
+
import os
|
9 |
+
from ..common.coco_loader_lsj import dataloader
|
10 |
+
from detectron2.data.datasets import register_coco_instances
|
11 |
+
from detectron2.config import CfgNode, LazyConfig
|
12 |
+
# model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
|
13 |
+
# model.backbone.square_pad = 512 # change input size to 512x512
|
14 |
+
|
15 |
+
cfg_file = "./models/mask_rcnn_vitdet_v2.py"
|
16 |
+
model = LazyConfig.load(cfg_file).model
|
17 |
+
model.backbone.square_pad = 512 # change input size to 512x512
|
18 |
+
|
19 |
+
# Initialization and trainer settings
|
20 |
+
train = model_zoo.get_config("common/train.py").train
|
21 |
+
train.amp.enabled = True
|
22 |
+
train.ddp.fp16_compression = True
|
23 |
+
train.init_checkpoint = (
|
24 |
+
#"detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True"
|
25 |
+
#"/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_16x16_no_clumping_mr0.98/checkpoint-799-encoder.pth"
|
26 |
+
"/ccn2/u/honglinc/cwm_checkpoints/mae_vitb/mae_pretrain_vit_base-encoder.pth"
|
27 |
+
)
|
28 |
+
train.output_dir = os.path.dirname(train.init_checkpoint) + "/coco_finetune_512_v3"
|
29 |
+
|
30 |
+
root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
|
31 |
+
register_coco_instances("cls_agnostic_coco", {},
|
32 |
+
os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"),
|
33 |
+
os.path.join(root, "coco/val2017")
|
34 |
+
)
|
35 |
+
dataloader.test.dataset.names = 'cls_agnostic_coco'
|
36 |
+
# Schedule
|
37 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
38 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
39 |
+
# train.max_iter = 184375
|
40 |
+
# milestones = [163889, 177546]
|
41 |
+
|
42 |
+
# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
|
43 |
+
train.max_iter = 61458
|
44 |
+
milestones = [54629, 59182]
|
45 |
+
|
46 |
+
lr_multiplier = L(WarmupParamScheduler)(
|
47 |
+
scheduler=L(MultiStepParamScheduler)(
|
48 |
+
values=[1.0, 0.1, 0.01],
|
49 |
+
milestones=milestones,
|
50 |
+
num_updates=train.max_iter,
|
51 |
+
),
|
52 |
+
warmup_length=250 / train.max_iter,
|
53 |
+
warmup_factor=0.001,
|
54 |
+
)
|
55 |
+
|
56 |
+
# Optimizer
|
57 |
+
optimizer = model_zoo.get_config("common/optim.py").AdamW
|
58 |
+
optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
|
59 |
+
optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
|
cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from fvcore.common.param_scheduler import MultiStepParamScheduler
|
3 |
+
|
4 |
+
from detectron2 import model_zoo
|
5 |
+
from detectron2.config import LazyCall as L
|
6 |
+
from detectron2.config import CfgNode, LazyConfig
|
7 |
+
from detectron2.solver import WarmupParamScheduler
|
8 |
+
|
9 |
+
from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
|
10 |
+
import os
|
11 |
+
from ..common.coco_loader_lsj import dataloader
|
12 |
+
|
13 |
+
# model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
|
14 |
+
# model.backbone.square_pad = 512 # change input size to 512x512
|
15 |
+
|
16 |
+
cfg_file = "./models/mask_rcnn_cwm.py"
|
17 |
+
model = LazyConfig.load(cfg_file).model
|
18 |
+
|
19 |
+
# Initialization and trainer settings
|
20 |
+
train = model_zoo.get_config("common/train.py").train
|
21 |
+
train.amp.enabled = True
|
22 |
+
train.ddp.fp16_compression = True
|
23 |
+
train.init_checkpoint = (
|
24 |
+
'/home/honglinc/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth'
|
25 |
+
)
|
26 |
+
train.output_dir = '/ccn2/u/honglinc/cwm_checkpoints/dinov2_coco_finetune_512'
|
27 |
+
|
28 |
+
# model.backbone.net.window_size = 0
|
29 |
+
# model.backbone.net.window_block_indexes = []
|
30 |
+
# model.backbone.net.use_rel_pos = False
|
31 |
+
# model.backbone.net.drop_path_rate = 0.
|
32 |
+
|
33 |
+
# Schedule
|
34 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
35 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
36 |
+
# train.max_iter = 184375
|
37 |
+
# milestones = [163889, 177546]
|
38 |
+
|
39 |
+
# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
|
40 |
+
train.max_iter = 61458
|
41 |
+
milestones = [54629, 59182]
|
42 |
+
|
43 |
+
lr_multiplier = L(WarmupParamScheduler)(
|
44 |
+
scheduler=L(MultiStepParamScheduler)(
|
45 |
+
values=[1.0, 0.1, 0.01],
|
46 |
+
milestones=milestones,
|
47 |
+
num_updates=train.max_iter,
|
48 |
+
),
|
49 |
+
warmup_length=250 / train.max_iter,
|
50 |
+
warmup_factor=0.001,
|
51 |
+
)
|
52 |
+
|
53 |
+
# Optimizer
|
54 |
+
optimizer = model_zoo.get_config("common/optim.py").AdamW
|
55 |
+
optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
|
56 |
+
optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
|
cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from fvcore.common.param_scheduler import MultiStepParamScheduler
|
3 |
+
|
4 |
+
from detectron2 import model_zoo
|
5 |
+
from detectron2.config import LazyCall as L
|
6 |
+
from detectron2.config import CfgNode, LazyConfig
|
7 |
+
from detectron2.solver import WarmupParamScheduler
|
8 |
+
|
9 |
+
from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
|
10 |
+
import os
|
11 |
+
from ..common.coco_loader_lsj import dataloader
|
12 |
+
from detectron2.data.datasets import register_coco_instances
|
13 |
+
# model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
|
14 |
+
# model.backbone.square_pad = 512 # change input size to 512x512
|
15 |
+
|
16 |
+
cfg_file = "./models/mask_rcnn_cwm.py"
|
17 |
+
model = LazyConfig.load(cfg_file).model
|
18 |
+
|
19 |
+
# Initialization and trainer settings
|
20 |
+
train = model_zoo.get_config("common/train.py").train
|
21 |
+
train.amp.enabled = True
|
22 |
+
train.ddp.fp16_compression = True
|
23 |
+
train.init_checkpoint = (
|
24 |
+
# "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True"
|
25 |
+
'/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_16x16_no_clumping_mr0.90/checkpoint-799-encoder.pth'
|
26 |
+
)
|
27 |
+
train.output_dir = os.path.dirname(train.init_checkpoint) + "/coco_finetune_512_v3"
|
28 |
+
train.eval_period = 1e9
|
29 |
+
|
30 |
+
root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
|
31 |
+
register_coco_instances("cls_agnostic_coco", {},
|
32 |
+
os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"),
|
33 |
+
os.path.join(root, "coco/val2017")
|
34 |
+
)
|
35 |
+
dataloader.test.dataset.names = 'cls_agnostic_coco'
|
36 |
+
# Schedule
|
37 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
38 |
+
# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
|
39 |
+
# train.max_iter = 184375
|
40 |
+
# milestones = [163889, 177546]
|
41 |
+
|
42 |
+
# 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
|
43 |
+
train.max_iter = 61458
|
44 |
+
milestones = [54629, 59182]
|
45 |
+
|
46 |
+
lr_multiplier = L(WarmupParamScheduler)(
|
47 |
+
scheduler=L(MultiStepParamScheduler)(
|
48 |
+
values=[1.0, 0.1, 0.01],
|
49 |
+
milestones=milestones,
|
50 |
+
num_updates=train.max_iter,
|
51 |
+
),
|
52 |
+
warmup_length=250 / train.max_iter,
|
53 |
+
warmup_factor=0.001,
|
54 |
+
)
|
55 |
+
|
56 |
+
# Optimizer
|
57 |
+
optimizer = model_zoo.get_config("common/optim.py").AdamW
|
58 |
+
optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
|
59 |
+
optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
|