rahulvenkk commited on
Commit
6dfcb0f
1 Parent(s): 0ba04fa

app.py updated

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +3 -0
  3. README.md +47 -14
  4. assets/color_wheel.png +0 -0
  5. assets/cwm_teaser.gif +3 -0
  6. assets/desk_1.jpg +3 -0
  7. assets/flow_test_videos/libby.mp4 +3 -0
  8. assets/flow_test_videos/weight_lifter.mp4 +3 -0
  9. cwm/__init__.py +0 -0
  10. cwm/data/__init__.py +0 -0
  11. cwm/data/dataset.py +453 -0
  12. cwm/data/dataset_utils.py +73 -0
  13. cwm/data/masking_generator.py +86 -0
  14. cwm/data/transforms.py +206 -0
  15. cwm/data/video_file_lists/kinetics_400_train_list.txt +3 -0
  16. cwm/data/video_file_lists/kinetics_400_train_list_sing.txt +3 -0
  17. cwm/engine_for_pretraining.py +92 -0
  18. cwm/eval/Action_recognition/__init__.py +0 -0
  19. cwm/eval/Flow/__init__.py +0 -0
  20. cwm/eval/Flow/create_spring_submission_parallel.sh +36 -0
  21. cwm/eval/Flow/create_spring_submission_unified.py +111 -0
  22. cwm/eval/Flow/flow_extraction_classes.py +122 -0
  23. cwm/eval/Flow/flow_utils.py +569 -0
  24. cwm/eval/Flow/flow_utils_legacy.py +152 -0
  25. cwm/eval/Flow/generator.py +579 -0
  26. cwm/eval/Flow/losses.py +60 -0
  27. cwm/eval/Flow/masking_flow.py +375 -0
  28. cwm/eval/Flow/vis_utils.py +150 -0
  29. cwm/eval/IntPhys/__init__.py +0 -0
  30. cwm/eval/Physion/__init__.py +0 -0
  31. cwm/eval/Physion/feature_extractor.py +317 -0
  32. cwm/eval/Physion/flow_utils.py +279 -0
  33. cwm/eval/Physion/run_eval.sh +17 -0
  34. cwm/eval/Physion/run_eval_kfflow.sh +18 -0
  35. cwm/eval/Physion/run_eval_mp4s.sh +19 -0
  36. cwm/eval/Physion/run_eval_mp4s_keyp.sh +17 -0
  37. cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh +17 -0
  38. cwm/eval/Segmentation/__init__.py +0 -0
  39. cwm/eval/Segmentation/archive/__init__.py +0 -0
  40. cwm/eval/Segmentation/archive/common/__init__.py +0 -0
  41. cwm/eval/Segmentation/archive/common/coco_loader_lsj.py +222 -0
  42. cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py +19 -0
  43. cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py +54 -0
  44. cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py +26 -0
  45. cwm/eval/Segmentation/archive/competition.py +673 -0
  46. cwm/eval/Segmentation/archive/configs/__init__.py +0 -0
  47. cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py +56 -0
  48. cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py +59 -0
  49. cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py +56 -0
  50. cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py +59 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.txt filter=lfs diff=lfs merge=lfs -text
38
+ *.gif filter=lfs diff=lfs merge=lfs -text
39
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+
3
+ .pth
README.md CHANGED
@@ -1,14 +1,47 @@
1
- ---
2
- title: Counterfactual World Models
3
- emoji: 📊
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Vision foundation model that unifies vision structures
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h2>Understanding Physical Dynamics with Counterfactual World Modeling</h2>
3
+
4
+ [**Rahul Venkatesh***](https://rahulvenkk.github.io/)<sup>1</sup> · [**Honglin Chen***](https://web.stanford.edu/~honglinc/)<sup>1*</sup> · [**Kevin Feigelis***](https://neuroscience.stanford.edu/people/kevin-t-feigelis)<sup>1</sup> · [**Daniel M. Bear**](https://twitter.com/recursus?lang=en)<sup>1</sup> · [**Khaled Jedoui**](https://web.stanford.edu/~thekej/)<sup>1</sup> · [**Klemen Kotar**](https://klemenkotar.github.io/)<sup>1</sup> · [**Felix Binder**](https://ac.felixbinder.net/)<sup>2</sup> · [**Wanhee Lee**](https://www.linkedin.com/in/wanhee-lee-31102820b/)<sup>1</sup> · [**Sherry Liu**](https://neuroailab.github.io/cwm-physics/)<sup>1</sup> · [**Kevin A. Smith**](https://www.mit.edu/~k2smith/)<sup>3</sup> · [**Judith E. Fan**](https://cogtoolslab.github.io/)<sup>1</sup> · [**Daniel L. K. Yamins**](https://stanford.edu/~yamins/)<sup>1</sup>
5
+
6
+ (* equal contribution)
7
+
8
+ <sup>1</sup>Stanford&emsp;&emsp;&emsp;&emsp;<sup>2</sup>UCSD&emsp;&emsp;&emsp;&emsp;<sup>3</sup>MIT
9
+
10
+
11
+
12
+
13
+ <a href="https://arxiv.org/abs/2312.06721"><img src='https://img.shields.io/badge/arXiv-CWM-red' alt='Paper PDF'></a>
14
+ <a href='https://neuroailab.github.io/cwm-physics/'><img src='https://img.shields.io/badge/Project_Page-CWM-green' alt='Project Page'></a>
15
+ <a href='https://neuroailab.github.io/cwm-physics/'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
16
+ <a href='https://neuroailab.github.io/cwm-physics/'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Colab-yellow'></a>
17
+ </div>
18
+
19
+ This work presents the Counterfactual World Modeling (CWM) framework. CWM is capable of counterfactual prediction and extraction of vision structures useful for understanding physical dynamics.
20
+
21
+ ![](assets/cwm_teaser.gif)
22
+
23
+ ## 📣 News
24
+
25
+ - 2024-06-01: Release [project page](https://neuroailab.github.io) and [codes](https://github.com/rahulvenkk/cwm_release.git)
26
+
27
+ ## 🔨 Installation
28
+
29
+ ```
30
+ git clone https://github.com/rahulvenkk/cwm_release.git
31
+ pip install -e .
32
+ ```
33
+
34
+ ## ✨ Usage
35
+ To download and use a pre-trianed model run the following
36
+ ```
37
+ from cwm.model.model_factory import model_factory
38
+ model = model_factory.load_model('vitbase_8x8patch_3frames_1tube')
39
+ ```
40
+ This will automatically initialize the appropriate model class and download the specified weights to your `$CACHE` directory.
41
+
42
+ ## 🔄 Pre-training
43
+ To train the model run the following script
44
+
45
+ ```
46
+ ./scripts/pretrain/3frame_patch8x8_mr0.90_gpu.sh
47
+ ```
assets/color_wheel.png ADDED
assets/cwm_teaser.gif ADDED

Git LFS Details

  • SHA256: 4fac6e545660c695f81f360a87f1060c44eea95f3ae7dbcf1fecbe2b097e3b6a
  • Pointer size: 133 Bytes
  • Size of remote file: 12.9 MB
assets/desk_1.jpg ADDED

Git LFS Details

  • SHA256: 84a3bfdd40841e8d291b4eb638ecc29b99f054ae6d3ea51b4cdc3090741987c8
  • Pointer size: 132 Bytes
  • Size of remote file: 4.73 MB
assets/flow_test_videos/libby.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d1887bc796d1883e8a63405398125476257572c0c8d7f1862bf309e422b4828
3
+ size 671950
assets/flow_test_videos/weight_lifter.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:525988b314936079f904236c614e2f987ba64fd5c69f4f12dd5b6b9076311854
3
+ size 1176790
cwm/__init__.py ADDED
File without changes
cwm/data/__init__.py ADDED
File without changes
cwm/data/dataset.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import decord
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+
9
+
10
+ class VideoMAE(torch.utils.data.Dataset):
11
+ """Load your own video classification dataset.
12
+ Parameters
13
+ ----------
14
+ root : str, required.
15
+ Path to the root folder storing the dataset.
16
+ setting : str, required.
17
+ A text file describing the dataset, each line per video sample.
18
+ There are three items in each line: (1) video path; (2) video length and (3) video label.
19
+ train : bool, default True.
20
+ Whether to load the training or validation set.
21
+ test_mode : bool, default False.
22
+ Whether to perform evaluation on the test set.
23
+ Usually there is three-crop or ten-crop evaluation strategy involved.
24
+ name_pattern : str, default None.
25
+ The naming pattern of the decoded video frames.
26
+ For example, img_00012.jpg.
27
+ video_ext : str, default 'mp4'.
28
+ If video_loader is set to True, please specify the video format accordinly.
29
+ is_color : bool, default True.
30
+ Whether the loaded image is color or grayscale.
31
+ modality : str, default 'rgb'.
32
+ Input modalities, we support only rgb video frames for now.
33
+ Will add support for rgb difference image and optical flow image later.
34
+ num_segments : int, default 1.
35
+ Number of segments to evenly divide the video into clips.
36
+ A useful technique to obtain global video-level information.
37
+ Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
38
+ num_crop : int, default 1.
39
+ Number of crops for each image. default is 1.
40
+ Common choices are three crops and ten crops during evaluation.
41
+ new_length : int, default 1.
42
+ The length of input video clip. Default is a single image, but it can be multiple video frames.
43
+ For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
44
+ new_step : int, default 1.
45
+ Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
46
+ new_step=2 means we will extract a video clip of every other frame.
47
+ temporal_jitter : bool, default False.
48
+ Whether to temporally jitter if new_step > 1.
49
+ video_loader : bool, default False.
50
+ Whether to use video loader to load data.
51
+ use_decord : bool, default True.
52
+ Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
53
+ transform : function, default None.
54
+ A function that takes data and label and transforms them.
55
+ data_aug : str, default 'v1'.
56
+ Different types of data augmentation auto. Supports v1, v2, v3 and v4.
57
+ lazy_init : bool, default False.
58
+ If set to True, build a dataset instance without loading any dataset.
59
+ """
60
+
61
+ def __init__(self,
62
+ root,
63
+ setting,
64
+ train=True,
65
+ test_mode=False,
66
+ name_pattern='img_%05d.jpg',
67
+ video_ext='mp4',
68
+ is_color=True,
69
+ modality='rgb',
70
+ num_segments=1,
71
+ num_crop=1,
72
+ new_length=1,
73
+ new_step=1,
74
+ randomize_interframes=False,
75
+ transform=None,
76
+ temporal_jitter=False,
77
+ video_loader=False,
78
+ use_decord=False,
79
+ lazy_init=False,
80
+ is_video_dataset=True):
81
+
82
+ super(VideoMAE, self).__init__()
83
+ self.root = root
84
+ self.setting = setting
85
+ self.train = train
86
+ self.test_mode = test_mode
87
+ self.is_color = is_color
88
+ self.modality = modality
89
+ self.num_segments = num_segments
90
+ self.num_crop = num_crop
91
+ self.new_length = new_length
92
+
93
+ self.randomize_interframes = randomize_interframes
94
+ self._new_step = new_step # If randomize_interframes is True, then this is the max, otherwise it's just the skip
95
+ # self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated
96
+ self.temporal_jitter = temporal_jitter
97
+ self.name_pattern = name_pattern
98
+ self.video_loader = video_loader
99
+ self.video_ext = video_ext
100
+ self.use_decord = use_decord
101
+ self.transform = transform
102
+ self.lazy_init = lazy_init
103
+
104
+ if (not self.lazy_init) and is_video_dataset:
105
+ self.clips = self._make_dataset(root, setting)
106
+ if len(self.clips) == 0:
107
+ raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
108
+ "Check your data directory (opt.data-dir)."))
109
+
110
+ def __getitem__(self, index):
111
+
112
+ directory, target = self.clips[index]
113
+
114
+ if self.video_loader:
115
+ if '.' in directory.split('/')[-1]:
116
+ # data in the "setting" file already have extension, e.g., demo.mp4
117
+ video_name = directory
118
+ else:
119
+ # data in the "setting" file do not have extension, e.g., demo
120
+ # So we need to provide extension (i.e., .mp4) to complete the file name.
121
+ video_name = '{}.{}'.format(directory, self.video_ext)
122
+
123
+ try:
124
+ decord_vr = decord.VideoReader(video_name, num_threads=1)
125
+ except:
126
+ # return video_name
127
+ return (self.__getitem__(index + 1))
128
+ duration = len(decord_vr)
129
+
130
+ segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration)
131
+
132
+ images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets,
133
+ new_step, skip_length)
134
+
135
+ process_data, mask = self.transform((images, None)) # T*C,H,W
136
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,
137
+ 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
138
+
139
+ return (process_data, mask)
140
+
141
+ def __len__(self):
142
+ return len(self.clips)
143
+
144
+ def _make_dataset(self, directory, setting):
145
+ if not os.path.exists(setting):
146
+ raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
147
+ clips = []
148
+ with open(setting) as split_f:
149
+ data = split_f.readlines()
150
+ for line in data:
151
+ line_info = line.split(' ')
152
+ # line format: video_path, video_duration, video_label
153
+ if len(line_info) < 2:
154
+ raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
155
+ elif len(line_info) > 2:
156
+ line_info = (' '.join(line_info[:-1]), line_info[-1]) # filename has spaces
157
+ clip_path = os.path.join(line_info[0])
158
+ target = int(line_info[1])
159
+ item = (clip_path, target)
160
+ clips.append(item)
161
+ # import torch_xla.core.xla_model as xm
162
+ # print = xm.master_print
163
+ # print("Dataset created. Number of clips: ", len(clips))
164
+ return clips
165
+
166
+ def _sample_train_indices(self, num_frames):
167
+ if self.randomize_interframes is False:
168
+ new_step = self._new_step
169
+ else:
170
+ new_step = np.random.randint(1, self._new_step + 1)
171
+
172
+ skip_length = self.new_length * new_step
173
+
174
+ average_duration = (num_frames - skip_length + 1) // self.num_segments
175
+ if average_duration > 0:
176
+ offsets = np.multiply(list(range(self.num_segments)),
177
+ average_duration)
178
+ offsets = offsets + np.random.randint(average_duration,
179
+ size=self.num_segments)
180
+ elif num_frames > max(self.num_segments, skip_length):
181
+ offsets = np.sort(np.random.randint(
182
+ num_frames - skip_length + 1,
183
+ size=self.num_segments))
184
+ else:
185
+ offsets = np.zeros((self.num_segments,))
186
+
187
+ if self.temporal_jitter:
188
+ skip_offsets = np.random.randint(
189
+ new_step, size=skip_length // new_step)
190
+ else:
191
+ skip_offsets = np.zeros(
192
+ skip_length // new_step, dtype=int)
193
+ return offsets + 1, skip_offsets, new_step, skip_length
194
+
195
+ def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step,
196
+ skip_length):
197
+ sampled_list = []
198
+ frame_id_list = []
199
+ for seg_ind in indices:
200
+ offset = int(seg_ind)
201
+ for i, _ in enumerate(range(0, skip_length, new_step)):
202
+ if offset + skip_offsets[i] <= duration:
203
+ frame_id = offset + skip_offsets[i] - 1
204
+ else:
205
+ frame_id = offset - 1
206
+ frame_id_list.append(frame_id)
207
+ if offset + new_step < duration:
208
+ offset += new_step
209
+ try:
210
+ video_data = video_reader.get_batch(frame_id_list).asnumpy()
211
+ sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in
212
+ enumerate(frame_id_list)]
213
+ except:
214
+ raise RuntimeError(
215
+ 'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory,
216
+ duration))
217
+ return sampled_list
218
+
219
+
220
+ class ContextAndTargetVideoDataset(VideoMAE):
221
+ """
222
+ A video dataset whose provided videos consist of (1) a "context" sequence of length Tc
223
+ and (2) a "target" sequence Tt.
224
+
225
+ These two sequences have the same frame rate (specificiable in real units) but are
226
+ separated by a specified gap (which may vary for different examples.)
227
+
228
+ The main use case is for training models to predict ahead by some variable amount,
229
+ given the context.
230
+ """
231
+
232
+ standard_fps = [12, 24, 30, 48, 60, 100]
233
+
234
+ def __init__(self,
235
+ root,
236
+ setting,
237
+ train=True,
238
+ test_mode=False,
239
+ transform=None,
240
+ step_units='ms',
241
+ new_step=150,
242
+ start_frame=0,
243
+ context_length=2,
244
+ target_length=1,
245
+ channels_first=True,
246
+ generate_masks=True,
247
+ mask_generator=None,
248
+ context_target_gap=[400, 600],
249
+ normalize_timestamps=True,
250
+ default_fps=30,
251
+ min_fps=0.1,
252
+ seed=0,
253
+ *args,
254
+ **kwargs):
255
+ super(ContextAndTargetVideoDataset, self).__init__(
256
+ root=root,
257
+ setting=setting,
258
+ train=train,
259
+ test_mode=test_mode,
260
+ transform=transform,
261
+ new_length=context_length,
262
+ use_decord=True,
263
+ lazy_init=False,
264
+ video_loader=True,
265
+ *args, **kwargs)
266
+
267
+ # breakpoint()
268
+
269
+ self.context_length = self.new_length
270
+ self.target_length = target_length
271
+
272
+ ## convert from fps and step size to frames
273
+ self._fps = None
274
+ self._min_fps = min_fps
275
+ self._default_fps = default_fps
276
+ self._step_units = step_units
277
+ self.new_step = new_step
278
+
279
+ ## sampling for train and test
280
+ self._start_frame = start_frame
281
+ self.gap = context_target_gap
282
+ self.seed = seed
283
+ self.rng = np.random.RandomState(seed=seed)
284
+
285
+ # breakpoint()
286
+
287
+ ## output formatting
288
+ self._channels_first = channels_first
289
+ self._normalize_timestamps = normalize_timestamps
290
+ self._generate_masks = generate_masks
291
+ self.mask_generator = mask_generator
292
+
293
+
294
+ def _get_frames_per_t(self, t):
295
+ if self._step_units == 'frames' or (self._step_units is None):
296
+ return int(t)
297
+
298
+ assert self._fps is not None
299
+ t_per_frame = 1 / self._fps
300
+ if self._step_units in ['ms', 'milliseconds']:
301
+ t_per_frame *= 1000.0
302
+
303
+ return max(int(np.round(t / t_per_frame)), 1)
304
+
305
+ @property
306
+ def new_step(self):
307
+ if self._fps is None:
308
+ return None
309
+ else:
310
+ return self._get_frames_per_t(self._new_step)
311
+
312
+ @new_step.setter
313
+ def new_step(self, v):
314
+ self._new_step = v
315
+
316
+ @property
317
+ def gap(self):
318
+ if self._fps is None:
319
+ return [1, 2]
320
+ else:
321
+ gap = [self._get_frames_per_t(self._gap[0]),
322
+ self._get_frames_per_t(self._gap[1])]
323
+ gap[1] = max(gap[1], gap[0] + 1)
324
+ return gap
325
+
326
+ @gap.setter
327
+ def gap(self, v):
328
+ if v is None:
329
+ v = self._new_step
330
+ if not isinstance(v, (list, tuple)):
331
+ v = [v, v]
332
+ self._gap = v
333
+
334
+ def _get_video_name(self, directory):
335
+ if ''.join(['.', self.video_ext]) in directory.split('/')[-1]:
336
+ # data in the "setting" file has extension, e.g. demo.mpr
337
+ video_name = directory
338
+ else:
339
+ # data doesn't have an extension
340
+ video_name = '{}.{}'.format(directory, self.video_ext)
341
+ return video_name
342
+
343
+ def _set_fps(self, reader):
344
+ """click fps to a standard"""
345
+ if self._step_units == 'frames' or self._step_units is None:
346
+ self._fps = None
347
+ else:
348
+ self._fps = None
349
+ fps = reader.get_avg_fps()
350
+ for st in self.standard_fps:
351
+ if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st):
352
+ self._fps = st
353
+ if self._fps is None:
354
+ self._fps = int(np.round(fps))
355
+
356
+ if self._fps < self._min_fps:
357
+ self._fps = self._default_fps
358
+
359
+ def _get_step_and_gap(self):
360
+ step = self.new_step
361
+ if self.randomize_interframes and self.train:
362
+ step = self.rng.randint(1, step + 1)
363
+
364
+ if self.train:
365
+ gap = self.rng.randint(*self.gap)
366
+ else:
367
+ gap = sum(self.gap) // 2
368
+ return (step, gap)
369
+
370
+ def _sample_frames(self):
371
+ step, gap = self._get_step_and_gap()
372
+
373
+ ## compute total length of sample
374
+ ## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2:
375
+ ## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13
376
+ ## so len(video) must be >= 13
377
+ self._total_length = self.context_length * step + gap + (self.target_length - 1) * step
378
+ if self._total_length > (self._num_frames - self._start_frame):
379
+ if self.train:
380
+ return None
381
+ else:
382
+ raise ValueError(
383
+ "movie of length %d starting at fr=%d is too long for video of %d frames" % \
384
+ (self._total_length, self._start_frame, self._num_frames))
385
+
386
+ ## sample the frames randomly (if training) or from the start frame (if test)
387
+ if self.train:
388
+ self.start_frame_now = self.rng.randint(
389
+ min(self._start_frame, self._num_frames - self._total_length),
390
+ self._num_frames - self._total_length + 1)
391
+ else:
392
+ self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length)
393
+
394
+ frames = [self.start_frame_now + i * step for i in range(self.context_length)]
395
+ frames += [frames[-1] + gap + i * step for i in range(self.target_length)]
396
+
397
+ # breakpoint()
398
+
399
+ return frames
400
+
401
+ def _decode_frame_images(self, reader, frames):
402
+ try:
403
+ video_data = reader.get_batch(frames).asnumpy()
404
+ video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB')
405
+ for t, _ in enumerate(frames)]
406
+ except:
407
+ raise RuntimeError(
408
+ "Error occurred in reading frames {} from video {} of duration {}".format(
409
+ frames, self.index, self._num_frames))
410
+ return video_data
411
+
412
+ def __getitem__(self, index):
413
+
414
+ self.index = index
415
+ self.directory, target = self.clips[index]
416
+
417
+ self.video_name = self._get_video_name(self.directory)
418
+
419
+ ## build decord loader
420
+ try:
421
+ decord_vr = decord.VideoReader(self.video_name, num_threads=1)
422
+ self._set_fps(decord_vr)
423
+ except:
424
+ # return self.video_name
425
+ return (self.__getitem__(index + 1))
426
+
427
+ ## sample the video
428
+ self._num_frames = len(decord_vr)
429
+ self.frames = self._sample_frames()
430
+ if self.frames is None:
431
+ print("no movie of length %d for video idx=%d" % (self._total_length, self.index))
432
+ return self.__getitem__(index + 1)
433
+
434
+ ## decode to PIL.Image
435
+ image_list = self._decode_frame_images(decord_vr, self.frames)
436
+
437
+ ## postproc to torch.Tensor and mask generation
438
+ if self.transform is None:
439
+ image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0)
440
+ else:
441
+ image_tensor = self.transform((image_list, None))
442
+
443
+ image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:])
444
+
445
+ ## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W]
446
+ if self._channels_first:
447
+ image_tensor = image_tensor.transpose(0, 1)
448
+
449
+ if self._generate_masks and self.mask_generator is not None:
450
+ mask = self.mask_generator()
451
+ return image_tensor, mask.bool()
452
+ else:
453
+ return image_tensor
cwm/data/dataset_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from cwm.data.transforms import *
3
+ from cwm.data.dataset import ContextAndTargetVideoDataset
4
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
5
+ from cwm.data.masking_generator import RotatedTableMaskingGenerator
6
+
7
+ class DataAugmentationForVideoMAE(object):
8
+ def __init__(self, augmentation_type, input_size, augmentation_scales):
9
+
10
+ transform_list = []
11
+
12
+ self.scale = GroupScale(input_size)
13
+ transform_list.append(self.scale)
14
+
15
+ if augmentation_type == 'multiscale':
16
+ self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales))
17
+ elif augmentation_type == 'center':
18
+ self.train_augmentation = GroupCenterCrop(input_size)
19
+
20
+ transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)])
21
+
22
+ # Normalize input images
23
+ normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
24
+ transform_list.append(normalize)
25
+
26
+ self.transform = transforms.Compose(transform_list)
27
+
28
+ def __call__(self, images):
29
+ process_data, _ = self.transform(images)
30
+ return process_data
31
+
32
+ def __repr__(self):
33
+ repr = "(DataAugmentationForVideoMAE,\n"
34
+ repr += " transform = %s,\n" % str(self.transform)
35
+ repr += ")"
36
+ return repr
37
+
38
+
39
+ def build_pretraining_dataset(args):
40
+
41
+ dataset_list = []
42
+ data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales)
43
+
44
+ mask_generator = RotatedTableMaskingGenerator(
45
+ input_size=args.mask_input_size,
46
+ mask_ratio=args.mask_ratio,
47
+ tube_length=args.tubelet_size,
48
+ batch_size=args.batch_size,
49
+ mask_type=args.mask_type
50
+ )
51
+
52
+ for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list:
53
+ dataset = ContextAndTargetVideoDataset(
54
+ root=None,
55
+ setting=data_path,
56
+ video_ext='mp4',
57
+ is_color=True,
58
+ modality='rgb',
59
+ context_length=args.context_frames,
60
+ target_length=args.target_frames,
61
+ step_units=args.temporal_units,
62
+ new_step=args.sampling_rate,
63
+ context_target_gap=args.context_target_gap,
64
+ transform=data_transform,
65
+ randomize_interframes=False,
66
+ channels_first=True,
67
+ temporal_jitter=False,
68
+ train=True,
69
+ mask_generator=mask_generator,
70
+ )
71
+ dataset_list.append(dataset)
72
+ dataset = torch.utils.data.ConcatDataset(dataset_list)
73
+ return dataset
cwm/data/masking_generator.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def get_tubes(masks_per_frame, tube_length):
5
+ rp = torch.randperm(len(masks_per_frame))
6
+ masks_per_frame = masks_per_frame[rp]
7
+
8
+ tubes = [masks_per_frame]
9
+ for x in range(tube_length - 1):
10
+ masks_per_frame = masks_per_frame.clone()
11
+ rp = torch.randperm(len(masks_per_frame))
12
+ masks_per_frame = masks_per_frame[rp]
13
+ tubes.append(masks_per_frame)
14
+
15
+ tubes = torch.vstack(tubes)
16
+
17
+ return tubes
18
+
19
+ class RotatedTableMaskingGenerator:
20
+ def __init__(self,
21
+ input_size,
22
+ mask_ratio,
23
+ tube_length,
24
+ batch_size,
25
+ mask_type='rotated_table',
26
+ seed=None,
27
+ randomize_num_visible=False):
28
+
29
+ self.batch_size = batch_size
30
+
31
+ self.mask_ratio = mask_ratio
32
+ self.tube_length = tube_length
33
+
34
+ self.frames, self.height, self.width = input_size
35
+ self.num_patches_per_frame = self.height * self.width
36
+ self.total_patches = self.frames * self.num_patches_per_frame
37
+
38
+ self.seed = seed
39
+ self.randomize_num_visible = randomize_num_visible
40
+
41
+ self.mask_type = mask_type
42
+
43
+ def __repr__(self):
44
+ repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format(
45
+ self.total_patches, self.tube_length, self.randomize_num_visible, self.seed
46
+ )
47
+ return repr_str
48
+
49
+ def __call__(self, m=None):
50
+
51
+ if self.mask_type == 'rotated_table_magvit':
52
+ self.mask_ratio = np.random.uniform(low=0.0, high=1)
53
+ self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2)
54
+ elif self.mask_type == 'rotated_table_maskvit':
55
+ self.mask_ratio = np.random.uniform(low=0.5, high=1)
56
+
57
+ all_masks = []
58
+ for b in range(self.batch_size):
59
+
60
+ self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame))
61
+ self.total_masks = self.tube_length * self.num_masks_per_frame
62
+
63
+ num_masks = self.num_masks_per_frame
64
+
65
+ if self.randomize_num_visible:
66
+ assert "Randomize num visible Not implemented"
67
+ num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
68
+
69
+ if self.mask_ratio == 0:
70
+ mask_per_frame = torch.hstack([
71
+ torch.zeros(self.num_patches_per_frame - num_masks),
72
+ ])
73
+ else:
74
+ mask_per_frame = torch.hstack([
75
+ torch.zeros(self.num_patches_per_frame - num_masks),
76
+ torch.ones(num_masks),
77
+ ])
78
+
79
+ tubes = get_tubes(mask_per_frame, self.tube_length)
80
+ top = torch.zeros(self.height * self.width).to(tubes.dtype)
81
+
82
+ top = torch.tile(top, (self.frames - self.tube_length, 1))
83
+ mask = torch.cat([top, tubes])
84
+ mask = mask.flatten()
85
+ all_masks.append(mask)
86
+ return torch.stack(all_masks)
cwm/data/transforms.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms.functional as F
3
+ import warnings
4
+ import random
5
+ import numpy as np
6
+ import torchvision
7
+ from PIL import Image, ImageOps
8
+ import numbers
9
+
10
+
11
+ class GroupRandomCrop(object):
12
+ def __init__(self, size):
13
+ if isinstance(size, numbers.Number):
14
+ self.size = (int(size), int(size))
15
+ else:
16
+ self.size = size
17
+
18
+ def __call__(self, img_tuple):
19
+ img_group, label = img_tuple
20
+
21
+ w, h = img_group[0].size
22
+ th, tw = self.size
23
+
24
+ out_images = list()
25
+
26
+ x1 = random.randint(0, w - tw)
27
+ y1 = random.randint(0, h - th)
28
+
29
+ for img in img_group:
30
+ assert(img.size[0] == w and img.size[1] == h)
31
+ if w == tw and h == th:
32
+ out_images.append(img)
33
+ else:
34
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35
+
36
+ return (out_images, label)
37
+
38
+
39
+ class GroupCenterCrop(object):
40
+ def __init__(self, size):
41
+ self.worker = torchvision.transforms.CenterCrop(size)
42
+
43
+ def __call__(self, img_tuple):
44
+ img_group, label = img_tuple
45
+ return ([self.worker(img) for img in img_group], label)
46
+
47
+
48
+ class GroupNormalize(object):
49
+ def __init__(self, mean, std):
50
+ self.mean = mean
51
+ self.std = std
52
+
53
+ def __call__(self, tensor_tuple):
54
+ tensor, label = tensor_tuple
55
+ rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56
+ rep_std = self.std * (tensor.size()[0]//len(self.std))
57
+
58
+ # TODO: make efficient
59
+ for t, m, s in zip(tensor, rep_mean, rep_std):
60
+ t.sub_(m).div_(s)
61
+
62
+ return (tensor,label)
63
+
64
+
65
+ class GroupGrayScale(object):
66
+ def __init__(self, size):
67
+ self.worker = torchvision.transforms.Grayscale(size)
68
+
69
+ def __call__(self, img_tuple):
70
+ img_group, label = img_tuple
71
+ return ([self.worker(img) for img in img_group], label)
72
+
73
+
74
+ class GroupScale(object):
75
+ """ Rescales the input PIL.Image to the given 'size'.
76
+ 'size' will be the size of the smaller edge.
77
+ For example, if height > width, then image will be
78
+ rescaled to (size * height / width, size)
79
+ size: size of the smaller edge
80
+ interpolation: Default: PIL.Image.BILINEAR
81
+ """
82
+
83
+ def __init__(self, size, interpolation=Image.BILINEAR):
84
+ self.worker = torchvision.transforms.Resize(size, interpolation)
85
+
86
+ def __call__(self, img_tuple):
87
+ img_group, label = img_tuple
88
+ return ([self.worker(img) for img in img_group], label)
89
+
90
+
91
+ class GroupMultiScaleCrop(object):
92
+
93
+ def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
94
+ self.scales = scales if scales is not None else [1, 875, .75, .66]
95
+ self.max_distort = max_distort
96
+ self.fix_crop = fix_crop
97
+ self.more_fix_crop = more_fix_crop
98
+ self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
99
+ self.interpolation = Image.BILINEAR
100
+
101
+ def __call__(self, img_tuple):
102
+ img_group, label = img_tuple
103
+
104
+ im_size = img_group[0].size
105
+
106
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
107
+ crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
108
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
109
+ return (ret_img_group, label)
110
+
111
+ def _sample_crop_size(self, im_size):
112
+ image_w, image_h = im_size[0], im_size[1]
113
+
114
+ # find a crop size
115
+ base_size = min(image_w, image_h)
116
+ crop_sizes = [int(base_size * x) for x in self.scales]
117
+ crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
118
+ crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
119
+
120
+ pairs = []
121
+ for i, h in enumerate(crop_h):
122
+ for j, w in enumerate(crop_w):
123
+ if abs(i - j) <= self.max_distort:
124
+ pairs.append((w, h))
125
+
126
+ crop_pair = random.choice(pairs)
127
+ if not self.fix_crop:
128
+ w_offset = random.randint(0, image_w - crop_pair[0])
129
+ h_offset = random.randint(0, image_h - crop_pair[1])
130
+ else:
131
+ w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
132
+
133
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
134
+
135
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
136
+ offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
137
+ return random.choice(offsets)
138
+
139
+ @staticmethod
140
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
141
+ w_step = (image_w - crop_w) // 4
142
+ h_step = (image_h - crop_h) // 4
143
+
144
+ ret = list()
145
+ ret.append((0, 0)) # upper left
146
+ ret.append((4 * w_step, 0)) # upper right
147
+ ret.append((0, 4 * h_step)) # lower left
148
+ ret.append((4 * w_step, 4 * h_step)) # lower right
149
+ ret.append((2 * w_step, 2 * h_step)) # center
150
+
151
+ if more_fix_crop:
152
+ ret.append((0, 2 * h_step)) # center left
153
+ ret.append((4 * w_step, 2 * h_step)) # center right
154
+ ret.append((2 * w_step, 4 * h_step)) # lower center
155
+ ret.append((2 * w_step, 0 * h_step)) # upper center
156
+
157
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
158
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
159
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
160
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
161
+ return ret
162
+
163
+
164
+ class Stack(object):
165
+
166
+ def __init__(self, roll=False):
167
+ self.roll = roll
168
+
169
+ def __call__(self, img_tuple):
170
+ img_group, label = img_tuple
171
+
172
+ if img_group[0].mode == 'L':
173
+ return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
174
+ elif img_group[0].mode == 'RGB':
175
+ if self.roll:
176
+ return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
177
+ else:
178
+ return (np.concatenate(img_group, axis=2), label)
179
+
180
+
181
+ class ToTorchFormatTensor(object):
182
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
183
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
184
+ def __init__(self, div=True):
185
+ self.div = div
186
+
187
+ def __call__(self, pic_tuple):
188
+ pic, label = pic_tuple
189
+
190
+ if isinstance(pic, np.ndarray):
191
+ # handle numpy array
192
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
193
+ else:
194
+ # handle PIL Image
195
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
196
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
197
+ # put it from HWC to CHW format
198
+ # yikes, this transpose takes 80% of the loading time/CPU
199
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
200
+ return (img.float().div(255.) if self.div else img.float(), label)
201
+
202
+
203
+ class IdentityTransform(object):
204
+
205
+ def __call__(self, data):
206
+ return data
cwm/data/video_file_lists/kinetics_400_train_list.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65e14c0735b4c90c57022add2407a8524d246cc09b3d5a7e83b963ac3b231032
3
+ size 19539143
cwm/data/video_file_lists/kinetics_400_train_list_sing.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b18ccdce4616fb32a0aababc2342a640a11f7d73439f49358a16cc99e7eaed3
3
+ size 1943
cwm/engine_for_pretraining.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ from typing import Iterable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
9
+
10
+ import utils
11
+
12
+ from datetime import datetime
13
+
14
+
15
+ def train_one_epoch(model: torch.nn.Module,
16
+ data_loader: Iterable,
17
+ optimizer: torch.optim.Optimizer,
18
+ device: torch.device,
19
+ epoch: int,
20
+ loss_scaler,
21
+ start_steps=None,
22
+ lr_schedule_values=None,
23
+ wd_schedule_values=None,
24
+ global_rank=None,
25
+ args=None,
26
+ loss_func = nn.MSELoss(),
27
+ ):
28
+
29
+ metric_logger = utils.MetricLogger(delimiter=" ")
30
+
31
+ if args.eval:
32
+ model.eval()
33
+ else:
34
+ model.train()
35
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
36
+
37
+ header = f'Epoch [{epoch}]'
38
+ patch_size = model.module.encoder.patch_size[-2:]
39
+ tubelet_size = model.module.encoder.patch_size[0]
40
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
41
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
42
+
43
+ for step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
44
+
45
+ # assign learning rate & weight decay for each iteration
46
+ it = start_steps + step # global training iteration
47
+ if (lr_schedule_values is not None or wd_schedule_values is not None) and (step % args.accum_iter == 0):
48
+ for i, param_group in enumerate(optimizer.param_groups):
49
+ if lr_schedule_values is not None:
50
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
51
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
52
+ param_group["weight_decay"] = wd_schedule_values[it]
53
+
54
+ # prepare input
55
+ videos, bool_masked_pos = batch
56
+ videos = videos.to(device, non_blocking=True)
57
+ bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1)
58
+
59
+ # prepare target
60
+ with torch.no_grad():
61
+ unnorm_videos = videos * std + mean # in [0, 1]
62
+ videos_patch = utils.patchify(unnorm_videos, tubelet_size, patch_size)
63
+ B, _, C = videos_patch.shape
64
+ labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
65
+
66
+ # feedforward
67
+ with torch.cuda.amp.autocast(enabled=True):
68
+ outputs = model(videos, bool_masked_pos)
69
+ loss = loss_func(input=outputs, target=labels)
70
+
71
+ loss_value = loss.item()
72
+
73
+ # backward
74
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
75
+ loss /= args.accum_iter
76
+ loss_scaler(loss, optimizer, clip_grad=None,
77
+ parameters=model.parameters(), create_graph=is_second_order,
78
+ update_grad=(step + 1) % args.accum_iter == 0)
79
+
80
+ torch.cuda.synchronize()
81
+ metric_logger.update(loss=loss_value)
82
+
83
+ if (step + 1) % args.accum_iter == 0:
84
+ optimizer.zero_grad()
85
+
86
+ lr = optimizer.param_groups[0]["lr"]
87
+ metric_logger.update(lr=lr)
88
+
89
+ # gather the stats from all processes
90
+ metric_logger.synchronize_between_processes()
91
+ print("Averaged stats:", metric_logger)
92
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
cwm/eval/Action_recognition/__init__.py ADDED
File without changes
cwm/eval/Flow/__init__.py ADDED
File without changes
cwm/eval/Flow/create_spring_submission_parallel.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define the path to the dataset and the Python script
4
+ DATASET_PATH="/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/"
5
+ SCRIPT_PATH="./create_spring_submission_unified.py"
6
+ SAVE_DATA_PATH=${1}
7
+ MODEL=${2}
8
+ # Counter for GPUs
9
+ GPU_COUNTER=0
10
+
11
+ # Number of GPUs available
12
+ NUM_GPUS=8
13
+
14
+ #kill session
15
+ tmux kill-session -t extraction
16
+
17
+ tmux new-session -d -s "extraction"
18
+
19
+ # Iterate through each folder in the dataset
20
+ for FOLDER in $(find $DATASET_PATH -mindepth 1 -maxdepth 1 -type d | sort); do
21
+ # Extract the folder name for the tmux session name
22
+ FOLDER_NAME=$(basename $FOLDER)
23
+
24
+ # Create a new detached tmux session for each folder
25
+ tmux new-window -t extraction -n "$FOLDER" "ulimit -n 65535; CUDA_VISIBLE_DEVICES=$GPU_COUNTER python $SCRIPT_PATH --folder $FOLDER --gpu $GPU_COUNTER --save_data_path $SAVE_DATA_PATH --model $MODEL; echo 'Press Enter to continue...'; read -p ''"
26
+ # Increment the GPU counter and reset if it exceeds the number of GPUs
27
+ GPU_COUNTER=$((GPU_COUNTER + 1))
28
+ if [ $GPU_COUNTER -ge $NUM_GPUS ]; then
29
+ GPU_COUNTER=0
30
+ fi
31
+
32
+ sleep 1
33
+ done
34
+
35
+ tmux attach-session -t extraction
36
+
cwm/eval/Flow/create_spring_submission_unified.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ # Parse command-line arguments
5
+ import importlib
6
+ import time
7
+
8
+ parser = argparse.ArgumentParser(description='Process a folder with RAFT')
9
+ parser.add_argument('--folder', type=str, required=True, help='Folder to process')
10
+ parser.add_argument('--model', type=str, required=True, help='Model used to extract flow')
11
+ parser.add_argument('--save_data_path', type=str, required=True, help='where to save the data')
12
+ parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
13
+ args = parser.parse_args()
14
+ import os
15
+
16
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
17
+ import torch
18
+ torch.cuda.set_device(0)
19
+
20
+ import h5py
21
+
22
+ def writeFlo5File(flow, filename):
23
+ with h5py.File(filename, "w") as f:
24
+ f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
25
+
26
+ if __name__ == '__main__':
27
+ module_name, class_name = args.model.rsplit(".", 1)
28
+ module = importlib.import_module(module_name)
29
+
30
+ model = getattr(module, class_name)
31
+ model = model().cuda().eval()
32
+
33
+ folder = args.folder.split('/')[-1]
34
+
35
+ import os
36
+ import matplotlib.pyplot as plt
37
+
38
+ import torch
39
+ import torchvision.transforms as transforms
40
+
41
+ # import smurf # Assuming this is your custom inference module
42
+
43
+ # Path for the dataset
44
+ dataset_path = '/ccn2/dataset/Flows_Kinetics/SPRING/spring/test/'
45
+
46
+ save_data_path = args.save_data_path
47
+
48
+ if not os.path.exists(save_data_path):
49
+ os.makedirs(save_data_path)
50
+
51
+ resize_crop = transforms.Compose([
52
+ transforms.ToTensor(),
53
+ ])
54
+
55
+ import numpy as np
56
+
57
+ def l2norm(x):
58
+ return np.sqrt((x ** 2).sum(-1))
59
+
60
+ all_epe = []
61
+ # Create a new HDF5 file
62
+
63
+ TAG_FLOAT = 202021.25
64
+
65
+ # Iterate over each folder in the dataset directory
66
+ for dir in ['FW', 'BW']:
67
+ for stereo in ['left', 'right']:
68
+ files = sorted(os.listdir(os.path.join(dataset_path, folder, f'frame_{stereo}')))
69
+ output_folder = os.path.join(save_data_path, folder)
70
+ output_folder = os.path.join(output_folder, f'flow_{dir}_{stereo}')
71
+
72
+ if not os.path.exists(output_folder):
73
+ os.makedirs(output_folder)
74
+
75
+ for ct_f in range(len(files) - 1):
76
+ # Read images
77
+ if dir == 'FW':
78
+ f1 = files[ct_f]
79
+ f2 = files[ct_f + 1]
80
+ else:
81
+ f2 = files[ct_f]
82
+ f1 = files[ct_f + 1]
83
+ t = time.time()
84
+ image1_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f1)
85
+ image2_path = os.path.join(dataset_path, folder, f'frame_{stereo}', f2)
86
+
87
+ idx = image1_path.split('/')[-1].split('.')[0].split('_')[-1]
88
+ flow_save_path = os.path.join(output_folder, f'flow_{dir}_{stereo}_' + idx + '.flo5')
89
+
90
+ # if os.path.exists(flow_save_path):
91
+ # try:
92
+ # with h5py.File(flow_save_path, 'r+') as f:
93
+ # if f['flow'][:].shape[0] == 2:
94
+ # flow = f['flow'][:].transpose([1, 2, 0])
95
+ # del f['flow']
96
+ # f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
97
+ # continue
98
+ # else:
99
+ # continue
100
+ # except:
101
+ # pass
102
+
103
+ image1_ = plt.imread(image1_path)
104
+ image2_ = plt.imread(image2_path)
105
+
106
+ image1 = resize_crop(image1_)
107
+ image2 = resize_crop(image2_)
108
+
109
+ forward_flow = model.forward(image1, image2)
110
+
111
+ writeFlo5File(forward_flow, flow_save_path)
cwm/eval/Flow/flow_extraction_classes.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import cwm.model.model_pretrain as vmae_tranformers
5
+ from . import flow_utils
6
+ from . import losses as bblosses
7
+
8
+
9
+ # Normal Resolution
10
+ def l2_norm(x):
11
+ return x.square().sum(-3, True).sqrt()
12
+
13
+
14
+
15
+
16
+ # x.shape
17
+ def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
18
+ fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
19
+ flow_diff_fwd = flow_fwd + fwd_bck_cycle
20
+
21
+ bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
22
+ flow_diff_bck = flow_bck + bck_fwd_cycle
23
+
24
+ norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
25
+ norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
26
+
27
+ occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
28
+ occ_thresh_bck = occ_thresh * norm_bck + 0.5
29
+
30
+ occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
31
+ occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
32
+
33
+ return occ_mask_fwd, occ_mask_bck
34
+
35
+
36
+ class ExtractFlow(nn.Module):
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+ return
41
+
42
+ def forward(self, img1, img2):
43
+ '''
44
+ img1: first frame
45
+ img2: second frame
46
+ returns: flow map (h, w, 2)
47
+ '''
48
+
49
+ from cwm.data.masking_generator import RotatedTableMaskingGenerator
50
+
51
+ class CWM(ExtractFlow):
52
+ def __init__(self, model_name, patch_size, weights_path):
53
+ super().__init__()
54
+
55
+ self.patch_size = patch_size
56
+ model = getattr(vmae_tranformers, model_name)
57
+ vmae_8x8_full = model().cuda().eval().requires_grad_(False)
58
+
59
+ VMAE_LOAD_PATH = weights_path
60
+ did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False)
61
+ print(did_load, VMAE_LOAD_PATH)
62
+
63
+ self.predictor = vmae_8x8_full
64
+
65
+ self.mask_generator = RotatedTableMaskingGenerator(
66
+ input_size=(vmae_8x8_full.num_frames, 28, 28),
67
+ mask_ratio=0.0,
68
+ tube_length=1,
69
+ batch_size=1,
70
+ mask_type='rotated_table'
71
+ )
72
+
73
+ def forward(self, img1, img2):
74
+ '''
75
+ img1: [3, 1024, 1024]
76
+ img1: [3, 1024, 1024]
77
+ both images are imagenet normalized
78
+ '''
79
+
80
+ with torch.no_grad():
81
+ FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
82
+ self.mask_generator, img1[None],
83
+ img2[None],
84
+ num_scales=2,
85
+ min_scale=224,
86
+ N_mask_samples=1)
87
+
88
+ BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
89
+ self.mask_generator,
90
+ img2[None],
91
+ img1[None],
92
+ num_scales=2,
93
+ min_scale=224,
94
+ N_mask_samples=1)
95
+
96
+ # FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
97
+ # self.mask_generator, img1[None],
98
+ # img2[None], img2[None],
99
+ # neg_back_flow=True, num_scales=1,
100
+ # min_scale=224, N_mask_samples=1,
101
+ # mask_ratio=0.0)
102
+ #
103
+ # BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
104
+ # self.mask_generator, img2[None],
105
+ # img1[None], img1[None],
106
+ # neg_back_flow=True, num_scales=1,
107
+ # min_scale=224, N_mask_samples=1,
108
+ # mask_ratio=0.0)
109
+
110
+ occ_mask = get_occ_masks(FF, BF)[0]
111
+
112
+ FF = FF * occ_mask
113
+
114
+ FF = FF[0]
115
+
116
+ return FF#.cpu().numpy().transpose([1, 2, 0])
117
+
118
+
119
+ class CWM_8x8(CWM):
120
+ def __init__(self):
121
+ super().__init__('vitb_8x8patch_3frames', 8,
122
+ '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth')
cwm/eval/Flow/flow_utils.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from . import losses as bblosses
7
+ import kornia
8
+
9
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
11
+
12
+ def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
13
+ # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
14
+ mask_unrolled = mask_tensor.view(-1)
15
+
16
+ second_frame_unmask_indices = torch.where(mask_unrolled[frame_size ** 2:] == False)[0]
17
+
18
+ # Divide the embedding tensor into two parts: corresponding to the first and the second frame
19
+ first_frame_embeddings = embedding_tensor[0, :frame_size ** 2, :]
20
+ second_frame_embeddings = embedding_tensor[0, frame_size ** 2:, :]
21
+
22
+ # print(first_frame_embeddings.shape, second_frame_embeddings.shape, embedding_tensor.shape)
23
+
24
+ # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
25
+ dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
26
+ norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
27
+ cos_sim_matrix = dot_product / norms
28
+
29
+ # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
30
+ first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
31
+
32
+ # Convert the 1D pixel indices into 2D coordinates
33
+ second_frame_y = second_frame_unmask_indices // frame_size
34
+ second_frame_x = second_frame_unmask_indices % frame_size
35
+ first_frame_y = first_frame_most_similar_indices // frame_size
36
+ first_frame_x = first_frame_most_similar_indices % frame_size
37
+
38
+ # Compute the x and y displacements and convert them to float
39
+ displacements_x = (second_frame_x - first_frame_x).float()
40
+ displacements_y = (second_frame_y - first_frame_y).float()
41
+
42
+ # Initialize optical flow tensor
43
+ optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
44
+
45
+ # Assign the computed displacements to the corresponding pixels in the optical flow tensor
46
+ optical_flow[0, second_frame_y, second_frame_x] = displacements_x
47
+ optical_flow[1, second_frame_y, second_frame_x] = displacements_y
48
+
49
+ return optical_flow
50
+
51
+
52
+ def get_minimal_224_crops_new_batched(video_tensor, N):
53
+ B, T, C, H, W = video_tensor.shape
54
+
55
+ # Calculate the number of crops needed in both the height and width dimensions
56
+ num_crops_h = math.ceil(H / 224) if H > 224 else 1
57
+ num_crops_w = math.ceil(W / 224) if W > 224 else 1
58
+
59
+ # Calculate the step size for the height and width dimensions
60
+ step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
61
+ step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
62
+
63
+ # Create a list to store the cropped tensors and their start positions
64
+ cropped_tensors = []
65
+ crop_positions = []
66
+
67
+ # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
68
+ for i in range(num_crops_h):
69
+ for j in range(num_crops_w):
70
+ start_h = i * step_size_h
71
+ start_w = j * step_size_w
72
+ end_h = min(start_h + 224, H)
73
+ end_w = min(start_w + 224, W)
74
+ crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
75
+ cropped_tensors.append(crop)
76
+ crop_positions.append((start_h, start_w))
77
+
78
+ D = len(cropped_tensors)
79
+
80
+ # If N is greater than D, generate additional random crops
81
+ if N > D and H > 224 and W > 224: # check if H and W are greater than 224
82
+ for _ in range(N - D):
83
+ start_h = random.randint(0, H - 224)
84
+ start_w = random.randint(0, W - 224)
85
+ crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
86
+ cropped_tensors.append(crop)
87
+ crop_positions.append((start_h, start_w))
88
+
89
+ # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
90
+ cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
91
+
92
+ return cropped_tensors, crop_positions
93
+
94
+
95
+ def create_weighted_mask_batched(h, w):
96
+ y_mask = np.linspace(0, 1, h)
97
+ y_mask = np.minimum(y_mask, 1 - y_mask)
98
+ x_mask = np.linspace(0, 1, w)
99
+ x_mask = np.minimum(x_mask, 1 - x_mask)
100
+ weighted_mask = np.outer(y_mask, x_mask)
101
+ return torch.from_numpy(weighted_mask).float()
102
+
103
+
104
+ def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
105
+ B, T, C, H, W = original_shape
106
+
107
+ # Initialize an empty tensor to store the reconstructed video
108
+ reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
109
+
110
+ # Create a tensor to store the sum of weighted masks
111
+ weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
112
+
113
+ # Create a weighted mask for the crops
114
+ weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
115
+ weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
116
+
117
+ for idx, crop in enumerate(cropped_tensors):
118
+ start_h, start_w = crop_positions[idx]
119
+
120
+ # Multiply the crop with the weighted mask
121
+ weighted_crop = crop * weighted_mask
122
+
123
+ # Add the weighted crop to the corresponding location in the reconstructed_video tensor
124
+ reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
125
+
126
+ # Update the weighted_masks_sum tensor
127
+ weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
128
+
129
+ # Add a small epsilon value to avoid division by zero
130
+ epsilon = 1e-8
131
+
132
+ # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
133
+ reconstructed_video /= (weighted_masks_sum + epsilon)
134
+
135
+ return reconstructed_video
136
+
137
+
138
+ def l2_norm(x):
139
+ return x.square().sum(-3, True).sqrt()
140
+
141
+
142
+ resize = lambda x, a: F.interpolate(x, [int(a * x.shape[-2]), int(a * x.shape[-1])], mode='bilinear',
143
+ align_corners=False)
144
+
145
+ upsample = lambda x, H, W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
146
+
147
+
148
+ def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
149
+ fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
150
+ flow_diff_fwd = flow_fwd + fwd_bck_cycle
151
+
152
+ bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
153
+ flow_diff_bck = flow_bck + bck_fwd_cycle
154
+
155
+ norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
156
+ norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
157
+
158
+ occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
159
+ occ_thresh_bck = occ_thresh * norm_bck + 0.5
160
+
161
+ occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
162
+ occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
163
+
164
+ return occ_mask_fwd, occ_mask_bck
165
+
166
+ def forward_backward_cycle_consistency(flow_fwd, flow_bck, niters=10):
167
+ # Make sure to be using axes-swapped, upsampled flows!
168
+ bck_flow_clone = flow_bck.clone().detach()
169
+ fwd_flow_clone = flow_fwd.clone().detach()
170
+
171
+ for i in range(niters):
172
+
173
+ fwd_bck_cycle_orig, _ = bblosses.backward_warp(img2=bck_flow_clone, flow=fwd_flow_clone)
174
+ flow_diff_fwd_orig = fwd_flow_clone + fwd_bck_cycle_orig
175
+
176
+ fwd_flow_clone = fwd_flow_clone - flow_diff_fwd_orig/2
177
+
178
+ bck_fwd_cycle_orig, _ = bblosses.backward_warp(img2=fwd_flow_clone, flow=bck_flow_clone)
179
+ flow_diff_bck_orig = bck_flow_clone + bck_fwd_cycle_orig
180
+
181
+
182
+ bck_flow_clone = bck_flow_clone - flow_diff_bck_orig/2
183
+
184
+ return fwd_flow_clone, bck_flow_clone
185
+
186
+ from PIL import Image
187
+ def resize_flow_map(flow_map, target_size):
188
+ """
189
+ Resize a flow map to a target size while adjusting the flow vectors.
190
+
191
+ Parameters:
192
+ flow_map (numpy.ndarray): Input flow map of shape (H, W, 2) where each pixel contains a (dx, dy) flow vector.
193
+ target_size (tuple): Target size (height, width) for the resized flow map.
194
+
195
+ Returns:
196
+ numpy.ndarray: Resized and scaled flow map of shape (target_size[0], target_size[1], 2).
197
+ """
198
+ # Get the original size
199
+ flow_map = flow_map[0].detach().cpu().numpy()
200
+ flow_map = flow_map.transpose(1, 2, 0)
201
+ original_size = flow_map.shape[:2]
202
+
203
+ # Separate the flow map into two channels: dx and dy
204
+ flow_map_x = flow_map[:, :, 0]
205
+ flow_map_y = flow_map[:, :, 1]
206
+
207
+ # Convert each flow channel to a PIL image for resizing
208
+ flow_map_x_img = Image.fromarray(flow_map_x)
209
+ flow_map_y_img = Image.fromarray(flow_map_y)
210
+
211
+ # Resize both channels to the target size using bilinear interpolation
212
+ flow_map_x_resized = flow_map_x_img.resize(target_size, Image.BILINEAR)
213
+ flow_map_y_resized = flow_map_y_img.resize(target_size, Image.BILINEAR)
214
+
215
+ # Convert resized PIL images back to NumPy arrays
216
+ flow_map_x_resized = np.array(flow_map_x_resized)
217
+ flow_map_y_resized = np.array(flow_map_y_resized)
218
+
219
+ # Compute the scaling factor based on the size change
220
+ scale_factor = target_size[0] / original_size[0] # Scaling factor for both dx and dy
221
+
222
+ # Scale the flow vectors (dx and dy) accordingly
223
+ flow_map_x_resized *= scale_factor
224
+ flow_map_y_resized *= scale_factor
225
+
226
+ # Recombine the two channels into a resized flow map
227
+ flow_map_resized = np.stack([flow_map_x_resized, flow_map_y_resized], axis=-1)
228
+
229
+ flow_map_resized = torch.from_numpy(flow_map_resized)[None].permute(0, 3, 1, 2)
230
+
231
+ return flow_map_resized
232
+
233
+ def get_vmae_optical_flow_crop_batched_smoothed(generator,
234
+ mask_generator,
235
+ img1,
236
+ img2,
237
+ neg_back_flow=True,
238
+ num_scales=1,
239
+ min_scale=400,
240
+ N_mask_samples=100,
241
+ mask_ratio=0.8,
242
+ smoothing_factor=1):
243
+
244
+ ##### DEPRECATED
245
+ print('Deprecated. Please use scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed')
246
+
247
+ return scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
248
+ mask_generator,
249
+ img1,
250
+ img2,
251
+ neg_back_flow=neg_back_flow,
252
+ num_scales=num_scales,
253
+ min_scale=min_scale,
254
+ N_mask_samples=N_mask_samples,
255
+ mask_ratio=mask_ratio,
256
+ smoothing_factor=smoothing_factor)
257
+
258
+
259
+
260
+ def average_crops(tensor, D):
261
+ C, H, W = tensor.shape
262
+
263
+ # Create zero-filled tensors for the shifted crops
264
+ down_shifted = torch.zeros_like(tensor)
265
+ up_shifted = torch.zeros_like(tensor)
266
+ right_shifted = torch.zeros_like(tensor)
267
+ left_shifted = torch.zeros_like(tensor)
268
+
269
+ # Shift the tensor and store the results in the zero-filled tensors
270
+ down_shifted[:, :H-D, :] = tensor[:, D:, :]
271
+ up_shifted[:, D:, :] = tensor[:, :H-D, :]
272
+ right_shifted[:, :, :W-D] = tensor[:, :, D:]
273
+ left_shifted[:, :, D:] = tensor[:, :, :W-D]
274
+
275
+ # Average the tensor with its four crops
276
+ result = (tensor + down_shifted + up_shifted + right_shifted + left_shifted) / 5.0
277
+
278
+ return result
279
+
280
+
281
+ def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(predictor,
282
+ mask_generator,
283
+ img1,
284
+ img2,
285
+ conditioning_img=None,
286
+ num_scales=1,
287
+ min_scale=400,
288
+ N_mask_samples=100,
289
+ smoothing_factor=1):
290
+ B = img1.shape[0]
291
+ assert len(img1.shape) == 4
292
+ assert num_scales >= 1
293
+
294
+ # For scaling
295
+ h1 = img2.shape[-2]
296
+ w1 = img2.shape[-1]
297
+
298
+
299
+ alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
300
+
301
+ frame_size = 224 // predictor.patch_size[-1]
302
+
303
+ patch_size = predictor.patch_size[-1]
304
+
305
+ num_frames = predictor.num_frames
306
+
307
+ all_fwd_flows_e2d = []
308
+
309
+ s_hs = []
310
+ s_ws = []
311
+
312
+ for aidx in range(num_scales):
313
+ # print(aidx)
314
+
315
+ # print('aidx: ', aidx)
316
+
317
+ img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
318
+ mode='bicubic', align_corners=True)
319
+ img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
320
+ mode='bicubic', align_corners=True)
321
+
322
+ if conditioning_img is not None:
323
+ conditioning_img_scaled = F.interpolate(conditioning_img.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
324
+ mode='bilinear', align_corners=False)
325
+
326
+ # print("img1_scaled", img1_scaled.shape, alpha, min_scale, num_scales)
327
+
328
+ h2 = img2_scaled.shape[-2]
329
+ w2 = img2_scaled.shape[-1]
330
+
331
+ s_h = h1 / h2
332
+ s_w = w1 / w2
333
+
334
+ s_hs.append(s_h)
335
+ s_ws.append(s_w)
336
+
337
+ if conditioning_img is not None:
338
+ video = torch.cat([conditioning_img_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
339
+ else:
340
+ video = torch.cat([img2_scaled.unsqueeze(1)]*(num_frames-1) + [img1_scaled.unsqueeze(1)], 1)
341
+
342
+ # Should work, even if the incoming video is already 224x224
343
+ crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
344
+
345
+ num_crops = len(crops1)
346
+
347
+ crop_flows_enc = []
348
+ crop_flows_enc2dec = []
349
+ N_samples = N_mask_samples
350
+
351
+ crop = torch.cat(crops1, 0).cuda()
352
+
353
+ optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
354
+ mask_counts = torch.zeros(frame_size, frame_size).cuda()
355
+
356
+ i = 0
357
+ while i < N_samples or (mask_counts == 0).any().item():
358
+ if i % 100 == 0:
359
+ pass # print(i)
360
+
361
+ # This would be that every sample has the same mask. For now that's okay I think
362
+ mask = mask_generator().bool().cuda()
363
+ mask_2f = ~mask[0, (frame_size * frame_size)*(num_frames-1):]
364
+ mask_counts += mask_2f.reshape(frame_size, frame_size)
365
+
366
+ with torch.cuda.amp.autocast(enabled=True):
367
+
368
+ processed_x = crop.transpose(1, 2)
369
+
370
+ encoder_out = predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
371
+ encoder_to_decoder = predictor.encoder_to_decoder(encoder_out)
372
+
373
+ encoder_to_decoder = encoder_to_decoder[:, (frame_size * frame_size)*(num_frames-2):, :]
374
+ flow_mask = mask[:, (frame_size * frame_size)*(num_frames-2):]
375
+
376
+ optical_flow_e2d = []
377
+ # one per batch element for now
378
+ for b in range(B * num_crops):
379
+ batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
380
+ # optical_flow_e2d.append(batch_flow.unsqueeze(0))
381
+
382
+ optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
383
+
384
+ optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
385
+ optical_flows_enc2dec += optical_flow_e2d
386
+ i += 1
387
+
388
+ optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
389
+
390
+ #other fucntion
391
+ # scale_factor_y = video.shape[-2] / 224
392
+ # scale_factor_x = video.shape[-1] / 224
393
+ #
394
+ # scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
395
+ # scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
396
+ # scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
397
+ #
398
+ # # split the crops back up
399
+ # crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
400
+
401
+ ###
402
+ #Kevin's fn
403
+ crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
404
+
405
+ ###
406
+
407
+ #Changed by Kevin
408
+ T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
409
+ crop_flows_enc2dec]
410
+ optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
411
+ B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
412
+
413
+ #other function
414
+ # optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
415
+ # [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
416
+ # crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
417
+ #
418
+ all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
419
+
420
+ #other function
421
+ # all_fwd_flows_e2d_new = []
422
+ #
423
+ # for r in all_fwd_flows_e2d:
424
+ # new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
425
+ # all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
426
+ # return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
427
+ #
428
+ #
429
+ # return_flow = -return_flow
430
+ # all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
431
+ #
432
+ # return return_flow, all_fwd_flows_e2d_new
433
+
434
+ #Kevin's method
435
+ all_fwd_flows_e2d_new = []
436
+
437
+ for ridx, r in enumerate(all_fwd_flows_e2d):
438
+ # print('ridx', ridx)
439
+ # print('sh', s_hs[ridx])
440
+ # print('sw', s_ws[ridx])
441
+ # print('scale_fac y', scale_ys[ridx])
442
+ # print('scale_fac x', scale_xs[ridx])
443
+
444
+ _sh = s_hs[ridx]
445
+ _sw = s_ws[ridx]
446
+ _sfy = predictor.patch_size[-1]
447
+ _sfx = predictor.patch_size[-1]
448
+
449
+ # plt.figure(figsize=(20, 20))
450
+
451
+ # plt.subplot(1,3,1)
452
+ # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
453
+
454
+ # plt.subplot(1,3,2)
455
+ new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], mode='bicubic', align_corners=True)
456
+ # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
457
+
458
+ scaled_new_r = torch.zeros_like(new_r)
459
+ scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
460
+ scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
461
+
462
+ # plt.subplot(1,3,3)
463
+ # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
464
+
465
+ # plt.show()
466
+
467
+ all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
468
+ return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
469
+
470
+ return_flow = -return_flow
471
+ all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
472
+
473
+ return return_flow , all_fwd_flows_e2d_new
474
+
475
+ def extract_jacobians_and_flows(img1, img2,
476
+ flow_generator,
477
+ mask,
478
+ target_mask=None):
479
+
480
+ IMAGE_SIZE = img1.shape[-2:]
481
+
482
+ y = torch.cat([img2.unsqueeze(1), img1.unsqueeze(1)], 1)
483
+
484
+ jacobians, flows, _ = flow_generator(y, mask, target_mask)
485
+
486
+ # swap x,y flow dims
487
+ flows = torch.cat([flows[0, 1].unsqueeze(0), flows[0, 0].unsqueeze(0)])
488
+
489
+ # upsample to 224
490
+ flows = flows.unsqueeze(0).repeat_interleave(IMAGE_SIZE[0] // flows.shape[-1], -1).repeat_interleave(
491
+ IMAGE_SIZE[0] // flows.shape[-1], -2)
492
+
493
+ return jacobians, flows
494
+
495
+ import matplotlib.pyplot as plt
496
+
497
+ class FlowToRgb(object):
498
+
499
+ def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
500
+ self.max_speed = max_speed
501
+ self.from_image_coordinates = from_image_coordinates
502
+ self.from_sampling_grid = from_sampling_grid
503
+
504
+ def __call__(self, flow):
505
+ assert flow.size(-3) == 2, flow.shape
506
+ if self.from_sampling_grid:
507
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
508
+ flow_y = -flow_y
509
+ elif not self.from_image_coordinates:
510
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
511
+ else:
512
+ flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
513
+ flow_x, flow_y = [flow_w, -flow_h]
514
+
515
+
516
+ # print("flow_x", flow_x[0, :, 0, 0], flow_y[0, :, 0, 0])
517
+ angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
518
+ speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
519
+
520
+ # print("angle", angle[0, :, 0, 0] * 180 / np.pi)
521
+
522
+ hue = torch.fmod(angle, torch.tensor(2 * np.pi))
523
+ sat = torch.ones_like(hue)
524
+ val = speed
525
+
526
+ hsv = torch.cat([hue, sat, val], -3)
527
+ rgb = kornia.color.hsv_to_rgb(hsv)
528
+ return rgb
529
+
530
+ def make_colorwheel(self):
531
+ """
532
+ Generates a color wheel for optical flow visualization as presented in:
533
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
534
+ """
535
+ RY = 15
536
+ YG = 6
537
+ GC = 4
538
+ CB = 11
539
+ BM = 13
540
+ MR = 6
541
+
542
+ ncols = RY + YG + GC + CB + BM + MR
543
+ colorwheel = np.zeros((ncols, 3))
544
+ col = 0
545
+
546
+ # RY
547
+ colorwheel[0:RY, 0] = 255
548
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
549
+ col += RY
550
+ # YG
551
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
552
+ colorwheel[col:col + YG, 1] = 255
553
+ col += YG
554
+ # GC
555
+ colorwheel[col:col + GC, 1] = 255
556
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
557
+ col += GC
558
+ # CB
559
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB)
560
+ colorwheel[col:col + CB, 2] = 255
561
+ col += CB
562
+ # BM
563
+ colorwheel[col:col + BM, 2] = 255
564
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
565
+ col += BM
566
+ # MR
567
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR)
568
+ colorwheel[col:col + MR, 0] = 255
569
+ return colorwheel
cwm/eval/Flow/flow_utils_legacy.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
2
+ mask_generator,
3
+ img1,
4
+ img2,
5
+ neg_back_flow=True,
6
+ num_scales=1,
7
+ min_scale=400,
8
+ N_mask_samples=100,
9
+ mask_ratio=0.8,
10
+ smoothing_factor=1):
11
+ B = img1.shape[0]
12
+ assert len(img1.shape) == 4
13
+ assert num_scales >= 1
14
+
15
+ # For scaling
16
+ h1 = img2.shape[-2]
17
+ w1 = img2.shape[-1]
18
+ assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible
19
+
20
+ if neg_back_flow is False:
21
+ print('WARNING: Not calculating negative backward flow')
22
+
23
+ alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
24
+
25
+ frame_size = 224 // generator.patch_size[-1]
26
+
27
+ all_fwd_flows_e2d = []
28
+
29
+ s_hs = []
30
+ s_ws = []
31
+
32
+ for aidx in range(num_scales):
33
+ print(aidx)
34
+
35
+ # print('aidx: ', aidx)
36
+
37
+ img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
38
+ mode='bicubic', align_corners=True)
39
+ img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
40
+ mode='bicubic', align_corners=True)
41
+
42
+ h2 = img2_scaled.shape[-2]
43
+ w2 = img2_scaled.shape[-1]
44
+
45
+ s_h = h1 / h2
46
+ s_w = w1 / w2
47
+
48
+ s_hs.append(s_h)
49
+ s_ws.append(s_w)
50
+
51
+ # Because technically the compute_optical_flow function returns neg back flow
52
+ if neg_back_flow is True:
53
+ video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
54
+ else:
55
+ video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1)
56
+
57
+ # Should work, even if the incoming video is already 224x224
58
+ crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
59
+
60
+ num_crops = len(crops1)
61
+
62
+ crop_flows_enc = []
63
+ crop_flows_enc2dec = []
64
+ N_samples = N_mask_samples
65
+
66
+ crop = torch.cat(crops1, 0).cuda()
67
+
68
+ optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
69
+ mask_counts = torch.zeros(frame_size, frame_size).cuda()
70
+
71
+ i = 0
72
+ while i < N_samples or (mask_counts == 0).any().item():
73
+ if i % 100 == 0:
74
+ pass # print(i)
75
+ mask_generator.mask_ratio = mask_ratio
76
+
77
+ # This would be that every sample has the same mask. For now that's okay I think
78
+ mask = mask_generator()[None]
79
+ mask_2f = ~mask[0, frame_size * frame_size:]
80
+ mask_counts += mask_2f.reshape(frame_size, frame_size)
81
+
82
+ with torch.cuda.amp.autocast(enabled=True):
83
+
84
+ processed_x = generator._preprocess(crop)
85
+
86
+ encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
87
+ encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out)
88
+
89
+ optical_flow_e2d = []
90
+ # one per batch element for now
91
+ for b in range(B * num_crops):
92
+ batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size)
93
+ optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
94
+
95
+ optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
96
+ optical_flows_enc2dec += optical_flow_e2d
97
+ i += 1
98
+
99
+ optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
100
+
101
+ # split the crops back up
102
+ crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
103
+
104
+ T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
105
+ crop_flows_enc2dec]
106
+
107
+ optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
108
+ B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
109
+
110
+ all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
111
+
112
+ all_fwd_flows_e2d_new = []
113
+
114
+ for ridx, r in enumerate(all_fwd_flows_e2d):
115
+ # print('ridx', ridx)
116
+ # print('sh', s_hs[ridx])
117
+ # print('sw', s_ws[ridx])
118
+ # print('scale_fac y', scale_ys[ridx])
119
+ # print('scale_fac x', scale_xs[ridx])
120
+
121
+ _sh = s_hs[ridx]
122
+ _sw = s_ws[ridx]
123
+ _sfy = generator.patch_size[-1]
124
+ _sfx = generator.patch_size[-1]
125
+
126
+ # plt.figure(figsize=(20, 20))
127
+
128
+ # plt.subplot(1,3,1)
129
+ # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
130
+
131
+ # plt.subplot(1,3,2)
132
+ new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])],
133
+ mode='bicubic', align_corners=True)
134
+ # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
135
+
136
+ scaled_new_r = torch.zeros_like(new_r)
137
+ scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
138
+ scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
139
+
140
+ # plt.subplot(1,3,3)
141
+ # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
142
+
143
+ # plt.show()
144
+
145
+ all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
146
+ return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
147
+
148
+ if neg_back_flow is True:
149
+ return_flow = -return_flow
150
+ all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
151
+
152
+ return return_flow, all_fwd_flows_e2d_new
cwm/eval/Flow/generator.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ import cwm.eval.Flow.masking_flow as masking
9
+
10
+
11
+ def boltzmann(x, beta=1, eps=1e-9):
12
+ if beta is None:
13
+ return x
14
+ x = torch.exp(x * beta)
15
+ return x / x.amax((-1,-2), keepdim=True).clamp(min=eps)
16
+
17
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
18
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
19
+
20
+ def imagenet_normalize(x, temporal_dim=1):
21
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x)
22
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x)
23
+ if temporal_dim == 2:
24
+ mean = mean.transpose(1,2)
25
+ std = std.transpose(1,2)
26
+ return (x - mean) / std
27
+
28
+ def imagenet_unnormalize(x, temporal_dim=2):
29
+ device = x.device
30
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x)
31
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x)
32
+ if temporal_dim == 2:
33
+ mean = mean.transpose(1,2)
34
+ std = std.transpose(1,2)
35
+ x = x*std + mean
36
+ return x
37
+
38
+
39
+
40
+ def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32):
41
+ static = False
42
+ if seq_length == 0:
43
+ static = True
44
+ seq_length = 1
45
+ B = batch_size
46
+ T = seq_length
47
+ H,W = imsize
48
+ ones = torch.ones([B,H,W,1], dtype=dtype_out)
49
+ if normalize:
50
+ h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out))
51
+ h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
52
+ w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out))
53
+ w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
54
+ else:
55
+ h = torch.arange(H).to(ones).view(1,H,1,1) * ones
56
+ w = torch.arange(W).to(ones).view(1,1,W,1) * ones
57
+ h = torch.stack([h]*T, 1)
58
+ w = torch.stack([w]*T, 1)
59
+ hw_ims = torch.cat([h,w], -1)
60
+ if static:
61
+ hw_ims = hw_ims[:,0]
62
+ return hw_ims
63
+
64
+
65
+ def get_distribution_centroid(dist, eps=1e-9, normalize=False):
66
+
67
+ B,T,C,H,W = dist.shape
68
+ assert C == 1
69
+ dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps)
70
+ dist = dist / dist_sum
71
+
72
+ grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device)
73
+ grid = grid.permute(0,1,4,2,3)
74
+ centroid = (grid * dist).sum((-2,-1))
75
+ return centroid
76
+
77
+
78
+
79
+ class FlowToRgb(object):
80
+
81
+ def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
82
+ self.max_speed = max_speed
83
+ self.from_image_coordinates = from_image_coordinates
84
+ self.from_sampling_grid = from_sampling_grid
85
+
86
+ def __call__(self, flow):
87
+ assert flow.size(-3) == 2, flow.shape
88
+ if self.from_sampling_grid:
89
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
90
+ flow_y = -flow_y
91
+ elif not self.from_image_coordinates:
92
+ flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
93
+ else:
94
+ flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
95
+ flow_x, flow_y = [flow_w, -flow_h]
96
+
97
+ angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
98
+ speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
99
+
100
+ hue = torch.fmod(angle, torch.tensor(2 * np.pi))
101
+ sat = torch.ones_like(hue)
102
+ val = speed
103
+
104
+ hsv = torch.cat([hue, sat, val], -3)
105
+ rgb = kornia.color.hsv_to_rgb(hsv)
106
+ return rgb
107
+
108
+ class Patchify(nn.Module):
109
+ """Convert a set of images or a movie into patch vectors"""
110
+
111
+ def __init__(self,
112
+ patch_size=(16, 16),
113
+ temporal_dim=1,
114
+ squeeze_channel_dim=True
115
+ ):
116
+ super().__init__()
117
+ self.set_patch_size(patch_size)
118
+ self.temporal_dim = temporal_dim
119
+ assert self.temporal_dim in [1, 2], self.temporal_dim
120
+ self._squeeze_channel_dim = squeeze_channel_dim
121
+
122
+ @property
123
+ def num_patches(self):
124
+ if (self.T is None) or (self.H is None) or (self.W is None):
125
+ return None
126
+ else:
127
+ return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw)
128
+
129
+ def set_patch_size(self, patch_size):
130
+ self.patch_size = patch_size
131
+ if len(self.patch_size) == 2:
132
+ self.ph, self.pw = self.patch_size
133
+ self.pt = 1
134
+ self._patches_are_3d = False
135
+ elif len(self.patch_size) == 3:
136
+ self.pt, self.ph, self.pw = self.patch_size
137
+ self._patches_are_3d = True
138
+ else:
139
+ raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size)
140
+
141
+ self.shape_inp = self.rank_inp = self.H = self.W = self.T = None
142
+ self.D = self.C = self.E = self.embed_dim = None
143
+
144
+ def _check_shape(self, x):
145
+ self.shape_inp = x.shape
146
+ self.rank_inp = len(self.shape_inp)
147
+ self.H, self.W = self.shape_inp[-2:]
148
+ assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size)
149
+ if (self.rank_inp == 5) and self._patches_are_3d:
150
+ self.T = self.shape_inp[self.temporal_dim]
151
+ assert (self.T % self.pt) == 0, (self.T, self.pt)
152
+ elif self.rank_inp == 5:
153
+ self.T = self.shape_inp[self.temporal_dim]
154
+ else:
155
+ self.T = 1
156
+
157
+ def split_by_time(self, x):
158
+ shape = x.shape
159
+ assert shape[1] % self.T == 0, (shape, self.T)
160
+ return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:])
161
+
162
+ def merge_by_time(self, x):
163
+ shape = x.shape
164
+ return x.view(shape[0], shape[1] * shape[2], *shape[3:])
165
+
166
+ def video_to_patches(self, x):
167
+ if self.rank_inp == 4:
168
+ assert self.pt == 1, (self.pt, x.shape)
169
+ x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw)
170
+ else:
171
+ assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp)
172
+ dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)'
173
+ x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw)
174
+
175
+ self.N, self.D, self.C = x.shape[-3:]
176
+ self.embed_dim = self.E = self.D * self.C
177
+ return x
178
+
179
+ def patches_to_video(self, x):
180
+ shape = x.shape
181
+ rank = len(shape)
182
+ if rank == 4:
183
+ B, _N, _D, _C = shape
184
+ else:
185
+ assert rank == 3, rank
186
+ B, _N, _E = shape
187
+ assert (_E % self.D == 0), (_E, self.D)
188
+ x = x.view(B, _N, self.D, -1)
189
+
190
+ if _N < self.num_patches:
191
+ masked_patches = self.get_masked_patches(
192
+ x,
193
+ num_patches=(self.num_patches - _N),
194
+ mask_mode=self.mask_mode)
195
+ x = torch.cat([x, masked_patches], 1)
196
+
197
+ x = rearrange(
198
+ x,
199
+ 'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)',
200
+ pt=self.pt, ph=self.ph, pw=self.pw,
201
+ t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw))
202
+
203
+ if self.rank_inp == 5 and (self.temporal_dim == 1):
204
+ x = x.transpose(1, 2)
205
+ elif self.rank_inp == 4:
206
+ assert x.shape[2] == 1, x.shape
207
+ x = x[:, :, 0]
208
+ return x
209
+
210
+ @staticmethod
211
+ def get_masked_patches(x, num_patches, mask_mode='zeros'):
212
+ shape = x.shape
213
+ patches_shape = (shape[0], num_patches, *shape[2:])
214
+ if mask_mode == 'zeros':
215
+ return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach()
216
+ elif mask_mode == 'gray':
217
+ return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach()
218
+ else:
219
+ raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode)
220
+
221
+ def average_within_patches(self, z):
222
+ if len(z.shape) == 3:
223
+ z = rearrange(z, 'b n (d c) -> b n d c', c=self.C)
224
+ return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1)
225
+
226
+ def forward(self, x, to_video=False, mask_mode='zeros'):
227
+ if not to_video:
228
+ self._check_shape(x)
229
+ x = self.video_to_patches(x)
230
+ return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1)
231
+
232
+ else: # x are patches
233
+ assert (self.shape_inp is not None) and (self.num_patches is not None)
234
+ self.mask_mode = mask_mode
235
+ x = self.patches_to_video(x)
236
+ return x
237
+
238
+
239
+ class DerivativeFlowGenerator(nn.Module):
240
+ """Estimate flow of a two-frame predictor using torch autograd"""
241
+
242
+ def __init__(self,
243
+ predictor,
244
+ perturbation_patch_size=None,
245
+ aggregation_patch_size=None,
246
+ agg_power=None,
247
+ agg_channel_func=None,
248
+ num_samples=1,
249
+ leave_one_out_sampling=False,
250
+ average_jacobian=True,
251
+ confidence_thresh=None,
252
+ temporal_dim=2,
253
+ imagenet_normalize_inputs=True):
254
+
255
+ super(DerivativeFlowGenerator, self).__init__()
256
+
257
+ self.predictor = predictor
258
+
259
+ self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True)
260
+
261
+ self.set_temporal_dim(temporal_dim)
262
+
263
+ self.imagenet_normalize_inputs = imagenet_normalize_inputs
264
+
265
+ self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size
266
+ self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size
267
+ self.agg_patchify = Patchify(self.aggregation_patch_size,
268
+ temporal_dim=1,
269
+ squeeze_channel_dim=False)
270
+ self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True))
271
+ self.average_jacobian = average_jacobian
272
+ self.confidence_thresh = confidence_thresh
273
+
274
+ self.num_samples = num_samples
275
+ self.leave_one_out_sampling = leave_one_out_sampling
276
+ self.agg_power = agg_power
277
+ self.t_dim = temporal_dim
278
+
279
+ def _get_patch_size(self, p):
280
+ if p is None:
281
+ return None
282
+ elif isinstance(p, int):
283
+ return (1, p, p)
284
+ elif len(p) == 2:
285
+ return (1, p[0], p[1])
286
+ else:
287
+ assert len(p) == 3, p
288
+ return (p[0], p[1], p[2])
289
+
290
+ def set_temporal_dim(self, t_dim):
291
+ if t_dim == 1:
292
+ self.predictor.t_dim = 1
293
+ self.predictor.c_dim = 2
294
+ elif t_dim == 2:
295
+ self.predictor.c_dim = 1
296
+ self.predictor.t_dim = 2
297
+ else:
298
+ raise ValueError("temporal_dim must be 1 or 2")
299
+
300
+ @property
301
+ def c_dim(self):
302
+ if self.predictor is None:
303
+ return None
304
+ return self.predictor.c_dim
305
+
306
+ @property
307
+ def patch_size(self):
308
+ if self.predictor is None:
309
+ return None
310
+ elif hasattr(self.predictor, 'patch_size'):
311
+ return self.predictor.patch_size
312
+ elif hasattr(self.predictor.encoder.patch_embed, 'proj'):
313
+ return self.predictor.encoder.patch_embed.proj.kernel_size
314
+ else:
315
+ return None
316
+ @property
317
+ def S(self):
318
+ return self.num_samples
319
+
320
+ @property
321
+ def sequence_length(self):
322
+ if self.predictor is None:
323
+ return None
324
+ elif hasattr(self.predictor, 'sequence_length'):
325
+ return self.predictor.sequence_length
326
+ elif hasattr(self.predictor, 'num_frames'):
327
+ return self.predictor.num_frames
328
+ else:
329
+ return 2
330
+ @property
331
+ def mask_shape(self):
332
+ if self.predictor is None:
333
+ return None
334
+ elif hasattr(self.predictor, 'mask_shape'):
335
+ return self.predictor.mask_shape
336
+
337
+ assert self.patch_size is not None
338
+ pt, ph, pw = self.patch_size
339
+ return (self.sequence_length // pt,
340
+ self.inp_shape[-2] // ph,
341
+ self.inp_shape[-1] // pw)
342
+
343
+ @property
344
+ def perturbation_mask_shape(self):
345
+ return (
346
+ self.mask_shape[0],
347
+ self.inp_shape[-2] // self.perturbation_patch_size[-2],
348
+ self.inp_shape[-1] // self.perturbation_patch_size[-1]
349
+ )
350
+
351
+
352
+
353
+ @property
354
+ def p_mask_shape(self):
355
+ return self.perturbation_mask_shape
356
+
357
+ @property
358
+ def aggregation_mask_shape(self):
359
+ return (
360
+ 1,
361
+ self.inp_shape[-2] // self.aggregation_patch_size[-2],
362
+ self.inp_shape[-1] // self.aggregation_patch_size[-1]
363
+ )
364
+
365
+ @property
366
+ def a_mask_shape(self):
367
+ return self.aggregation_mask_shape
368
+
369
+ def get_perturbation_input(self, x):
370
+ self.set_input(x)
371
+ y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True)
372
+ y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1)
373
+ return y
374
+
375
+ def pred_patches_to_video(self, y, x, mask):
376
+ """input at visible positions, preds at masked positions"""
377
+ B, C = y.shape[0], y.shape[-1]
378
+ self.patchify._check_shape(x)
379
+ self.patchify.D = np.prod(self.patch_size)
380
+ x = self.patchify(x)
381
+ y_out = torch.zeros_like(x)
382
+ x_vis = x[~mask]
383
+
384
+ y_out[~mask] = x_vis.view(-1, C)
385
+ try:
386
+ y_out[mask] = y.view(-1, C)
387
+ except:
388
+ y_out[mask] = y.reshape(-1, C)
389
+
390
+ return self.patchify(y_out, to_video=True)
391
+
392
+ def set_image_size(self, *args, **kwargs):
393
+ assert self.predictor is not None, "Can't set the image size without a predictor"
394
+ if hasattr(self.predictor, 'set_image_size'):
395
+ self.predictor.set_image_size(*args, **kwargs)
396
+ else:
397
+ self.predictor.image_size = args[0]
398
+
399
+ def predict(self, x=None, mask=None, forward_full=False):
400
+ if x is None:
401
+ x = self.x
402
+ if mask is None:
403
+ mask = self.generate_mask(x)
404
+
405
+ self.set_image_size(x.shape[-2:])
406
+ y = self.predictor(
407
+ self._preprocess(x),
408
+ mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full)
409
+
410
+ y = self.pred_patches_to_video(y, x, mask=mask)
411
+
412
+ frame = -1 % y.size(1)
413
+ y = y[:, frame:frame + 1]
414
+
415
+ return y
416
+
417
+ def _get_perturbation_func(self, x=None, mask=None):
418
+
419
+ if (x is not None):
420
+ self.set_input(x, mask)
421
+
422
+ def forward_mini_image(y):
423
+ y = y.repeat_interleave(self.perturbation_patch_size[-2], -2)
424
+ y = y.repeat_interleave(self.perturbation_patch_size[-1], -1)
425
+ x_pred = self.predict(self.x + y, self.mask)
426
+ x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape)
427
+ return x_pred[self.targets]
428
+
429
+ return forward_mini_image
430
+
431
+ def _postprocess_jacobian(self, jac):
432
+ _jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype)
433
+ _jac[self.targets] = jac
434
+ jac = self.agg_channel_func(_jac)
435
+ assert jac.size(-3) == 1, jac.shape
436
+ jac = jac.squeeze(-3)[..., 0, :, :] # derivative w.r.t. first frame and agg channels
437
+ jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1],
438
+ self.B, self.p_mask_shape[-2], self.p_mask_shape[-1])
439
+ bs = torch.arange(0, self.B).long().to(jac.device)
440
+ jac = jac[bs, :, :, bs, :, :] # take diagonal
441
+ return jac
442
+
443
+ def _confident_jacobian(self, jac):
444
+ if self.confidence_thresh is None:
445
+ return torch.ones_like(jac[:, None, ..., 0, 0])
446
+ conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None]
447
+ return conf
448
+
449
+ def set_input(self, x, mask=None, timestamps=None):
450
+ shape = x.shape
451
+ if len(shape) == 4:
452
+ x = x.unsqueeze(1)
453
+ else:
454
+ assert len(shape) == 5, \
455
+ "Input must be a movie of shape [B,T,C,H,W]" + \
456
+ "or a single frame of shape [B,C,H,W]"
457
+
458
+ self.inp_shape = x.shape
459
+ self.x = x
460
+ self.B = self.inp_shape[0]
461
+ self.T = self.inp_shape[1]
462
+ self.C = self.inp_shape[2]
463
+ if mask is not None:
464
+ self.mask = mask
465
+
466
+ if timestamps is not None:
467
+ self.timestamps = timestamps
468
+
469
+ def _preprocess(self, x):
470
+ if self.imagenet_normalize_inputs:
471
+ x = imagenet_normalize(x)
472
+ if self.t_dim != 1:
473
+ x = x.transpose(self.t_dim, self.c_dim)
474
+ return x
475
+
476
+ def _jacobian_to_flows(self, jac):
477
+ if self.agg_power is None:
478
+ jac = (jac == jac.amax((-2, -1), True)).float()
479
+ else:
480
+ jac = torch.pow(jac, self.agg_power)
481
+
482
+ jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:])
483
+ centroids = get_distribution_centroid(jac, normalize=False).view(
484
+ self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2)
485
+ rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2],
486
+ self.a_mask_shape[-1] / self.p_mask_shape[-1]]
487
+ centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2)
488
+
489
+ flows = centroids - \
490
+ coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device)
491
+ flows = flows.permute(0, 3, 1, 2)
492
+ px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1)
493
+ flows *= px_scale
494
+
495
+ return flows
496
+
497
+ def set_targets(self, targets=None, frame=-1):
498
+ frame = frame % self.mask_shape[0]
499
+ if targets is None:
500
+ targets = self.get_mask_image(self.mask)[:, frame:frame + 1]
501
+ else:
502
+ assert len(targets.shape) == 4, targets.shape
503
+ targets = targets[:, frame:frame + 1]
504
+ self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:])
505
+
506
+ def _get_mask_partition(self, mask):
507
+ mask = self.get_mask_image(mask)
508
+ mask_list = masking.partition_masks(
509
+ mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling)
510
+ return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1)
511
+ for m in mask_list]
512
+
513
+ def _compute_jacobian(self, y):
514
+ perturbation_func = self._get_perturbation_func()
515
+ jac = torch.autograd.functional.jacobian(
516
+ perturbation_func,
517
+ y,
518
+ vectorize=False)
519
+ jac = self._postprocess_jacobian(jac)
520
+ return jac
521
+
522
+ def _upsample_mask(self, mask):
523
+ return masking.upsample_masks(
524
+ mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:])
525
+
526
+ def get_mask_image(self, mask, upsample=False, invert=False, shape=None):
527
+ if shape is None:
528
+ shape = self.mask_shape
529
+ mask = mask.view(-1, *shape)
530
+ if upsample:
531
+ mask = self._upsample_mask(mask)
532
+ if invert:
533
+ mask = 1 - mask
534
+ return mask
535
+
536
+ def forward(self, x, mask, targets=None):
537
+ self.set_input(x, mask)
538
+ y = self.get_perturbation_input(x)
539
+ mask_list = self._get_mask_partition(mask)
540
+
541
+ jacobian, flows, confident = [], [], []
542
+ for s, mask_sample in enumerate(mask_list):
543
+ self.set_input(x, mask_sample)
544
+ self.set_targets(targets)
545
+
546
+ import time
547
+ t1 = time.time()
548
+ jac = self._compute_jacobian(y)
549
+ conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:])
550
+ jacobian.append(jac)
551
+ confident.append(conf_jac)
552
+ if not self.average_jacobian:
553
+ flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \
554
+ masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:])
555
+ flows.append(flow)
556
+ t2 = time.time()
557
+ print(t2 - t1)
558
+
559
+ jacobian = torch.stack(jacobian, -1)
560
+ confident = torch.stack(confident, -1)
561
+ valid = torch.stack([masking.upsample_masks(
562
+ self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1)
563
+ valid = valid * confident
564
+
565
+ if self.average_jacobian:
566
+ _valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2)
567
+ jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1)
568
+ flows = self._jacobian_to_flows(jac) * \
569
+ masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:])
570
+ if targets is not None:
571
+ self.set_targets(targets)
572
+ flows *= self.targets
573
+ else:
574
+ flows = torch.stack(flows, -1)
575
+ flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1)
576
+
577
+ valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1)
578
+
579
+ return (jacobian, flows, valid)
cwm/eval/Flow/losses.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision import transforms
4
+
5
+
6
+ def sampling_grid(height, width):
7
+ H, W = height, width
8
+ grid = torch.stack([
9
+ torch.arange(W).view(1, -1).repeat(H, 1),
10
+ torch.arange(H).view(-1, 1).repeat(1, W)
11
+ ], -1)
12
+ grid = grid.view(1, H, W, 2)
13
+ return grid
14
+
15
+
16
+ def normalize_sampling_grid(coords):
17
+ assert len(coords.shape) == 4, coords.shape
18
+ assert coords.size(-1) == 2, coords.shape
19
+ H, W = coords.shape[-3:-1]
20
+ xs, ys = coords.split([1, 1], -1)
21
+ xs = 2 * xs / (W - 1) - 1
22
+ ys = 2 * ys / (H - 1) - 1
23
+ return torch.cat([xs, ys], -1)
24
+
25
+
26
+ def backward_warp(img2, flow, do_mask=False):
27
+ """
28
+ Grid sample from img2 using the flow from img1->img2 to get a prediction of img1.
29
+
30
+ flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels
31
+ should be (x,y) where larger y values correspond to lower parts of the image.
32
+ """
33
+
34
+ ## resize the flow to the image size.
35
+ ## since flow has units of pixels, its values need to be rescaled accordingly.
36
+ if list(img2.shape[-2:]) != list(flow.shape[-2:]):
37
+ scale = [img2.size(-1) / flow.size(-1), # x
38
+ img2.size(-2) / flow.size(-2)] # y
39
+ scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device)
40
+ flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear
41
+
42
+ B, C, H, W = img2.shape
43
+
44
+ ## use flow to warp sampling grid
45
+ grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1)
46
+
47
+ ## put grid in normalized image coordinates
48
+ grid = normalize_sampling_grid(grid)
49
+
50
+ ## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y)
51
+ img1_pred = F.grid_sample(img2, grid, align_corners=True)
52
+
53
+ if do_mask:
54
+ mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \
55
+ (grid[..., 1] > -1) & (grid[..., 1] < 1)
56
+ mask = mask[:, None].to(img2.dtype)
57
+ return (img1_pred, mask)
58
+
59
+ else:
60
+ return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float())
cwm/eval/Flow/masking_flow.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+
7
+ def upsample_masks(masks, size, thresh=0.5):
8
+ shape = masks.shape
9
+ dtype = masks.dtype
10
+ h, w = shape[-2:]
11
+ H, W = size
12
+ if (H == h) and (W == w):
13
+ return masks
14
+ elif (H < h) and (W < w):
15
+ s = (h // H, w // W)
16
+ return masks[..., ::s[0], ::s[1]]
17
+
18
+ masks = masks.unsqueeze(-2).unsqueeze(-1)
19
+ masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w)
20
+ if ((H % h) == 0) and ((W % w) == 0):
21
+ masks = masks.view(*shape[:-2], H, W)
22
+ else:
23
+ _H = np.prod(masks.shape[-4:-2])
24
+ _W = np.prod(masks.shape[-2:])
25
+ masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh
26
+ masks = masks.view(*shape[:2], H, W).to(masks.dtype)
27
+ return masks
28
+
29
+
30
+
31
+
32
+ def partition_masks(masks, num_samples=2, leave_one_out=False):
33
+ B = masks.shape[0]
34
+ S = num_samples
35
+ masks = masks.view(B, -1)
36
+ partitioned = [torch.ones_like(masks) for _ in range(S)]
37
+ for b in range(B):
38
+ vis_inds = torch.where(~masks[b])[0]
39
+ vis_inds = vis_inds[torch.randperm(vis_inds.size(0))]
40
+ if leave_one_out:
41
+ for s in range(S):
42
+ partitioned[s][b][vis_inds] = 0
43
+ partitioned[s][b][vis_inds[s::S]] = 1
44
+ else:
45
+ for s in range(S):
46
+ partitioned[s][b][vis_inds[s::S]] = 0
47
+ return partitioned
48
+
49
+
50
+ class RectangularizeMasks(nn.Module):
51
+ """Make sure all masks in a batch have same number of 1s and 0s"""
52
+
53
+ def __init__(self, truncation_mode='min'):
54
+ super().__init__()
55
+ self._mode = truncation_mode
56
+ assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode)
57
+
58
+ def set_mode(self, mode):
59
+ self._mode = mode
60
+
61
+ def __call__(self, masks):
62
+
63
+ if self._mode in ['none', None]:
64
+ return masks
65
+
66
+ assert isinstance(masks, torch.Tensor), type(masks)
67
+ if self._mode == 'full':
68
+ return torch.ones_like(masks)
69
+
70
+ shape = masks.shape
71
+ masks = masks.flatten(1)
72
+ B, N = masks.shape
73
+ num_masked = masks.float().sum(-1)
74
+ M = {
75
+ 'min': torch.amin, 'max': torch.amax, 'mean': torch.mean
76
+ }[self._mode](num_masked).long()
77
+
78
+ num_changes = num_masked.long() - M
79
+
80
+ for b in range(B):
81
+ nc = num_changes[b]
82
+ if nc > 0:
83
+ inds = torch.where(masks[b])[0]
84
+ inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)]
85
+ masks[b, inds] = 0
86
+ elif nc < 0:
87
+ inds = torch.where(~masks[b])[0]
88
+ inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)]
89
+ masks[b, inds] = 1
90
+ if list(masks.shape) != list(shape):
91
+ masks = masks.view(*shape)
92
+
93
+ return masks
94
+
95
+
96
+ class UniformMaskingGenerator(object):
97
+ def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False):
98
+ self.frames = None
99
+ if len(input_size) == 3:
100
+ self.frames, self.height, self.width = input_size
101
+ elif len(input_size) == 2:
102
+ self.height, self.width = input_size
103
+ elif len(input_size) == 1 or isinstance(input_size, int):
104
+ self.height = self.width = input_size
105
+
106
+ self.clumping_factor = clumping_factor
107
+ self.pad_h = self.height % self.c[0]
108
+ self.pad_w = self.width % self.c[1]
109
+ self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
110
+ self.mask_ratio = mask_ratio
111
+
112
+ self.rng = np.random.RandomState(seed=seed)
113
+ self.randomize_num_visible = randomize_num_visible
114
+
115
+ @property
116
+ def num_masks_per_frame(self):
117
+ if not hasattr(self, '_num_masks_per_frame'):
118
+ self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
119
+ return self._num_masks_per_frame
120
+
121
+ @num_masks_per_frame.setter
122
+ def num_masks_per_frame(self, val):
123
+ self._num_masks_per_frame = val
124
+ self._mask_ratio = (val / self.num_patches_per_frame)
125
+
126
+ @property
127
+ def c(self):
128
+ if isinstance(self.clumping_factor, int):
129
+ return (self.clumping_factor, self.clumping_factor)
130
+ else:
131
+ return self.clumping_factor[:2]
132
+
133
+ @property
134
+ def mask_ratio(self):
135
+ return self._mask_ratio
136
+
137
+ @mask_ratio.setter
138
+ def mask_ratio(self, val):
139
+ self._mask_ratio = val
140
+ self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
141
+
142
+ @property
143
+ def num_visible(self):
144
+ return self.num_patches_per_frame - self.num_masks_per_frame
145
+
146
+ @num_visible.setter
147
+ def num_visible(self, val):
148
+ self.num_masks_per_frame = self.num_patches_per_frame - val
149
+
150
+ def __repr__(self):
151
+ repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format(
152
+ self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible
153
+ )
154
+ return repr_str
155
+
156
+ def sample_mask_per_frame(self):
157
+ num_masks = self.num_masks_per_frame
158
+ if self.randomize_num_visible:
159
+ num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
160
+ mask = np.hstack([
161
+ np.zeros(self.num_patches_per_frame - num_masks),
162
+ np.ones(num_masks)])
163
+ self.rng.shuffle(mask)
164
+ if max(*self.c) > 1:
165
+ mask = mask.reshape(self.height // self.c[0],
166
+ 1,
167
+ self.width // self.c[1],
168
+ 1)
169
+ mask = np.tile(mask, (1, self.c[0], 1, self.c[1]))
170
+ mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w))
171
+ _pad_h = self.rng.choice(range(self.pad_h + 1))
172
+ pad_h = (self.pad_h - _pad_h, _pad_h)
173
+ _pad_w = self.rng.choice(range(self.pad_w + 1))
174
+ pad_w = (self.pad_w - _pad_w, _pad_w)
175
+ mask = np.pad(mask,
176
+ (pad_h, pad_w),
177
+ constant_values=1
178
+ ).reshape((self.height, self.width))
179
+ return mask
180
+
181
+ def __call__(self, num_frames=None):
182
+ num_frames = (num_frames or self.frames) or 1
183
+ masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten()
184
+ return masks
185
+
186
+
187
+ class TubeMaskingGenerator(UniformMaskingGenerator):
188
+
189
+ def __call__(self, num_frames=None):
190
+ num_frames = (num_frames or self.frames) or 1
191
+ masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten()
192
+ return masks
193
+
194
+
195
+ class RotatedTableMaskingGenerator(TubeMaskingGenerator):
196
+
197
+ def __init__(self, tube_length=None, *args, **kwargs):
198
+ super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs)
199
+ self.tube_length = tube_length
200
+
201
+ def __call__(self, num_frames=None):
202
+ num_frames = (num_frames or self.frames) or 2
203
+ tube_length = self.tube_length or (num_frames - 1)
204
+ table_thickness = num_frames - tube_length
205
+ assert tube_length < num_frames, (tube_length, num_frames)
206
+
207
+ tubes = super().__call__(num_frames=tube_length)
208
+ top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten()
209
+ masks = np.concatenate([top, tubes], 0)
210
+ return masks
211
+
212
+
213
+ class PytorchMaskGeneratorWrapper(nn.Module):
214
+ """Pytorch wrapper for numpy masking generators"""
215
+
216
+ def __init__(self,
217
+ mask_generator=TubeMaskingGenerator,
218
+ *args, **kwargs):
219
+ super().__init__()
220
+ self.mask_generator = mask_generator(*args, **kwargs)
221
+
222
+ @property
223
+ def mask_ratio(self):
224
+ return self.mask_generator.mask_ratio
225
+
226
+ @mask_ratio.setter
227
+ def mask_ratio(self, value):
228
+ self.mask_generator.mask_ratio = value
229
+
230
+ def forward(self, device='cuda', dtype_out=torch.bool, **kwargs):
231
+ masks = self.mask_generator(**kwargs)
232
+ masks = torch.tensor(masks).to(device).to(dtype_out)
233
+ return masks
234
+
235
+
236
+ class MaskingGenerator(nn.Module):
237
+ """Pytorch base class for masking generators"""
238
+
239
+ def __init__(self,
240
+ input_size,
241
+ mask_ratio,
242
+ seed=0,
243
+ visible_frames=0,
244
+ clumping_factor=1,
245
+ randomize_num_visible=False,
246
+ create_on_cpu=True,
247
+ always_batch=False):
248
+ super().__init__()
249
+ self.frames = None
250
+
251
+ if len(input_size) == 3:
252
+ self.frames, self.height, self.width = input_size
253
+ elif len(input_size) == 2:
254
+ self.height, self.width = input_size
255
+ elif len(input_size) == 1 or isinstance(input_size, int):
256
+ self.height = self.width = input_size
257
+
258
+ self.clumping_factor = clumping_factor
259
+ self.pad_h = self.height % self.c[0]
260
+ self.pad_w = self.width % self.c[1]
261
+ self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
262
+
263
+ self.mask_ratio = mask_ratio
264
+ self.visible_frames = visible_frames
265
+ self.always_batch = always_batch
266
+ self.create_on_cpu = create_on_cpu
267
+
268
+ self.rng = np.random.RandomState(seed=seed)
269
+ self._set_torch_seed(seed)
270
+
271
+ self.randomize_num_visible = randomize_num_visible
272
+
273
+ @property
274
+ def num_masks_per_frame(self):
275
+ if not hasattr(self, '_num_masks_per_frame'):
276
+ self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
277
+ return self._num_masks_per_frame
278
+
279
+ @num_masks_per_frame.setter
280
+ def num_masks_per_frame(self, val):
281
+ self._num_masks_per_frame = val
282
+ self._mask_ratio = (val / self.num_patches_per_frame)
283
+
284
+ @property
285
+ def c(self):
286
+ if isinstance(self.clumping_factor, int):
287
+ return (self.clumping_factor,) * 2
288
+ else:
289
+ return self.clumping_factor[:2]
290
+
291
+ @property
292
+ def mask_ratio(self):
293
+ return self._mask_ratio
294
+
295
+ @mask_ratio.setter
296
+ def mask_ratio(self, val):
297
+ self._mask_ratio = val
298
+ self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
299
+
300
+ @property
301
+ def num_visible(self):
302
+ return self.num_patches_per_frame - self.num_masks_per_frame
303
+
304
+ @num_visible.setter
305
+ def num_visible(self, val):
306
+ self.num_masks_per_frame = self.num_patches_per_frame - val
307
+
308
+ def _set_torch_seed(self, seed):
309
+ self.seed = seed
310
+ torch.manual_seed(self.seed)
311
+
312
+ def __repr__(self):
313
+ repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \
314
+ "mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \
315
+ "randomize num visible? {}").format(
316
+ type(self).__name__, self.num_patches_per_frame,
317
+ self.num_masks_per_frame, self.num_visible, self.mask_ratio,
318
+ self.randomize_num_visible
319
+ )
320
+ return repr_str
321
+
322
+ def sample_mask_per_frame(self, *args, **kwargs):
323
+ num_masks = self.num_masks_per_frame
324
+ if self.randomize_num_visible:
325
+ num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
326
+
327
+ mask = torch.cat([
328
+ torch.zeros([self.num_patches_per_frame - num_masks]),
329
+ torch.ones([num_masks])], 0).bool()
330
+ inds = torch.randperm(mask.size(0)).long()
331
+ mask = mask[inds]
332
+
333
+ if max(*self.c) > 1:
334
+ mask = mask.view(self.height // self.c[0],
335
+ 1,
336
+ self.width // self.c[1],
337
+ 1)
338
+ mask = torch.tile(mask, (1, self.c[0], 1, self.c[1]))
339
+ mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w)
340
+ _pad_h = self.rng.choice(range(self.pad_h + 1))
341
+ pad_h = (self.pad_h - _pad_h, _pad_h)
342
+ _pad_w = self.rng.choice(range(self.pad_w + 1))
343
+ pad_w = (self.pad_w - _pad_w, _pad_w)
344
+ mask = F.pad(mask,
345
+ pad_w + pad_h,
346
+ mode='constant',
347
+ value=1)
348
+ mask = mask.reshape(self.height, self.width)
349
+
350
+ return mask
351
+
352
+ def forward(self, x=None, num_frames=None):
353
+
354
+ num_frames = (num_frames or self.frames) or 1
355
+ if isinstance(x, torch.Tensor):
356
+ batch_size = x.size(0)
357
+ masks = torch.stack([
358
+ torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
359
+ for b in range(batch_size)], 0)
360
+ if not self.create_on_cpu:
361
+ masks = masks.to(x.device)
362
+ if batch_size == 1 and not self.always_batch:
363
+ masks = masks.squeeze(0)
364
+ else:
365
+ batch_size = 1
366
+ masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
367
+ if self.always_batch:
368
+ masks = masks[None]
369
+
370
+ if self.visible_frames > 0:
371
+ vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool)
372
+ vis = vis.view(masks.shape).to(masks.device)
373
+ masks = torch.cat(([vis] * self.visible_frames) + [masks], -1)
374
+
375
+ return masks
cwm/eval/Flow/vis_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def imshow(ims, ax=None, t=0, vmin=None, vmax=None, title=None, cmap=None, fontsize=20):
7
+ if ax is None:
8
+ fig, ax = plt.subplots(1,1)
9
+ with torch.no_grad():
10
+ im = ims[t].float().cpu().numpy().transpose((1,2,0))
11
+ if (vmin is not None) and (vmax is not None):
12
+ im =ax.imshow(im, vmin=vmin, vmax=vmax, cmap=(cmap or 'viridis'))
13
+ else:
14
+ im =ax.imshow(im)
15
+
16
+ if title is not None:
17
+ ax.set_title(title, fontsize=fontsize)
18
+
19
+ return (im, ax)
20
+
21
+ def make_colorwheel():
22
+ """
23
+ Generates a color wheel for optical flow visualization as presented in:
24
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
25
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
26
+
27
+ Code follows the original C++ source code of Daniel Scharstein.
28
+ Code follows the the Matlab source code of Deqing Sun.
29
+
30
+ Returns:
31
+ np.ndarray: Color wheel
32
+ """
33
+
34
+ RY = 15
35
+ YG = 6
36
+ GC = 4
37
+ CB = 11
38
+ BM = 13
39
+ MR = 6
40
+
41
+ ncols = RY + YG + GC + CB + BM + MR
42
+ colorwheel = np.zeros((ncols, 3))
43
+ col = 0
44
+
45
+ # RY
46
+ colorwheel[0:RY, 0] = 255
47
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
48
+ col = col+RY
49
+ # YG
50
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
51
+ colorwheel[col:col+YG, 1] = 255
52
+ col = col+YG
53
+ # GC
54
+ colorwheel[col:col+GC, 1] = 255
55
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
56
+ col = col+GC
57
+ # CB
58
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
59
+ colorwheel[col:col+CB, 2] = 255
60
+ col = col+CB
61
+ # BM
62
+ colorwheel[col:col+BM, 2] = 255
63
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
64
+ col = col+BM
65
+ # MR
66
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
67
+ colorwheel[col:col+MR, 0] = 255
68
+ return colorwheel
69
+
70
+
71
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
72
+ """
73
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
74
+
75
+ According to the C++ source code of Daniel Scharstein
76
+ According to the Matlab source code of Deqing Sun
77
+
78
+ Args:
79
+ u (np.ndarray): Input horizontal flow of shape [H,W]
80
+ v (np.ndarray): Input vertical flow of shape [H,W]
81
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
82
+
83
+ Returns:
84
+ np.ndarray: Flow visualization image of shape [H,W,3]
85
+ """
86
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
87
+ colorwheel = make_colorwheel() # shape [55x3]
88
+ ncols = colorwheel.shape[0]
89
+ rad = np.sqrt(np.square(u) + np.square(v))
90
+ a = np.arctan2(-v, -u)/np.pi
91
+ fk = (a+1) / 2*(ncols-1)
92
+ k0 = np.floor(fk).astype(np.int32)
93
+ k1 = k0 + 1
94
+ k1[k1 == ncols] = 0
95
+ f = fk - k0
96
+ for i in range(colorwheel.shape[1]):
97
+ tmp = colorwheel[:,i]
98
+ col0 = tmp[k0] / 255.0
99
+ col1 = tmp[k1] / 255.0
100
+ col = (1-f)*col0 + f*col1
101
+ idx = (rad <= 1)
102
+ col[idx] = 1 - rad[idx] * (1-col[idx])
103
+ col[~idx] = col[~idx] * 0.75 # out of range
104
+ # Note the 2-i => BGR instead of RGB
105
+ ch_idx = 2-i if convert_to_bgr else i
106
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
107
+ return flow_image
108
+
109
+
110
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
111
+ """
112
+ Expects a two dimensional flow image of shape.
113
+
114
+ Args:
115
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
116
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
117
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
118
+
119
+ Returns:
120
+ np.ndarray: Flow visualization image of shape [H,W,3]
121
+ """
122
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
123
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
124
+ if clip_flow is not None:
125
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
126
+ u = flow_uv[:,:,0]
127
+ v = flow_uv[:,:,1]
128
+ rad = np.sqrt(np.square(u) + np.square(v))
129
+ rad_max = np.max(rad)
130
+ epsilon = 1e-5
131
+ u = u / (rad_max + epsilon)
132
+ v = v / (rad_max + epsilon)
133
+ return flow_uv_to_colors(u, v, convert_to_bgr)
134
+
135
+ from decord import VideoReader, cpu
136
+ from PIL import Image
137
+ from torchvision import transforms
138
+ def get_video(video_name, num_frames=2, delta_time=4, frame=None):
139
+ decord_vr = VideoReader(video_name, num_threads=1, ctx=cpu(0))
140
+ max_end_ind = len(decord_vr) - num_frames*delta_time - 1
141
+ start_frame = frame if frame is not None else rng.randint(1, max_end_ind)
142
+ print("fps", decord_vr.get_avg_fps())
143
+ print("start frame = %d" % start_frame)
144
+ frame_id_list = list(range(start_frame, start_frame + num_frames*delta_time, delta_time))
145
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
146
+ video_data = [Image.fromarray(video_data[t]).convert('RGB') for t, _ in enumerate(frame_id_list)]
147
+ return (torch.stack([transforms.ToTensor()(im) for im in video_data], 0), start_frame)
148
+
149
+
150
+
cwm/eval/IntPhys/__init__.py ADDED
File without changes
cwm/eval/Physion/__init__.py ADDED
File without changes
cwm/eval/Physion/feature_extractor.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor
3
+ from physion_evaluator.utils import DataAugmentationForVideoMAE
4
+
5
+ from torch.functional import F
6
+
7
+ from cwm.eval.Flow.flow_utils import get_occ_masks
8
+
9
+ from cwm.model.model_factory import model_factory
10
+ import torch
11
+
12
+ def load_predictor(
13
+ model_func_,
14
+ load_path_,
15
+ **kwargs):
16
+ predictor = model_func_(**kwargs).eval().requires_grad_(False)
17
+
18
+ did_load = predictor.load_state_dict(
19
+ torch.load(load_path_, map_location=torch.device("cpu"))['model'])
20
+ predictor._predictor_load_path = load_path_
21
+ print(did_load, load_path_)
22
+ return predictor
23
+
24
+
25
+ class CWM(PhysionFeatureExtractor):
26
+ def __init__(self, model_name, aggregate_embeddings=False):
27
+ super().__init__()
28
+
29
+ self.model = model_factory.load_model(model_name).cuda().half()
30
+
31
+ self.num_frames = self.model.num_frames
32
+
33
+ self.timestamps = np.arange(self.num_frames)
34
+
35
+ ps = (224 // self.model.patch_size[1]) ** 2
36
+
37
+ self.bool_masked_pos = np.zeros([ps * self.num_frames])
38
+ self.bool_masked_pos[ps * (self.num_frames - 1):] = 1
39
+
40
+ self.ps = ps
41
+
42
+ self.aggregate_embeddings = aggregate_embeddings
43
+
44
+ def transform(self):
45
+
46
+ return DataAugmentationForVideoMAE(
47
+ imagenet_normalize=True,
48
+ rescale_size=224,
49
+ ), 150, 4
50
+
51
+ def fwd(self, videos):
52
+ bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
53
+ bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
54
+ x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
55
+ return_features=True)
56
+ return x_encoded
57
+
58
+ def extract_features(self, videos, for_flow=False):
59
+ '''
60
+ videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
61
+ returns: [B, T, D] extracted features
62
+ '''
63
+
64
+ videos = videos.transpose(1, 2)
65
+
66
+ all_features = []
67
+
68
+ # repeat the last frame of the video
69
+ videos = torch.cat([videos, videos[:, :, -1:]], dim=2)
70
+
71
+ for x in range(0, 4, self.num_frames - 1):
72
+ vid = videos[:, :, x:x + self.num_frames, :, :]
73
+ all_features.append(self.fwd(vid))
74
+ if self.aggregate_embeddings:
75
+ feats = all_features[-1].mean(dim=1, keepdim=True)
76
+ all_features[-1] = feats
77
+ # feats = feats.view(feats.shape[0], -1, self.model.num_patches_per_frame, feats.shape[-1])
78
+ # feats = feats.mean(dim=2)
79
+ # all_features[-1] = feats
80
+
81
+ x_encoded = torch.cat(all_features, dim=1)
82
+
83
+ return x_encoded
84
+
85
+
86
+ class CWM_Keypoints(PhysionFeatureExtractor):
87
+ def __init__(self, model_name):
88
+ super().__init__()
89
+
90
+ self.model = model_factory.load_model(model_name).cuda().half()
91
+
92
+ self.frames = [[0, 1, 2], [1, 2, 3]]
93
+
94
+ self.num_frames = self.model.num_frames
95
+
96
+ self.ps = (224 // self.model.patch_size[1]) ** 2
97
+
98
+ self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
99
+ self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1
100
+
101
+ self.frame_gap = 150
102
+
103
+ self.num_frames_dataset = 4
104
+
105
+ self.res = 224
106
+
107
+
108
+ def transform(self):
109
+
110
+ return DataAugmentationForVideoMAE(
111
+ imagenet_normalize=True,
112
+ rescale_size=self.res,
113
+ ), self.frame_gap, self.num_frames_dataset
114
+
115
+ def fwd(self, videos):
116
+ bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
117
+ bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
118
+ _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
119
+ return_features=True)
120
+ return x_encoded
121
+
122
+ def extract_features(self, videos, segments=None):
123
+ '''
124
+ videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
125
+ returns: [B, T, D] extracted features
126
+ '''
127
+
128
+ videos = videos.transpose(1, 2)
129
+
130
+ all_features = []
131
+
132
+ for x, arr in enumerate(self.frames):
133
+
134
+ #use the downsampled videos for keypoints
135
+ vid = videos[:, :, arr, :, :].half()
136
+ frame0 = vid[:, :, 0]
137
+ frame1 = vid[:, :, 1]
138
+ frame2 = vid[:, :, 2]
139
+
140
+ #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
141
+ mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1)
142
+
143
+ #reshape the features to [batch size, num_features]
144
+ k_feat = k_feat.view(k_feat.shape[0], -1)
145
+
146
+ all_features.append(k_feat)
147
+
148
+ x_encoded = torch.cat(all_features, dim=1)
149
+
150
+ return x_encoded
151
+
152
+
153
+ class CWM_KeypointsFlow(PhysionFeatureExtractor):
154
+ def __init__(self, model_name):
155
+ super().__init__()
156
+
157
+ self.model = model_factory.load_model(model_name).cuda().half()
158
+
159
+ self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]]
160
+
161
+ self.num_frames = self.model.num_frames
162
+
163
+ self.timestamps = np.arange(self.num_frames)
164
+
165
+ self.ps = (224 // self.model.patch_size[1]) ** 2
166
+
167
+ self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
168
+ self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1
169
+
170
+ self.frame_gap = 50
171
+
172
+ self.num_frames_dataset = 9
173
+
174
+ self.res = 512
175
+
176
+ def transform(self):
177
+
178
+ return DataAugmentationForVideoMAE(
179
+ imagenet_normalize=True,
180
+ rescale_size=self.res,
181
+ ), self.frame_gap, self.num_frames_dataset
182
+
183
+ def fwd(self, videos):
184
+ bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
185
+ bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
186
+ _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
187
+ return_features=True)
188
+ return x_encoded
189
+
190
+ def get_forward_flow(self, videos):
191
+
192
+ fid = 6
193
+
194
+ forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine')
195
+
196
+ backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine')
197
+
198
+ occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0]
199
+
200
+ forward_flow = forward_flow * occlusion_mask
201
+
202
+ forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1)
203
+
204
+ forward_flow = forward_flow.to(videos.device)
205
+
206
+ forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest')
207
+
208
+ return forward_flow
209
+
210
+ def extract_features(self, videos, segments=None):
211
+ '''
212
+ videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
213
+ returns: [B, T, D] extracted features
214
+ Note:
215
+ For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient
216
+ for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed,
217
+ depending on the complexity of the task.
218
+ '''
219
+
220
+
221
+ #resize to 224 to get keypoints and features
222
+ videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False)
223
+ videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224)
224
+
225
+ #for computing flow at higher resolution
226
+ videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False)
227
+ videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024)
228
+
229
+ videos = videos.transpose(1, 2).half()
230
+ videos_downsampled = videos_downsampled.transpose(1, 2).half()
231
+
232
+ # Get the forward flow for the frame at 300ms
233
+ forward_flow = self.get_forward_flow(videos)
234
+
235
+ # Verify that there are no nans forward flow
236
+ assert not torch.isnan(forward_flow).any(), "Forward flow is nan"
237
+
238
+ all_features = []
239
+
240
+ for x, arr in enumerate(self.frames):
241
+
242
+ #use the downsampled videos for keypoints
243
+ vid = videos_downsampled[:, :, arr, :, :]
244
+ frame0 = vid[:, :, 0]
245
+ frame1 = vid[:, :, 1]
246
+ frame2 = vid[:, :, 2]
247
+
248
+ #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
249
+ mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1)
250
+
251
+ #for the last set of frames only use features at keypoint regions of frame2
252
+ if (x == 2):
253
+ k_feat = k_feat[:, -10:, :]
254
+
255
+ #reshape the features to [batch size, num_features]
256
+ k_feat = k_feat.view(k_feat.shape[0], -1)
257
+
258
+ choices_image_resolution = choices * self.model.patch_size[1]
259
+
260
+ # At 300ms, add optical flow patches at the detected keypoint locations
261
+ # For the first frame (x == 0)
262
+ if x == 0:
263
+ # Extract the optical flow information from the forward flow matrix for the second channel (index 2)
264
+ flow_keyp = forward_flow[:, 2]
265
+
266
+ # Initialize a result tensor to store the flow patches
267
+ # Tensor shape: [batch_size, 8x8 patch (flattened to 64) * 2 channels, 10 keypoints]
268
+ flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device)
269
+
270
+ # Patch size shift (since 8x8 patches are being extracted)
271
+ shift = 8
272
+
273
+ # Loop over each element in the batch to process individual video frames
274
+ for b in range(flow_keyp.size(0)):
275
+ # Extract the x and y coordinates of the keypoint locations for this batch element
276
+ x_indices = choices_image_resolution[b, :, 0]
277
+ y_indices = choices_image_resolution[b, :, 1]
278
+
279
+ # For each keypoint (10 total keypoints in this case)
280
+ for ind in range(10):
281
+ # Extract the 8x8 patch of optical flow at each keypoint's (x, y) location
282
+ # Flatten the patch and assign it to the corresponding slice in the result tensor
283
+ flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift,
284
+ x_indices[ind]:x_indices[ind] + shift].flatten()
285
+
286
+ # Reshape the flow tensor for easier concatenation (flatten across all patches)
287
+ flow = flow.view(flow.shape[0], -1)
288
+
289
+ # Concatenate the extracted optical flow features with the existing feature tensor (k_feat)
290
+ k_feat = torch.cat([k_feat, flow], dim=1)
291
+
292
+ all_features.append(k_feat)
293
+
294
+ x_encoded = torch.cat(all_features, dim=1)
295
+
296
+ return x_encoded
297
+
298
+
299
+ class CWM_base_8x8_3frame(CWM):
300
+ def __init__(self,):
301
+ super().__init__('vitb_8x8patch_3frames')
302
+
303
+ class CWM_base_8x8_3frame_mean_embed(CWM):
304
+ def __init__(self,):
305
+ super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True)
306
+
307
+ # CWM* (keypoints only) 74.7
308
+ class CWM_base_8x8_3frame_keypoints(CWM_Keypoints):
309
+ def __init__(self,):
310
+ super().__init__('vitb_8x8patch_3frames')
311
+
312
+
313
+ # CWM* (keypoints + Flow) 75.4
314
+ class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow):
315
+ def __init__(self,):
316
+ super().__init__('vitb_8x8patch_3frames')
317
+
cwm/eval/Physion/flow_utils.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ import math
6
+
7
+ def create_weighted_mask_batched(h, w):
8
+ y_mask = np.linspace(0, 1, h)
9
+ y_mask = np.minimum(y_mask, 1 - y_mask)
10
+ x_mask = np.linspace(0, 1, w)
11
+ x_mask = np.minimum(x_mask, 1 - x_mask)
12
+ weighted_mask = np.outer(y_mask, x_mask)
13
+ return torch.from_numpy(weighted_mask).float()
14
+
15
+ def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
16
+ B, T, C, H, W = original_shape
17
+
18
+ # Initialize an empty tensor to store the reconstructed video
19
+ reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
20
+
21
+ # Create a tensor to store the sum of weighted masks
22
+ weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
23
+
24
+ # Create a weighted mask for the crops
25
+ weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
26
+ weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
27
+
28
+ for idx, crop in enumerate(cropped_tensors):
29
+ start_h, start_w = crop_positions[idx]
30
+
31
+ # Multiply the crop with the weighted mask
32
+ weighted_crop = crop * weighted_mask
33
+
34
+ # Add the weighted crop to the corresponding location in the reconstructed_video tensor
35
+ reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
36
+
37
+ # Update the weighted_masks_sum tensor
38
+ weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
39
+
40
+ # Add a small epsilon value to avoid division by zero
41
+ epsilon = 1e-8
42
+
43
+ # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
44
+ reconstructed_video /= (weighted_masks_sum + epsilon)
45
+
46
+ return reconstructed_video
47
+
48
+ import torch.nn.functional as F
49
+
50
+ resize = lambda x,a: F.interpolate(x, [int(a*x.shape[-2]), int(a*x.shape[-1])], mode='bilinear', align_corners=False)
51
+
52
+ upsample = lambda x,H,W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
53
+
54
+
55
+
56
+ #
57
+ def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
58
+ # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
59
+ mask_unrolled = mask_tensor.view(-1)
60
+
61
+ second_frame_unmask_indices = torch.where(mask_unrolled[frame_size**2:] == False)[0]
62
+
63
+ # Divide the embedding tensor into two parts: corresponding to the first and the second frame
64
+ first_frame_embeddings = embedding_tensor[0, :frame_size**2, :]
65
+ second_frame_embeddings = embedding_tensor[0, frame_size**2:, :]
66
+
67
+ # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
68
+ dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
69
+ norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
70
+ cos_sim_matrix = dot_product / norms
71
+
72
+ # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
73
+ first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
74
+
75
+ # Convert the 1D pixel indices into 2D coordinates
76
+ second_frame_y = second_frame_unmask_indices // frame_size
77
+ second_frame_x = second_frame_unmask_indices % frame_size
78
+ first_frame_y = first_frame_most_similar_indices // frame_size
79
+ first_frame_x = first_frame_most_similar_indices % frame_size
80
+
81
+ # Compute the x and y displacements and convert them to float
82
+ displacements_x = (second_frame_x - first_frame_x).float()
83
+ displacements_y = (second_frame_y - first_frame_y).float()
84
+
85
+ # Initialize optical flow tensor
86
+ optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
87
+
88
+ # Assign the computed displacements to the corresponding pixels in the optical flow tensor
89
+ optical_flow[0, second_frame_y, second_frame_x] = displacements_x
90
+ optical_flow[1, second_frame_y, second_frame_x] = displacements_y
91
+
92
+ return optical_flow
93
+
94
+ def get_minimal_224_crops_new_batched(video_tensor, N):
95
+ B, T, C, H, W = video_tensor.shape
96
+
97
+ # Calculate the number of crops needed in both the height and width dimensions
98
+ num_crops_h = math.ceil(H / 224) if H > 224 else 1
99
+ num_crops_w = math.ceil(W / 224) if W > 224 else 1
100
+
101
+ # Calculate the step size for the height and width dimensions
102
+ step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
103
+ step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
104
+
105
+ # Create a list to store the cropped tensors and their start positions
106
+ cropped_tensors = []
107
+ crop_positions = []
108
+
109
+ # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
110
+ for i in range(num_crops_h):
111
+ for j in range(num_crops_w):
112
+ start_h = i * step_size_h
113
+ start_w = j * step_size_w
114
+ end_h = min(start_h + 224, H)
115
+ end_w = min(start_w + 224, W)
116
+ crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
117
+ cropped_tensors.append(crop)
118
+ crop_positions.append((start_h, start_w))
119
+
120
+ D = len(cropped_tensors)
121
+
122
+ # If N is greater than D, generate additional random crops
123
+ if N > D and H > 224 and W > 224: # check if H and W are greater than 224
124
+ for _ in range(N - D):
125
+ start_h = random.randint(0, H - 224)
126
+ start_w = random.randint(0, W - 224)
127
+ crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
128
+ cropped_tensors.append(crop)
129
+ crop_positions.append((start_h, start_w))
130
+
131
+ # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
132
+ cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
133
+
134
+ return cropped_tensors, crop_positions
135
+
136
+ def get_honglin_3frame_vmae_optical_flow_crop_batched(generator,
137
+ mask_generator,
138
+ img1,
139
+ img2,
140
+ img3,
141
+ neg_back_flow=True,
142
+ num_scales=1,
143
+ min_scale=400,
144
+ N_mask_samples=100,
145
+ mask_ratio=0.8,
146
+ flow_frames='23'):
147
+ B = img1.shape[0]
148
+ assert len(img1.shape) == 4
149
+ assert num_scales >= 1
150
+
151
+ # For scaling
152
+ h1 = img2.shape[-2]
153
+ w1 = img2.shape[-1]
154
+ assert min_scale < h1
155
+
156
+ if neg_back_flow is False:
157
+ print('WARNING: Not calculating negative backward flow')
158
+
159
+ alpha = (min_scale / img1.shape[-2]) ** (1 / 4)
160
+
161
+ frame_size = 224 // generator.patch_size[-1]
162
+
163
+ patch_size = generator.patch_size[-1]
164
+
165
+ all_fwd_flows_e2d = []
166
+
167
+ for aidx in range(num_scales):
168
+
169
+ # print('aidx: ', aidx)
170
+
171
+ img1_scaled = resize(img1.clone(), alpha ** aidx)
172
+ img2_scaled = resize(img2.clone(), alpha ** aidx)
173
+ img3_scaled = resize(img3.clone(), alpha ** aidx)
174
+
175
+ h2 = img2_scaled.shape[-2]
176
+ w2 = img2_scaled.shape[-1]
177
+
178
+ s_h = h1 / h2
179
+ s_w = w1 / w2
180
+
181
+ # Because technically the compute_optical_flow function returns neg back flow
182
+ if neg_back_flow is True:
183
+ video = torch.cat([img3_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
184
+ else:
185
+ video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img3_scaled.unsqueeze(1)], 1)
186
+
187
+ # Should work, even if the incoming video is already 224x224
188
+ crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
189
+
190
+ # print(len(crops1), crops1[0].shape)
191
+
192
+ num_crops = len(crops1)
193
+
194
+ crop_flows_enc = []
195
+ crop_flows_enc2dec = []
196
+ N_samples = N_mask_samples
197
+
198
+ crop = torch.cat(crops1, 0).cuda()
199
+ # print(crop.shape)
200
+
201
+ optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
202
+ mask_counts = torch.zeros(frame_size, frame_size).cuda()
203
+
204
+ i = 0
205
+ while i < N_samples or (mask_counts == 0).any().item():
206
+ if i % 100 == 0:
207
+ pass # print(i)
208
+ mask_generator.mask_ratio = mask_ratio
209
+
210
+ # breakpoint()
211
+ # This would be that every sample has the same mask. For now that's okay I think
212
+ mask = mask_generator(num_frames=3)[None]
213
+ mask_2f = ~mask[0, frame_size * frame_size * 2:]
214
+ mask_counts += mask_2f.reshape(frame_size, frame_size)
215
+
216
+ with torch.cuda.amp.autocast(enabled=True):
217
+
218
+ processed_x = crop.transpose(1, 2)
219
+
220
+ # print("crop", processed_x.max())
221
+
222
+ encoder_out = generator.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
223
+ encoder_to_decoder = generator.encoder_to_decoder(encoder_out)
224
+ # print(encoder_to_decoder.shape)
225
+
226
+ if flow_frames == '23':
227
+ encoder_to_decoder = encoder_to_decoder[:, frame_size * frame_size:, :]
228
+ flow_mask = mask[:, frame_size * frame_size:]
229
+ # print(encoder_to_decoder.shape)
230
+ elif flow_frames == '12':
231
+ encoder_to_decoder = encoder_to_decoder[:, :frame_size * frame_size * 2, :]
232
+ # print(encoder_to_decoder.shape)
233
+ flow_mask = mask[:, :frame_size * frame_size * 2]
234
+ # print(mask.shape)
235
+ # print(flow_mask.shape)
236
+ # print()
237
+
238
+ optical_flow_e2d = []
239
+ # one per batch element for now
240
+ for b in range(B * num_crops):
241
+ batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
242
+ optical_flow_e2d.append(batch_flow.unsqueeze(0))
243
+
244
+ optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
245
+ optical_flows_enc2dec += optical_flow_e2d
246
+ i += 1
247
+
248
+ optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
249
+
250
+ scale_factor_y = video.shape[-2] / 224
251
+ scale_factor_x = video.shape[-1] / 224
252
+
253
+ scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
254
+ scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
255
+ scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
256
+
257
+ # split the crops back up
258
+ crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
259
+ # print(len(crop_flows_enc2dec))
260
+
261
+ optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
262
+ [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
263
+ crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
264
+
265
+ all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
266
+
267
+ all_fwd_flows_e2d_new = []
268
+
269
+ for r in all_fwd_flows_e2d:
270
+ new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
271
+ all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
272
+ return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
273
+
274
+ if neg_back_flow is True:
275
+ return_flow = -return_flow
276
+ all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
277
+
278
+ return return_flow, all_fwd_flows_e2d_new
279
+
cwm/eval/Physion/run_eval.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #physion_feature_extract \
2
+ #--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
3
+ #--data_root_path /ccn2/u/rmvenkat/data/testing_physion/regenerate_from_old_commit/ \
4
+ #--model_class feature_extractor.CWM_base_8x8_3frame \
5
+ #--gpu 1 \
6
+ #--batch_size 8 \
7
+ #--dir_for_saving /ccn2/u/rmvenkat/data/physion_release/ \
8
+ #--mode ocp
9
+
10
+ physion_train_readout \
11
+ --train-path /ccn2/u/rmvenkat/data/physion_release/ocp/train_features.hdf5 \
12
+ --test-path /ccn2/u/rmvenkat/data/physion_release/ocp/test_features.hdf5 \
13
+ --model-name CWM_base_8x8_3frame \
14
+ --train-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/train_json.json \
15
+ --test-scenario-indices /ccn2/u/rmvenkat/data/physion_release/ocp/test_json.json \
16
+ --test-scenario-map /ccn2/u/rmvenkat/data/physion_release/ocp/test_scenario_map.json \
17
+ --save_path /ccn2/u/rmvenkat/data/physion_release/
cwm/eval/Physion/run_eval_kfflow.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #physion_feature_extract \
2
+ #--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
3
+ #--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
4
+ #--model_class feature_extractor.CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \
5
+ #--gpu 1 \
6
+ #--batch_size 8 \
7
+ #--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_kfflow/ \
8
+ #--mode ocp
9
+
10
+
11
+ physion_train_readout \
12
+ --train-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_features.hdf5 \
13
+ --test-path /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_features.hdf5 \
14
+ --model-name CWM_base_8x8_3frame_Keypoints_KFFlowPatched_noF1_cwm_50_occ_mask \
15
+ --train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/train_json.json \
16
+ --test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_json.json \
17
+ --test-scenario-map /ccn2/u/rmvenkat/data/physion_release_kfflow/ocp/test_scenario_map.json \
18
+ --save_path /ccn2/u/rmvenkat/data/physion_release_kfflow/
cwm/eval/Physion/run_eval_mp4s.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dir_for_saving=/ccn2/u/rmvenkat/data/physion_release/
2
+ model_name=CWM_base_8x8_3frame
3
+
4
+ physion_feature_extract \
5
+ --data_root_path /ccn2/u/rmvenkat/data/download_test/physion_mp4s/ \
6
+ --model_class feature_extractor.$model_name \
7
+ --gpu 1 \
8
+ --batch_size 8 \
9
+ --dir_for_saving $dir_for_saving \
10
+ --mode ocd
11
+
12
+ physion_train_readout \
13
+ --train-path ${dir_for_saving}ocd/train_features.hdf5 \
14
+ --test-path ${dir_for_saving}ocd/test_features.hdf5 \
15
+ --model-name $model_name \
16
+ --train-scenario-indices ${dir_for_saving}ocd/train_json.json \
17
+ --test-scenario-indices ${dir_for_saving}ocd/test_json.json \
18
+ --test-scenario-map ${dir_for_saving}ocd/test_scenario_map.json \
19
+ --save_path $dir_for_saving
cwm/eval/Physion/run_eval_mp4s_keyp.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #physion_feature_extract \
2
+ #--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
3
+ #--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
4
+ #--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints \
5
+ #--gpu 3 \
6
+ #--batch_size 8 \
7
+ #--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp/ \
8
+ #--mode ocp
9
+
10
+ physion_train_readout \
11
+ --train-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_features.hdf5 \
12
+ --test-path /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_features.hdf5 \
13
+ --model-name CWM_base_8x8_3frame \
14
+ --train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/train_json.json \
15
+ --test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_json.json \
16
+ --test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp/ocp/test_scenario_map.json \
17
+ --save_path /ccn2/u/rmvenkat/data/physion_release_keyp/
cwm/eval/Physion/run_eval_mp4s_keyp_flow.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #physion_feature_extract \
2
+ #--model_path /ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth \
3
+ #--data_root_path /ccn2/u/rmvenkat/data/physion_mp4s/ \
4
+ #--model_class feature_extractor_cleaned.CWM_base_8x8_3frame_keypoints_flow \
5
+ #--gpu 7 \
6
+ #--batch_size 8 \
7
+ #--dir_for_saving /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ \
8
+ #--mode ocp
9
+
10
+ physion_train_readout \
11
+ --train-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_features.hdf5 \
12
+ --test-path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_features.hdf5 \
13
+ --model-name CWM_base_8x8_3frame \
14
+ --train-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/train_json.json \
15
+ --test-scenario-indices /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_json.json \
16
+ --test-scenario-map /ccn2/u/rmvenkat/data/physion_release_keyp_flow/ocp/test_scenario_map.json \
17
+ --save_path /ccn2/u/rmvenkat/data/physion_release_keyp_flow/
cwm/eval/Segmentation/__init__.py ADDED
File without changes
cwm/eval/Segmentation/archive/__init__.py ADDED
File without changes
cwm/eval/Segmentation/archive/common/__init__.py ADDED
File without changes
cwm/eval/Segmentation/archive/common/coco_loader_lsj.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import detectron2.data.transforms as T
2
+ from detectron2 import model_zoo
3
+ from detectron2.config import LazyCall as L
4
+
5
+ # Data using LSJ
6
+ image_size = 512
7
+ dataloader = model_zoo.get_config("common/data/coco.py").dataloader
8
+ dataloader.train.mapper.augmentations = [
9
+ L(T.RandomFlip)(horizontal=True), # flip first
10
+ L(T.ResizeScale)(
11
+ min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
12
+ ),
13
+ L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
14
+ ]
15
+ dataloader.train.mapper.image_format = "RGB"
16
+ dataloader.train.total_batch_size = 64
17
+ dataloader.train.num_workers = 0
18
+ # recompute boxes due to cropping
19
+ dataloader.train.mapper.recompute_boxes = True
20
+
21
+ dataloader.test.mapper.augmentations = [
22
+ L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
23
+ ]
24
+
25
+
26
+
27
+
28
+ import copy
29
+ import logging
30
+ import numpy as np
31
+ from typing import List, Optional, Union
32
+ import torch
33
+
34
+ from detectron2.config import configurable
35
+
36
+ from detectron2.data import detection_utils as utils
37
+ from detectron2.data import transforms as T
38
+
39
+ """
40
+ This file contains the default mapping that's applied to "dataset dicts".
41
+ """
42
+
43
+ __all__ = ["DatasetMapper"]
44
+
45
+
46
+ class DatasetMapper:
47
+ """
48
+ A callable which takes a dataset dict in Detectron2 Dataset format,
49
+ and map it into a format used by the model.
50
+
51
+ This is the default callable to be used to map your dataset dict into training data.
52
+ You may need to follow it to implement your own one for customized logic,
53
+ such as a different way to read or transform images.
54
+ See :doc:`/tutorials/data_loading` for details.
55
+
56
+ The callable currently does the following:
57
+
58
+ 1. Read the image from "file_name"
59
+ 2. Applies cropping/geometric transforms to the image and annotations
60
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
61
+ """
62
+
63
+ @configurable
64
+ def __init__(
65
+ self,
66
+ is_train: bool,
67
+ *,
68
+ augmentations: List[Union[T.Augmentation, T.Transform]],
69
+ image_format: str,
70
+ use_instance_mask: bool = False,
71
+ use_keypoint: bool = False,
72
+ instance_mask_format: str = "polygon",
73
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
74
+ precomputed_proposal_topk: Optional[int] = None,
75
+ recompute_boxes: bool = False,
76
+ ):
77
+ """
78
+ NOTE: this interface is experimental.
79
+
80
+ Args:
81
+ is_train: whether it's used in training or inference
82
+ augmentations: a list of augmentations or deterministic transforms to apply
83
+ image_format: an image format supported by :func:`detection_utils.read_image`.
84
+ use_instance_mask: whether to process instance segmentation annotations, if available
85
+ use_keypoint: whether to process keypoint annotations if available
86
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
87
+ masks into this format.
88
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
89
+ precomputed_proposal_topk: if given, will load pre-computed
90
+ proposals from dataset_dict and keep the top k proposals for each image.
91
+ recompute_boxes: whether to overwrite bounding box annotations
92
+ by computing tight bounding boxes from instance mask annotations.
93
+ """
94
+ if recompute_boxes:
95
+ assert use_instance_mask, "recompute_boxes requires instance masks"
96
+ # fmt: off
97
+ self.is_train = is_train
98
+ self.augmentations = T.AugmentationList(augmentations)
99
+ self.image_format = image_format
100
+ self.use_instance_mask = use_instance_mask
101
+ self.instance_mask_format = instance_mask_format
102
+ self.use_keypoint = use_keypoint
103
+ self.keypoint_hflip_indices = keypoint_hflip_indices
104
+ self.proposal_topk = precomputed_proposal_topk
105
+ self.recompute_boxes = recompute_boxes
106
+ # fmt: on
107
+ logger = logging.getLogger(__name__)
108
+ mode = "training" if is_train else "inference"
109
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
110
+
111
+ @classmethod
112
+ def from_config(cls, cfg, is_train: bool = True):
113
+ augs = utils.build_augmentation(cfg, is_train)
114
+ if cfg.INPUT.CROP.ENABLED and is_train:
115
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
116
+ recompute_boxes = cfg.MODEL.MASK_ON
117
+ else:
118
+ recompute_boxes = False
119
+
120
+ ret = {
121
+ "is_train": is_train,
122
+ "augmentations": augs,
123
+ "image_format": cfg.INPUT.FORMAT,
124
+ "use_instance_mask": cfg.MODEL.MASK_ON,
125
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
126
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
127
+ "recompute_boxes": recompute_boxes,
128
+ }
129
+
130
+ if cfg.MODEL.KEYPOINT_ON:
131
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
132
+
133
+ if cfg.MODEL.LOAD_PROPOSALS:
134
+ ret["precomputed_proposal_topk"] = (
135
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
136
+ if is_train
137
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
138
+ )
139
+ return ret
140
+
141
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
142
+ # USER: Modify this if you want to keep them for some reason.
143
+ for anno in dataset_dict["annotations"]:
144
+ if not self.use_instance_mask:
145
+ anno.pop("segmentation", None)
146
+ if not self.use_keypoint:
147
+ anno.pop("keypoints", None)
148
+
149
+ # USER: Implement additional transformations if you have other types of data
150
+ annos = [
151
+ utils.transform_instance_annotations(
152
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
153
+ )
154
+ for obj in dataset_dict.pop("annotations")
155
+ if obj.get("iscrowd", 0) == 0
156
+ ]
157
+ instances = utils.annotations_to_instances(
158
+ annos, image_shape, mask_format=self.instance_mask_format
159
+ )
160
+
161
+ # After transforms such as cropping are applied, the bounding box may no longer
162
+ # tightly bound the object. As an example, imagine a triangle object
163
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
164
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
165
+ # the intersection of original bounding box and the cropping box.
166
+ if self.recompute_boxes:
167
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
168
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
169
+
170
+ def __call__(self, dataset_dict):
171
+ """
172
+ Args:
173
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
174
+
175
+ Returns:
176
+ dict: a format that builtin models in detectron2 accept
177
+ """
178
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
179
+ # USER: Write your own image loading if it's not from a file
180
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
181
+ utils.check_image_size(dataset_dict, image)
182
+
183
+ # USER: Remove if you don't do semantic/panoptic segmentation.
184
+ if "sem_seg_file_name" in dataset_dict:
185
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
186
+ else:
187
+ sem_seg_gt = None
188
+
189
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
190
+ transforms = self.augmentations(aug_input)
191
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
192
+
193
+ image_shape = image.shape[:2] # h, w
194
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
195
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
196
+ # Therefore it's important to use torch.Tensor.
197
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
198
+ if sem_seg_gt is not None:
199
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
200
+
201
+ # USER: Remove if you don't use pre-computed proposals.
202
+ # Most users would not need this feature.
203
+ if self.proposal_topk is not None:
204
+ utils.transform_proposals(
205
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
206
+ )
207
+
208
+ if not self.is_train:
209
+ # USER: Modify this if you want to keep them for some reason.
210
+ dataset_dict.pop("annotations", None)
211
+ dataset_dict.pop("sem_seg_file_name", None)
212
+ return dataset_dict
213
+
214
+ if "annotations" in dataset_dict:
215
+ self._transform_annotations(dataset_dict, transforms, image_shape)
216
+
217
+ # Modified by Honglin Chen: change it to class-agnostic instance labels
218
+ dataset_dict['instances'].gt_classes *= 0
219
+ return dataset_dict
220
+
221
+
222
+ dataloader.train.mapper._target_ = DatasetMapper
cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument('--input', type=str, help='The path to the checkpoint.')
6
+ parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
7
+ args = parser.parse_args()
8
+
9
+ state_dict = torch.load(args.input, map_location='cpu')['model']
10
+
11
+ new_state_dict = {}
12
+ for k, v in state_dict.items():
13
+ if 'encoder' in k and not 'decoder' in k:
14
+ new_k = 'backbone.net.model.' + k
15
+ new_state_dict[new_k] = v
16
+
17
+ output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
18
+ torch.save(new_state_dict, output_path)
19
+ print('Save model to', output_path)
cwm/eval/Segmentation/archive/common/convert_cwm_checkpoint_detectron_format_v2.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import sys
4
+ sys.path.append('../../../')
5
+ from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--input', type=str, help='The path to the checkpoint.')
9
+ parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
10
+ args = parser.parse_args()
11
+
12
+ state_dict = torch.load(args.input, map_location='cpu')['model']
13
+ mae = True
14
+ # C = state_dict['encoder.patch_embed.proj.weight'].shape[0]
15
+ C = 768
16
+ pos_embed = get_sinusoid_encoding_table(14*14, C)
17
+ cls_token = torch.zeros(1, 1, C)
18
+ pos_embed = torch.cat([cls_token, pos_embed], dim=1)
19
+
20
+
21
+ new_state_dict = {'backbone.net.pos_embed': pos_embed}
22
+ for k, v in state_dict.items():
23
+
24
+ if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k):
25
+
26
+ if 'patch_embed.proj.weight' in k:
27
+
28
+ if len(v.shape) == 5:
29
+ if v.shape[2] == 1:
30
+ v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16)
31
+ else:
32
+ v = v[:, :, 0]
33
+
34
+ old_k = k
35
+ k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k
36
+
37
+ if 'attn' in k and '_bias' in k:
38
+ old_attn = '.'.join(old_k.split('.')[:-1])
39
+ attn = '.'.join(k.split('.')[:-1])
40
+ k = attn + '.qkv.bias'
41
+ if k in new_state_dict:
42
+ continue
43
+
44
+ v = torch.cat([
45
+ state_dict[old_attn + '.q_bias'],
46
+ state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']),
47
+ state_dict[old_attn + '.v_bias'],
48
+ ], dim=0)
49
+ print(k, v.shape)
50
+ new_state_dict[k] = v
51
+
52
+ output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
53
+ torch.save(new_state_dict, output_path)
54
+ print('Save model to', output_path)
cwm/eval/Segmentation/archive/common/convert_dino_checkpoint_detectron_format.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import sys
4
+ sys.path.append('../../../')
5
+ from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--input', type=str, help='The path to the checkpoint.')
9
+ parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
10
+ args = parser.parse_args()
11
+ breakpoint()
12
+ state_dict = torch.load(args.input, map_location='cpu')
13
+
14
+ new_state_dict = {}
15
+
16
+ for k, v in state_dict.items():
17
+ if 'pos_embed' in k:
18
+ breakpoint()
19
+ else:
20
+ pass
21
+ k = 'backbone.net.' + k
22
+ new_state_dict[k] = v
23
+
24
+ output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
25
+ torch.save(new_state_dict, output_path)
26
+ print('Save model to', output_path)
cwm/eval/Segmentation/archive/competition.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from torch.distributions.categorical import Categorical
7
+
8
+ from kornia.filters.kernels import (get_spatial_gradient_kernel2d,
9
+ normalize_kernel2d)
10
+
11
+ def l2_normalize(x):
12
+ return F.normalize(x, p=2.0, dim=-1, eps=1e-6)
13
+
14
+ def reduce_max(x, dim, keepdim=True):
15
+ return torch.max(x, dim=dim, keepdim=keepdim)[0]
16
+
17
+ def coordinate_ims(batch_size, seq_length, imsize):
18
+ static = False
19
+ if seq_length == 0:
20
+ static = True
21
+ seq_length = 1
22
+ B = batch_size
23
+ T = seq_length
24
+ H,W = imsize
25
+ ones = torch.ones([B,H,W,1], dtype=torch.float32)
26
+ h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32))
27
+ h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
28
+ w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32))
29
+ w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
30
+ h = torch.stack([h]*T, 1)
31
+ w = torch.stack([w]*T, 1)
32
+ hw_ims = torch.cat([h,w], -1)
33
+ if static:
34
+ hw_ims = hw_ims[:,0]
35
+ return hw_ims
36
+
37
+ def dot_product_attention(queries, keys, normalize=True, eps=1e-8):
38
+ """
39
+ Compute the normalized dot product between two PyTorch tensors
40
+ """
41
+ B,N,D_q = queries.size()
42
+ _B,N_k,D_k = keys.size()
43
+ assert D_q == D_k, (queries.shape, keys.shape)
44
+ if normalize:
45
+ queries = F.normalize(queries, p=2.0, dim=-1, eps=eps)
46
+ keys = F.normalize(keys, p=2.0, dim=-1, eps=eps)
47
+
48
+ outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) # [B, N, N_k]
49
+ attention = torch.transpose(outputs, 1, 2) # [B, N_k, N]
50
+
51
+ return outputs
52
+
53
+ def sample_image_inds_from_probs(probs, num_points, eps=1e-9):
54
+
55
+ B,H,W = probs.shape
56
+ P = num_points
57
+ N = H*W
58
+
59
+ probs = probs.reshape(B,N)
60
+ probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps)
61
+ dist = Categorical(probs=probs, validate_args=False)
62
+
63
+ indices = dist.sample([P]).permute(1,0).to(torch.int32) # [B,P]
64
+
65
+ indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1))
66
+ indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1))
67
+ indices = torch.stack([indices_h, indices_w], dim=-1) # [B,P,2]
68
+ return indices
69
+
70
+ def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True):
71
+
72
+ B,C,H,W = list(image.size())
73
+
74
+ # prepare kernel
75
+ kernel = get_spatial_gradient_kernel2d(mode, order)
76
+ if normalize_kernel:
77
+ kernel = normalize_kernel2d(kernel)
78
+ tmp_kernel = kernel.to(image).detach()
79
+ tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
80
+ kernel_flip = tmp_kernel.flip(-3)
81
+
82
+ # pad spatial dims of image
83
+ padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
84
+ out_channels = 3 if (order == 2) else 2
85
+ padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] # [B*C,1,1,H+p,W+p]
86
+ gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W)
87
+ return gradient_image
88
+
89
+ def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True):
90
+ """
91
+ Sample num_points in normalized (h,w) coordinates from the borders of the input image
92
+ """
93
+ B,C,H,W = list(image.size())
94
+ if mask is not None:
95
+ assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size())
96
+ else:
97
+ mask = torch.ones(size=(B,1,H,W)).to(image)
98
+
99
+ gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) # [B,C,2,H,W]
100
+ gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2))
101
+ if sum_edges:
102
+ edges = gradient_magnitude.sum(1) # [B,H,W]
103
+ else:
104
+ edges = gradient_magnitude.max(1)[0]
105
+
106
+ if mask is not None:
107
+ edges = edges * mask[:,0]
108
+
109
+ coordinates = sample_image_inds_from_probs(edges, num_points=num_points)
110
+ if normalized_coordinates:
111
+ coordinates = coordinates.to(torch.float32)
112
+ coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device)
113
+ coordinates = 2.0 * coordinates - 1.0
114
+ return coordinates
115
+
116
+ def index_into_images(images, indices, channels_last=False):
117
+ """
118
+ index into an image at P points to get its values
119
+
120
+ images: [B,C,H,W]
121
+ indices: [B,P,2]
122
+ """
123
+ assert indices.size(-1) == 2, indices.size()
124
+ if channels_last:
125
+ images = images.permute(0,3,1,2) # [B,C,H,W]
126
+ B,C,H,W = images.shape
127
+ _,P,_ = indices.shape
128
+ inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) # [B,P] each
129
+ inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices)
130
+ inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long)
131
+ values = images.permute(0,2,3,1)[list(inds)] # [B,P,C]
132
+ return values
133
+
134
+ def soft_index(images, indices, scale_by_imsize=True):
135
+ assert indices.shape[-1] == 2, indices.shape
136
+ B,C,H,W = images.shape
137
+ _,P,_ = indices.shape
138
+
139
+ # h_inds, w_inds = indices.split([1,1], dim=-1)
140
+ h_inds, w_inds = list(indices.permute(2,0,1))
141
+ if scale_by_imsize:
142
+ h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5
143
+ w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5
144
+
145
+ h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds))
146
+ w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds))
147
+
148
+ h_floor = torch.floor(h_inds)
149
+ w_floor = torch.floor(w_inds)
150
+ h_ceil = torch.ceil(h_inds)
151
+ w_ceil = torch.ceil(w_inds)
152
+
153
+ bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor)
154
+ bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds)
155
+ top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor)
156
+ top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds)
157
+
158
+ in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95
159
+ in_bounds = in_bounds.to(torch.float32)
160
+
161
+ top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1))
162
+ top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1))
163
+ bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1))
164
+ bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1))
165
+
166
+ im_vals = top_left_vals * top_left_weight[...,None]
167
+ im_vals += top_right_vals * top_right_weight[...,None]
168
+ im_vals += bot_left_vals * bot_left_weight[...,None]
169
+ im_vals += bot_right_vals * bot_right_weight[...,None]
170
+
171
+ im_vals = im_vals.view(B,P,C)
172
+
173
+ return im_vals
174
+
175
+ def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1):
176
+ """
177
+ Compute how well "fit" each agent is for the position it's at on the plateau,
178
+ according to its "phenotype"
179
+
180
+ positions: [B,P,2]
181
+ plateau: [B,H,W,Q]
182
+ phenotypes: [B,P,D] or None
183
+ availability: [B,H,W,A]
184
+ """
185
+ B,H,W,Q = plateau.shape
186
+ P = positions.shape[1]
187
+ if phenotypes is None:
188
+ phenotypes = soft_index(plateau, positions)
189
+
190
+ if availability is not None:
191
+ assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape)
192
+ A = availability.size(-1)
193
+ assert P % A == 0, (P, A)
194
+ S = P // A # population size
195
+ print("computing availability -- needlessly?", [B,H,W,A,Q])
196
+ plateau = availability[...,None] * plateau[...,None,:] # [B,H,W,A,Q]
197
+ plateau = plateau.view(B,H,W,A*Q)
198
+
199
+ plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True)
200
+ if noise > 0:
201
+ plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device)
202
+
203
+ if availability is not None:
204
+ plateau_values = l2_normalize(plateau_values.view(B, P, A, Q))
205
+ inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] # [B,P,A,1]
206
+ plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) # [B,P,Q]
207
+ else:
208
+ plateau_values = l2_normalize(plateau_values)
209
+
210
+ compatibility = torch.sum(
211
+ l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) # [B,P,1]
212
+
213
+ return compatibility
214
+
215
+ def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6):
216
+ """Find overlaps between masks"""
217
+ B,N,P = masks.shape
218
+ if masks_target is None:
219
+ masks_target = masks
220
+ if mask_thresh is not None:
221
+ masks = (masks > mask_thresh).to(torch.float32)
222
+ masks_target = (masks_target > mask_thresh).to(torch.float32)
223
+
224
+ ## union and intersection
225
+ overlaps = masks[...,None] * masks_target[...,None,:] # [B,N,P,P]
226
+ I = overlaps.sum(dim=1)
227
+ U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1)
228
+ iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) # [B,P,P]
229
+
230
+ return iou
231
+
232
+ def compete_agents(masks, fitnesses, alive,
233
+ mask_thresh=0.5, compete_thresh=0.2,
234
+ sticky_winners=True):
235
+ """
236
+ Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each
237
+
238
+ args:
239
+ masks: [B,N,P]
240
+ fitnesses: [B,P,1]
241
+ alive: [B,P,1]
242
+
243
+ returns:
244
+ still_alive: [B,P,1]
245
+
246
+ """
247
+ B,N,P = masks.shape
248
+ assert list(alive.shape) == [B,P,1], alive.shape
249
+ assert list(fitnesses.shape) == [B,P,1], fitnesses.shape
250
+
251
+ ## find territorial disputes
252
+ overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh)
253
+ disputes = overlaps > compete_thresh # [B,P,P] <bool>
254
+
255
+ ## agents don't fight themselves
256
+ disputes = torch.logical_and(
257
+ disputes, torch.logical_not(
258
+ torch.eye(P, dtype=torch.bool, device=disputes.device).unsqueeze(0).expand(B,-1,-1)))
259
+
260
+ ## kill off the agents with lower fitness in each dispute
261
+ killed = torch.logical_and(disputes, fitnesses < torch.transpose(fitnesses, 1, 2))
262
+
263
+ ## once an agent wins, it always wins again
264
+ if sticky_winners:
265
+ winners = (alive > 0.5)
266
+ losers = torch.logical_not(winners)
267
+
268
+ ## winners can't lose to last round's losers
269
+ winners_vs_losers = torch.logical_and(winners, torch.transpose(losers, 1, 2)) # [B,P,P]
270
+ killed = torch.logical_and(killed, torch.logical_not(winners_vs_losers))
271
+
272
+ ## losers can't overtake last round's winners
273
+ losers_vs_winners = torch.logical_and(losers, torch.transpose(winners, 1, 2))
274
+ losers_vs_winners_disputes = torch.logical_and(losers_vs_winners, disputes)
275
+ killed = torch.logical_or(killed, losers_vs_winners_disputes)
276
+
277
+ ## if an agent was killed by *any* competitor, it's dead
278
+ killed = torch.any(killed, dim=2, keepdim=True)
279
+ alive = torch.logical_not(killed).to(torch.float32)
280
+
281
+ return alive
282
+
283
+ def compute_distance_weighted_vectors(vector_map, positions, mask=None, beta=1.0, eps=1e-8):
284
+ """
285
+ compute vectors whose values are a weighted mean of vector_map, where weights are given by distance.
286
+ """
287
+ B,H,W,D = vector_map.shape
288
+ assert positions.size(-1) == 2, positions.size()
289
+ B,P,_ = positions.shape
290
+ N = H*W
291
+
292
+ if mask is None:
293
+ mask = torch.ones_like(vector_map[...,0:1]).to(vector_map.device)
294
+ else:
295
+ assert list(mask.shape) == [B,H,W,1]
296
+
297
+ hw_grid = coordinate_ims(B, 0, [H,W]).view(B, N, 2).to(vector_map.device)
298
+ delta_positions = hw_grid[:,None] - positions[:,:,None] # [B,P,N,2]
299
+ distances = torch.sqrt(delta_positions[...,0]**2 + delta_positions[...,1]**2 + eps) # [B,P,N]
300
+
301
+ ## max distance is 2*sqrt(2)
302
+ inv_distances = (2.0 * np.sqrt(2.0)) / (distances + eps)
303
+ inv_distances = F.softmax(beta * inv_distances * mask.view(B, 1, N), dim=-1) # [B,P,N]
304
+ distance_weighted_vectors = torch.sum(
305
+ vector_map.view(B, 1, N, D) * inv_distances[...,None], dim=2, keepdim=False) # [B,P,D]
306
+ return distance_weighted_vectors
307
+
308
+ def masks_from_phenotypes(plateau, phenotypes, normalize=True):
309
+
310
+ B,H,W,Q = plateau.shape
311
+ N = H*W
312
+ masks = dot_product_attention(
313
+ queries=plateau.view(B,N,Q),
314
+ keys=phenotypes,
315
+ normalize=normalize)
316
+ masks = F.relu(masks)
317
+ return masks
318
+
319
+ class Competition(nn.Module):
320
+
321
+ def __init__(
322
+ self,
323
+ size=None,
324
+ num_masks=16,
325
+ num_competition_rounds=5,
326
+ mask_beta=10.0,
327
+ reduce_func=reduce_max,
328
+ stop_gradient=True,
329
+ stop_gradient_phenotypes=True,
330
+ normalization_func=l2_normalize,
331
+ sum_edges=True,
332
+ mask_thresh=0.5,
333
+ compete_thresh=0.2,
334
+ sticky_winners=True,
335
+ selection_strength=100.0,
336
+ homing_strength=10.0,
337
+ mask_dead_segments=True
338
+ ):
339
+ super().__init__()
340
+ self.num_masks = self.M = num_masks
341
+ self.num_competition_rounds = num_competition_rounds
342
+ self.mask_beta = mask_beta
343
+ self.reduce_func = reduce_func
344
+ self.normalization_func = normalization_func
345
+
346
+ ## stop gradients
347
+ self.sg_func = lambda x: (x.detach() if stop_gradient else x)
348
+ self.sg_phenotypes_func = lambda x: (x.detach() if stop_gradient_phenotypes else x)
349
+
350
+ ## agent sampling kwargs
351
+ self.sum_edges = sum_edges
352
+
353
+ ## competition kwargs
354
+ self.mask_thresh = mask_thresh
355
+ self.compete_thresh = compete_thresh
356
+ self.sticky_winners = sticky_winners
357
+ self.selection_strength = selection_strength
358
+ self.homing_strength = homing_strength
359
+ self.mask_dead_segments = mask_dead_segments
360
+
361
+ ## shapes
362
+ self.B = self.T = self.BT = self.N = self.Q = None
363
+ self.size = size # [H,W]
364
+ if self.size:
365
+ assert len(self.size) == 2, self.size
366
+
367
+ def reshape_batch_time(self, x, merge=True):
368
+
369
+ if merge:
370
+ self.is_temporal = True
371
+ B, T = x.size()[0:2]
372
+ if self.B:
373
+ assert (B == self.B), (B, self.B)
374
+ else:
375
+ self.B = B
376
+
377
+ if self.T:
378
+ assert (T == self.T), (T, self.T)
379
+ else:
380
+ self.T = T
381
+
382
+ assert B*T == (self.B * self.T), (B*T, self.B*self.T)
383
+ if self.BT is None:
384
+ self.BT = self.B * self.T
385
+
386
+ return torch.reshape(x, [self.BT] + list(x.size())[2:])
387
+
388
+ else: # split
389
+ BT = x.size()[0]
390
+ assert self.B and self.T, (self.B, self.T)
391
+ if self.BT is not None:
392
+ assert BT == self.BT, (BT, self.BT)
393
+ else:
394
+ self.BT = BT
395
+
396
+ return torch.reshape(x, [self.B, self.T] + list(x.size())[1:])
397
+
398
+ def process_plateau_input(self, plateau):
399
+
400
+ shape = plateau.size()
401
+ if len(shape) == 5:
402
+ self.is_temporal = True
403
+ self.B, self.T, self.H, self.W, self.Q = shape
404
+ self.N = self.H * self.W
405
+ self.BT = self.B * self.T
406
+ plateau = self.reshape_batch_time(plateau)
407
+ elif (len(shape) == 4) and (self.size is None):
408
+ self.is_temporal = False
409
+ self.B, self.H, self.W, self.Q = shape
410
+ self.N = self.H * self.W
411
+ self.T = 1
412
+ self.BT = self.B*self.T
413
+ elif (len(shape) == 4) and (self.size is not None):
414
+ self.is_temporal = True
415
+ self.B, self.T, self.N, self.Q = shape
416
+ self.BT = self.B * self.T
417
+ self.H, self.W = self.size
418
+ plateau = self.reshape_batch_time(plateau)
419
+ plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q])
420
+ elif len(shape) == 3:
421
+ assert self.size is not None, \
422
+ "You need to specify an image size to reshape the plateau of shape %s" % shape
423
+ self.is_temporal = False
424
+ self.B, self.N, self.Q = shape
425
+ self.T = 1
426
+ self.BT = self.B
427
+ self.H, self.W = self.size
428
+ plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q])
429
+ else:
430
+ raise ValueError("input plateau map with shape %s cannot be reshaped to [BT, H, W, Q]" % shape)
431
+
432
+ return plateau
433
+
434
+ def forward(self,
435
+ plateau,
436
+ agents=None,
437
+ alive=None,
438
+ phenotypes=None,
439
+ compete=True,
440
+ update_pointers=True,
441
+ yoke_phenotypes_to_agents=True,
442
+ noise=0.1
443
+ ):
444
+ """
445
+ Find the uniform regions within the plateau map
446
+ by competition between visual "indices."
447
+
448
+ args:
449
+ plateau: [B,[T],H,W,Q] feature map with smooth "plateaus"
450
+
451
+ returns:
452
+ masks: [B, [T], H, W, M] <float> one mask in each of M channels
453
+ agents: [B, [T], M, 2] <float> positions of agents in normalized coordinates
454
+ alive: [B, [T], M] <float> binary vector indicating which masks are valid
455
+ phenotypes: [B, [T], M, Q]
456
+ unharvested: [B, [T], H, W] <float> map of regions that weren't covered
457
+
458
+ """
459
+
460
+ ## preprocess
461
+ plateau = self.process_plateau_input(plateau) # [BT,H,W,Q]
462
+ plateau = self.normalization_func(plateau)
463
+
464
+ ## sample initial indices ("agents") from borders of the plateau map
465
+ if agents is None:
466
+ agents = sample_coordinates_at_borders(
467
+ plateau.permute(0,3,1,2),
468
+ num_points=self.M,
469
+ mask=None,
470
+ sum_edges=self.sum_edges)
471
+ else:
472
+ if self.is_temporal:
473
+ agents = agents.view(self.BT, *agents.shape[2:])
474
+
475
+ ## the agents have "phenotypes" depending on where they're situated on the plateau map
476
+ if phenotypes is None:
477
+ phenotypes = self.sg_phenotypes_func(
478
+ self.normalization_func(
479
+ soft_index(plateau.permute(0,3,1,2),
480
+ agents, scale_by_imsize=True)))
481
+ elif self.is_temporal:
482
+ phenotypes = phenotypes.view(self.BT, *phenotypes.shape[2:])
483
+
484
+ ## the "fitness" of an agent -- how likely it is to survive competition --
485
+ ## is how well its phenotype matches the plateau vector at its current position
486
+ ## initially all of these agents are "alive"
487
+ if alive is None:
488
+ alive = torch.ones_like(agents[...,-1:]) # [BT,M,1]
489
+ fitnesses = compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise)
490
+ alive_mask = None
491
+ else:
492
+ if self.is_temporal:
493
+ alive = alive.view(self.BT, *alive.shape[2:])
494
+ alive_mask = (alive > 0.5).float()
495
+ fitnesses = alive_mask + compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) * (1 - alive_mask)
496
+
497
+ alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M]
498
+
499
+ ## compute the masks at initialization
500
+ masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True)
501
+
502
+ ## find the "unharvested" regions of the plateau map not covered by agents
503
+ unharvested = torch.minimum(self.reduce_func(masks_pred, dim=-1, keepdim=True), torch.tensor(1.0))
504
+ unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1)
505
+
506
+ if alive_mask is not None:
507
+ new_agents = sample_coordinates_at_borders(
508
+ plateau.permute(0,3,1,2), num_points=self.M,
509
+ mask=unharvested.permute(0,3,1,2),
510
+ sum_edges=self.sum_edges)
511
+ agents = agents * alive_mask + new_agents * (1.0 - alive_mask)
512
+
513
+ new_phenotypes = self.sg_phenotypes_func(
514
+ self.normalization_func(
515
+ soft_index(plateau.permute(0,3,1,2),
516
+ new_agents, scale_by_imsize=True)))
517
+ phenotypes = phenotypes * alive_mask + new_phenotypes * (1.0 - alive_mask)
518
+
519
+ for r in range(self.num_competition_rounds):
520
+ # print("Evolution round {}".format(r+1))
521
+
522
+ ## compute the "availability" of the plateau map for each agent (i.e. where it can harvest from)
523
+ alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M]
524
+ # availability = alive_t * masks_pred + (1.0 - alive_t) * unharvested.view(self.BT, self.N, 1)
525
+ # availability = availability.view(self.BT, self.H, self.W, self.M)
526
+
527
+ ## update the fitnesses
528
+ if update_pointers and compete:
529
+ fitnesses = compute_compatibility(
530
+ positions=agents,
531
+ plateau=plateau,
532
+ phenotypes=phenotypes,
533
+ # availability=availability)
534
+ availability=None,
535
+ noise=noise
536
+ )
537
+
538
+
539
+ ## kill agents that have wandered off the map
540
+ in_bounds = torch.all(
541
+ torch.logical_and(agents < 1.0, agents > -1.0),
542
+ dim=-1, keepdim=True) # [BT,M,1]
543
+ fitnesses *= in_bounds.to(fitnesses)
544
+
545
+ ## break ties in fitness
546
+ fitnesses -= 0.001 * torch.arange(self.M, dtype=torch.float32)[None,:,None].expand(self.BT,-1,-1).to(fitnesses.device)
547
+
548
+ ## recompute the masks (why?)
549
+ if yoke_phenotypes_to_agents:
550
+ occupied_regions = self.sg_phenotypes_func(
551
+ soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True))
552
+ masks_pred = masks_from_phenotypes(plateau, occupied_regions, normalize=True) # [BT,N,M]
553
+
554
+ ## have each pair of agents compete.
555
+ ## If their masks overlap, the winner is the one with higher fitness
556
+ if compete:
557
+ alive = compete_agents(masks_pred, fitnesses, alive,
558
+ mask_thresh=self.mask_thresh,
559
+ compete_thresh=self.compete_thresh,
560
+ sticky_winners=self.sticky_winners)
561
+
562
+ alive *= in_bounds.to(alive)
563
+ alive_t = torch.transpose(alive, 1, 2)
564
+
565
+ # print("Num alive masks", alive.sum(), "which ones --> ", np.where(alive[0,:,0].detach().cpu().numpy()))
566
+ if not yoke_phenotypes_to_agents:
567
+ masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True)
568
+
569
+ ## update which parts of the plateau are "unharvested"
570
+ unharvested = torch.minimum(self.reduce_func(masks_pred * alive_t, dim=-1, keepdim=True),
571
+ torch.tensor(1.0, dtype=torch.float32))
572
+ unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1)
573
+
574
+
575
+ ## update phenotypes of the winners
576
+ if update_pointers:
577
+ if self.mask_thresh is not None:
578
+ winner_phenotypes = (masks_pred[...,None] > self.mask_thresh).to(plateau)
579
+ if self.selection_strength > 0:
580
+ winner_phenotypes = winner_phenotypes * plateau.view(self.BT, self.N, 1, self.Q)
581
+ winner_phenotypes = self.normalization_func(winner_phenotypes.mean(dim=1)) # [BT,M,Q]
582
+ phenotypes += (alive * winner_phenotypes) * self.selection_strength
583
+
584
+ ## reinitialize losing agent positions
585
+ alive_mask = (alive > 0.5).to(torch.float32)
586
+ loser_agents = sample_coordinates_at_borders(
587
+ plateau.permute(0,3,1,2), num_points=self.M,
588
+ mask=unharvested.permute(0,3,1,2),
589
+ sum_edges=self.sum_edges)
590
+ agents = agents * alive_mask + loser_agents * (1.0 - alive_mask)
591
+
592
+
593
+ ## reinitialize loser agent phenotypes
594
+ loser_phenotypes = self.normalization_func(
595
+ compute_distance_weighted_vectors(plateau, agents, mask=unharvested, beta=self.homing_strength))
596
+ phenotypes = alive_mask * phenotypes + (1.0 - alive_mask) * loser_phenotypes
597
+ phenotypes = self.normalization_func(phenotypes)
598
+
599
+ ## that's it for this round!
600
+ # print("round %d" % r, alive.shape, torch.where(alive[0,:,0]))
601
+
602
+ ## run a final competition between the surviving masks
603
+ if self.mask_beta is not None:
604
+ masks_pred = F.softmax(
605
+ self.mask_beta * masks_pred * alive_t - \
606
+ self.mask_beta * (1.0 - alive_t), dim=-1)
607
+ if self.mask_dead_segments:
608
+ masks_pred *= alive_t
609
+
610
+ masks_pred = masks_pred.view(self.BT,self.H,self.W,self.M)
611
+ if self.is_temporal:
612
+ masks_pred = self.reshape_batch_time(masks_pred, merge=False)
613
+ agents = self.reshape_batch_time(agents, merge=False)
614
+ alive = self.reshape_batch_time(alive, merge=False)
615
+ phenotypes = self.reshape_batch_time(phenotypes, merge=False)
616
+ unharvested = self.reshape_batch_time(unharvested, merge=False)
617
+
618
+ return (masks_pred, agents, alive, phenotypes, unharvested)
619
+
620
+ @staticmethod
621
+ def masks_to_segments(masks):
622
+ return masks.argmax(-1)
623
+
624
+ @staticmethod
625
+ def flatten_plateau_with_masks(plateau, masks, alive, flatten_masks=True):
626
+ B,M,_ = alive.shape
627
+ Q = plateau.shape[-1]
628
+ if flatten_masks:
629
+ masks = F.one_hot((alive[...,None,None,:,0] * masks).argmax(-1), num_classes=M).float()
630
+
631
+ flat_plateau = torch.zeros_like(plateau)
632
+ phenotypes = torch.zeros((B,M,Q), device=plateau.device).float()
633
+ for b in range(B):
634
+ m_inds = torch.where(alive[b,:,0])[0]
635
+ masks_b = masks[b,...,m_inds]
636
+ num_px = masks_b.sum((0,1)).clamp(min=1)[:,None] # [K,1]
637
+ phenos_b = torch.einsum('hwk,hwq->kq', masks_b, plateau[b]) / num_px # [K,Q]
638
+ flat_plateau_b = (masks_b[...,None] * phenos_b[None,None]).sum(-2) # [H,W,Q]
639
+
640
+ phenotypes[b,m_inds,:] = phenos_b
641
+ flat_plateau[b] = flat_plateau_b
642
+
643
+ _norm = lambda x: F.normalize(x, p=2, dim=-1)
644
+ return (_norm(flat_plateau), _norm(phenotypes))
645
+
646
+ @staticmethod
647
+ def plot_agents(agents, alive, size=[128,128]):
648
+ B,M,_ = alive.shape
649
+ agent_map = -1 * torch.ones((B,*size), device=alive.device, dtype=torch.long)
650
+ for b in range(B):
651
+ inds = torch.where(alive[b,:,0])
652
+ for i in inds[0]:
653
+ pos = agents[b,i]*0.5 + 0.5
654
+ pos = pos * torch.tensor(size, device=pos.device)
655
+ hmin, wmin = list(torch.floor(pos).long())
656
+ hmax, wmax = list(torch.ceil(pos).long())
657
+ agent_map[b,[hmin,hmin,hmax,hmax],[wmin,wmax,wmin,wmax]] = i
658
+
659
+ return agent_map
660
+
661
+ if __name__ == '__main__':
662
+
663
+ Comp = Competition(num_masks=32, num_competition_rounds=5)
664
+
665
+ left = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([1.,0.2,0.])
666
+ middle = torch.ones(size=(32,16)).unsqueeze(-1) * torch.tensor([0.,1.,0.2])
667
+ right = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([0.1,0.,1.])
668
+ plateau = torch.cat([left, middle, right], dim=-2).unsqueeze(0)
669
+ masks, agents, alive, phenotypes, unharvested = Comp(plateau)
670
+ mask_inds = np.where(alive[0,:,0].numpy())[0]
671
+ print(np.argmax(masks[0,...], axis=-1))
672
+ for ind in mask_inds:
673
+ print("num pixels in mask %d ---> %d" % (ind, (np.argmax(masks[0], -1) == ind).sum()))
cwm/eval/Segmentation/archive/configs/__init__.py ADDED
File without changes
cwm/eval/Segmentation/archive/configs/mask_rcnn_cwm_vitdet_b_100ep.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from fvcore.common.param_scheduler import MultiStepParamScheduler
3
+
4
+ from detectron2 import model_zoo
5
+ from detectron2.config import LazyCall as L
6
+ from detectron2.config import CfgNode, LazyConfig
7
+ from detectron2.solver import WarmupParamScheduler
8
+ from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
9
+
10
+ from ..common.coco_loader_lsj import dataloader
11
+
12
+
13
+ # model = model_zoo.get_config("./models/mask_rcnn_cwm.py").model
14
+
15
+ cfg_file = "./models/mask_rcnn_cwm.py"
16
+ model = LazyConfig.load(cfg_file).model
17
+
18
+ # url = get_checkpoint_url(config_path)
19
+ # if "train" in cfg and "init_checkpoint" in cfg.train:
20
+ # cfg.train.init_checkpoint = url
21
+ # else:
22
+ # raise NotImplementedError
23
+
24
+ # Initialization and trainer settings
25
+ train = model_zoo.get_config("common/train.py").train
26
+ train.amp.enabled = True
27
+ train.ddp.fp16_compression = True
28
+ train.init_checkpoint = (
29
+ #"/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping/checkpoint-799-encoder.pth"
30
+ './output/model_0004999.pth'
31
+ )
32
+ train.eval_period = 1e9
33
+
34
+ # Schedule
35
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
36
+ # train.max_iter = 184375
37
+ # milestones = [163889, 177546]
38
+
39
+ # 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
40
+ train.max_iter = 61458
41
+ milestones = [54629, 59182]
42
+
43
+ lr_multiplier = L(WarmupParamScheduler)(
44
+ scheduler=L(MultiStepParamScheduler)(
45
+ values=[1.0, 0.1, 0.01],
46
+ milestones=milestones,
47
+ num_updates=train.max_iter,
48
+ ),
49
+ warmup_length=250 / train.max_iter,
50
+ warmup_factor=0.001,
51
+ )
52
+
53
+ # Optimizer
54
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
55
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
56
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from fvcore.common.param_scheduler import MultiStepParamScheduler
3
+
4
+ from detectron2 import model_zoo
5
+ from detectron2.config import LazyCall as L
6
+ from detectron2.solver import WarmupParamScheduler
7
+ from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
8
+ import os
9
+ from ..common.coco_loader_lsj import dataloader
10
+ from detectron2.data.datasets import register_coco_instances
11
+ from detectron2.config import CfgNode, LazyConfig
12
+ # model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
13
+ # model.backbone.square_pad = 512 # change input size to 512x512
14
+
15
+ cfg_file = "./models/mask_rcnn_vitdet_v2.py"
16
+ model = LazyConfig.load(cfg_file).model
17
+ model.backbone.square_pad = 512 # change input size to 512x512
18
+
19
+ # Initialization and trainer settings
20
+ train = model_zoo.get_config("common/train.py").train
21
+ train.amp.enabled = True
22
+ train.ddp.fp16_compression = True
23
+ train.init_checkpoint = (
24
+ #"detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True"
25
+ #"/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_16x16_no_clumping_mr0.98/checkpoint-799-encoder.pth"
26
+ "/ccn2/u/honglinc/cwm_checkpoints/mae_vitb/mae_pretrain_vit_base-encoder.pth"
27
+ )
28
+ train.output_dir = os.path.dirname(train.init_checkpoint) + "/coco_finetune_512_v3"
29
+
30
+ root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
31
+ register_coco_instances("cls_agnostic_coco", {},
32
+ os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"),
33
+ os.path.join(root, "coco/val2017")
34
+ )
35
+ dataloader.test.dataset.names = 'cls_agnostic_coco'
36
+ # Schedule
37
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
38
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
39
+ # train.max_iter = 184375
40
+ # milestones = [163889, 177546]
41
+
42
+ # 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
43
+ train.max_iter = 61458
44
+ milestones = [54629, 59182]
45
+
46
+ lr_multiplier = L(WarmupParamScheduler)(
47
+ scheduler=L(MultiStepParamScheduler)(
48
+ values=[1.0, 0.1, 0.01],
49
+ milestones=milestones,
50
+ num_updates=train.max_iter,
51
+ ),
52
+ warmup_length=250 / train.max_iter,
53
+ warmup_factor=0.001,
54
+ )
55
+
56
+ # Optimizer
57
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
58
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
59
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_dino.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from fvcore.common.param_scheduler import MultiStepParamScheduler
3
+
4
+ from detectron2 import model_zoo
5
+ from detectron2.config import LazyCall as L
6
+ from detectron2.config import CfgNode, LazyConfig
7
+ from detectron2.solver import WarmupParamScheduler
8
+
9
+ from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
10
+ import os
11
+ from ..common.coco_loader_lsj import dataloader
12
+
13
+ # model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
14
+ # model.backbone.square_pad = 512 # change input size to 512x512
15
+
16
+ cfg_file = "./models/mask_rcnn_cwm.py"
17
+ model = LazyConfig.load(cfg_file).model
18
+
19
+ # Initialization and trainer settings
20
+ train = model_zoo.get_config("common/train.py").train
21
+ train.amp.enabled = True
22
+ train.ddp.fp16_compression = True
23
+ train.init_checkpoint = (
24
+ '/home/honglinc/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth'
25
+ )
26
+ train.output_dir = '/ccn2/u/honglinc/cwm_checkpoints/dinov2_coco_finetune_512'
27
+
28
+ # model.backbone.net.window_size = 0
29
+ # model.backbone.net.window_block_indexes = []
30
+ # model.backbone.net.use_rel_pos = False
31
+ # model.backbone.net.drop_path_rate = 0.
32
+
33
+ # Schedule
34
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
35
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
36
+ # train.max_iter = 184375
37
+ # milestones = [163889, 177546]
38
+
39
+ # 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
40
+ train.max_iter = 61458
41
+ milestones = [54629, 59182]
42
+
43
+ lr_multiplier = L(WarmupParamScheduler)(
44
+ scheduler=L(MultiStepParamScheduler)(
45
+ values=[1.0, 0.1, 0.01],
46
+ milestones=milestones,
47
+ num_updates=train.max_iter,
48
+ ),
49
+ warmup_length=250 / train.max_iter,
50
+ warmup_factor=0.001,
51
+ )
52
+
53
+ # Optimizer
54
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
55
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
56
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
cwm/eval/Segmentation/archive/configs/mask_rcnn_vitdet_b_100ep_v2.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from fvcore.common.param_scheduler import MultiStepParamScheduler
3
+
4
+ from detectron2 import model_zoo
5
+ from detectron2.config import LazyCall as L
6
+ from detectron2.config import CfgNode, LazyConfig
7
+ from detectron2.solver import WarmupParamScheduler
8
+
9
+ from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
10
+ import os
11
+ from ..common.coco_loader_lsj import dataloader
12
+ from detectron2.data.datasets import register_coco_instances
13
+ # model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
14
+ # model.backbone.square_pad = 512 # change input size to 512x512
15
+
16
+ cfg_file = "./models/mask_rcnn_cwm.py"
17
+ model = LazyConfig.load(cfg_file).model
18
+
19
+ # Initialization and trainer settings
20
+ train = model_zoo.get_config("common/train.py").train
21
+ train.amp.enabled = True
22
+ train.ddp.fp16_compression = True
23
+ train.init_checkpoint = (
24
+ # "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True"
25
+ '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_16x16_no_clumping_mr0.90/checkpoint-799-encoder.pth'
26
+ )
27
+ train.output_dir = os.path.dirname(train.init_checkpoint) + "/coco_finetune_512_v3"
28
+ train.eval_period = 1e9
29
+
30
+ root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
31
+ register_coco_instances("cls_agnostic_coco", {},
32
+ os.path.join(root, "coco/annotations/coco_cls_agnostic_instances_val2017.json"),
33
+ os.path.join(root, "coco/val2017")
34
+ )
35
+ dataloader.test.dataset.names = 'cls_agnostic_coco'
36
+ # Schedule
37
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
38
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
39
+ # train.max_iter = 184375
40
+ # milestones = [163889, 177546]
41
+
42
+ # 50 ep = 30730 iters * 96 images/iter / 118000 images/ep
43
+ train.max_iter = 61458
44
+ milestones = [54629, 59182]
45
+
46
+ lr_multiplier = L(WarmupParamScheduler)(
47
+ scheduler=L(MultiStepParamScheduler)(
48
+ values=[1.0, 0.1, 0.01],
49
+ milestones=milestones,
50
+ num_updates=train.max_iter,
51
+ ),
52
+ warmup_length=250 / train.max_iter,
53
+ warmup_factor=0.001,
54
+ )
55
+
56
+ # Optimizer
57
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
58
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
59
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}