ZehanWang commited on
Commit
864ec44
1 Parent(s): ebca029

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__init__.py +0 -0
  2. src/__pycache__/__init__.cpython-310.pyc +0 -0
  3. src/dataset/__init__.py +97 -0
  4. src/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  5. src/dataset/__pycache__/base_depth_dataset.cpython-310.pyc +0 -0
  6. src/dataset/__pycache__/base_inpaint_dataset.cpython-310.pyc +0 -0
  7. src/dataset/__pycache__/depthanything_dataset.cpython-310.pyc +0 -0
  8. src/dataset/__pycache__/diode_dataset.cpython-310.pyc +0 -0
  9. src/dataset/__pycache__/eth3d_dataset.cpython-310.pyc +0 -0
  10. src/dataset/__pycache__/eval_base_dataset.cpython-310.pyc +0 -0
  11. src/dataset/__pycache__/hypersim_dataset.cpython-310.pyc +0 -0
  12. src/dataset/__pycache__/kitti_dataset.cpython-310.pyc +0 -0
  13. src/dataset/__pycache__/mixed_sampler.cpython-310.pyc +0 -0
  14. src/dataset/__pycache__/nyu_dataset.cpython-310.pyc +0 -0
  15. src/dataset/__pycache__/scannet_dataset.cpython-310.pyc +0 -0
  16. src/dataset/__pycache__/vkitti_dataset.cpython-310.pyc +0 -0
  17. src/dataset/base_depth_dataset.py +286 -0
  18. src/dataset/base_inpaint_dataset.py +280 -0
  19. src/dataset/depthanything_dataset.py +91 -0
  20. src/dataset/diode_dataset.py +91 -0
  21. src/dataset/eth3d_dataset.py +65 -0
  22. src/dataset/eval_base_dataset.py +283 -0
  23. src/dataset/hypersim_dataset.py +44 -0
  24. src/dataset/inpaint_dataset.py +286 -0
  25. src/dataset/kitti_dataset.py +124 -0
  26. src/dataset/mixed_sampler.py +149 -0
  27. src/dataset/nyu_dataset.py +61 -0
  28. src/dataset/scannet_dataset.py +44 -0
  29. src/dataset/vkitti_dataset.py +97 -0
  30. src/trainer/__init__.py +16 -0
  31. src/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  32. src/trainer/__pycache__/marigold_inpaint_trainer.cpython-310.pyc +0 -0
  33. src/trainer/__pycache__/marigold_trainer.cpython-310.pyc +0 -0
  34. src/trainer/__pycache__/marigold_xl_trainer.cpython-310.pyc +0 -0
  35. src/trainer/marigold_inpaint_trainer.py +665 -0
  36. src/trainer/marigold_trainer.py +968 -0
  37. src/trainer/marigold_xl_trainer.py +948 -0
  38. src/util/__pycache__/alignment.cpython-310.pyc +0 -0
  39. src/util/__pycache__/config_util.cpython-310.pyc +0 -0
  40. src/util/__pycache__/data_loader.cpython-310.pyc +0 -0
  41. src/util/__pycache__/depth_transform.cpython-310.pyc +0 -0
  42. src/util/__pycache__/logging_util.cpython-310.pyc +0 -0
  43. src/util/__pycache__/loss.cpython-310.pyc +0 -0
  44. src/util/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
  45. src/util/__pycache__/metric.cpython-310.pyc +0 -0
  46. src/util/__pycache__/multi_res_noise.cpython-310.pyc +0 -0
  47. src/util/__pycache__/seeding.cpython-310.pyc +0 -0
  48. src/util/__pycache__/slurm_util.cpython-310.pyc +0 -0
  49. src/util/alignment.py +72 -0
  50. src/util/config_util.py +49 -0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
src/dataset/__init__.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-16
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import os
24
+ import pdb
25
+
26
+ from .base_depth_dataset import BaseDepthDataset # noqa: F401
27
+ from .eval_base_dataset import EvaluateBaseDataset, DatasetMode, get_pred_name
28
+ from .diode_dataset import DIODEDataset
29
+ from .eth3d_dataset import ETH3DDataset
30
+ from .hypersim_dataset import HypersimDataset
31
+ from .kitti_dataset import KITTIDataset
32
+ from .nyu_dataset import NYUDataset
33
+ from .scannet_dataset import ScanNetDataset
34
+ from .vkitti_dataset import VirtualKITTIDataset
35
+ from .depthanything_dataset import DepthAnythingDataset
36
+ from .base_inpaint_dataset import BaseInpaintDataset
37
+
38
+ dataset_name_class_dict = {
39
+ "hypersim": HypersimDataset,
40
+ "vkitti": VirtualKITTIDataset,
41
+ "nyu_v2": NYUDataset,
42
+ "kitti": KITTIDataset,
43
+ "eth3d": ETH3DDataset,
44
+ "diode": DIODEDataset,
45
+ "scannet": ScanNetDataset,
46
+ 'depthanything': DepthAnythingDataset,
47
+ 'inpainting': BaseInpaintDataset
48
+ }
49
+
50
+
51
+ def get_dataset(
52
+ cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs
53
+ ):
54
+ if "mixed" == cfg_data_split.name:
55
+ # assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets."
56
+ dataset_ls = [
57
+ get_dataset(_cfg, base_data_dir, mode, **kwargs)
58
+ for _cfg in cfg_data_split.dataset_list
59
+ ]
60
+ return dataset_ls
61
+ elif cfg_data_split.name in dataset_name_class_dict.keys():
62
+ dataset_class = dataset_name_class_dict[cfg_data_split.name]
63
+ dataset = dataset_class(
64
+ mode=mode,
65
+ filename_ls_path=cfg_data_split.filenames,
66
+ dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir),
67
+ **cfg_data_split,
68
+ **kwargs,
69
+ )
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ return dataset
74
+
75
+ def get_eval_dataset(
76
+ cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs
77
+ ) -> EvaluateBaseDataset:
78
+ if "mixed" == cfg_data_split.name:
79
+ assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets."
80
+ dataset_ls = [
81
+ get_dataset(_cfg, base_data_dir, mode, **kwargs)
82
+ for _cfg in cfg_data_split.dataset_list
83
+ ]
84
+ return dataset_ls
85
+ elif cfg_data_split.name in dataset_name_class_dict.keys():
86
+ dataset_class = dataset_name_class_dict[cfg_data_split.name]
87
+ dataset = dataset_class(
88
+ mode=mode,
89
+ filename_ls_path=cfg_data_split.filenames,
90
+ dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir),
91
+ **cfg_data_split,
92
+ **kwargs,
93
+ )
94
+ else:
95
+ raise NotImplementedError
96
+
97
+ return dataset
src/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
src/dataset/__pycache__/base_depth_dataset.cpython-310.pyc ADDED
Binary file (7.92 kB). View file
 
src/dataset/__pycache__/base_inpaint_dataset.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
src/dataset/__pycache__/depthanything_dataset.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
src/dataset/__pycache__/diode_dataset.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
src/dataset/__pycache__/eth3d_dataset.cpython-310.pyc ADDED
Binary file (1.4 kB). View file
 
src/dataset/__pycache__/eval_base_dataset.cpython-310.pyc ADDED
Binary file (7.64 kB). View file
 
src/dataset/__pycache__/hypersim_dataset.cpython-310.pyc ADDED
Binary file (957 Bytes). View file
 
src/dataset/__pycache__/kitti_dataset.cpython-310.pyc ADDED
Binary file (3.35 kB). View file
 
src/dataset/__pycache__/mixed_sampler.cpython-310.pyc ADDED
Binary file (3.95 kB). View file
 
src/dataset/__pycache__/nyu_dataset.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
src/dataset/__pycache__/scannet_dataset.cpython-310.pyc ADDED
Binary file (946 Bytes). View file
 
src/dataset/__pycache__/vkitti_dataset.cpython-310.pyc ADDED
Binary file (2.63 kB). View file
 
src/dataset/base_depth_dataset.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-30
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+ import glob
23
+ import io
24
+ import json
25
+ import os
26
+ import pdb
27
+ import random
28
+ import tarfile
29
+ from enum import Enum
30
+ from typing import Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from torch.utils.data import Dataset
36
+ from torchvision.transforms import InterpolationMode, Resize, CenterCrop
37
+ import torchvision.transforms as transforms
38
+ from transformers import CLIPTextModel, CLIPTokenizer
39
+ from src.util.depth_transform import DepthNormalizerBase
40
+ import random
41
+
42
+ from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode
43
+
44
+
45
+ def read_image_from_tar(tar_obj, img_rel_path):
46
+ image = tar_obj.extractfile("./" + img_rel_path)
47
+ image = image.read()
48
+ image = Image.open(io.BytesIO(image))
49
+
50
+
51
+ class BaseDepthDataset(Dataset):
52
+ def __init__(
53
+ self,
54
+ mode: DatasetMode,
55
+ filename_ls_path: str,
56
+ dataset_dir: str,
57
+ disp_name: str,
58
+ min_depth: float,
59
+ max_depth: float,
60
+ has_filled_depth: bool,
61
+ name_mode: DepthFileNameMode,
62
+ depth_transform: Union[DepthNormalizerBase, None] = None,
63
+ tokenizer: CLIPTokenizer = None,
64
+ augmentation_args: dict = None,
65
+ resize_to_hw=None,
66
+ move_invalid_to_far_plane: bool = True,
67
+ rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1],
68
+ **kwargs,
69
+ ) -> None:
70
+ super().__init__()
71
+ self.mode = mode
72
+ # dataset info
73
+ self.filename_ls_path = filename_ls_path
74
+ self.disp_name = disp_name
75
+ self.has_filled_depth = has_filled_depth
76
+ self.name_mode: DepthFileNameMode = name_mode
77
+ self.min_depth = min_depth
78
+ self.max_depth = max_depth
79
+ # training arguments
80
+ self.depth_transform: DepthNormalizerBase = depth_transform
81
+ self.augm_args = augmentation_args
82
+ self.resize_to_hw = resize_to_hw
83
+ self.rgb_transform = rgb_transform
84
+ self.move_invalid_to_far_plane = move_invalid_to_far_plane
85
+ self.tokenizer = tokenizer
86
+ # Load filenames
87
+ self.filenames = []
88
+ filename_paths = glob.glob(self.filename_ls_path)
89
+ for path in filename_paths:
90
+ with open(path, "r") as f:
91
+ self.filenames += json.load(f)
92
+ # Tar dataset
93
+ self.tar_obj = None
94
+ self.is_tar = (
95
+ True
96
+ if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir)
97
+ else False
98
+ )
99
+
100
+ def __len__(self):
101
+ return len(self.filenames)
102
+
103
+ def __getitem__(self, index):
104
+ rasters, other = self._get_data_item(index)
105
+ if DatasetMode.TRAIN == self.mode:
106
+ rasters = self._training_preprocess(rasters)
107
+ # merge
108
+ outputs = rasters
109
+ outputs.update(other)
110
+ return outputs
111
+
112
+ def _get_data_item(self, index):
113
+ rgb_path = self.filenames[index]['rgb_path']
114
+ depth_path = self.filenames[index]['depth_path']
115
+ mask_path = None
116
+ if 'valid_mask' in self.filenames[index]:
117
+ mask_path = self.filenames[index]['valid_mask']
118
+ if self.filenames[index]['caption'] is not None:
119
+ coca_caption = self.filenames[index]['caption']['coca_caption']
120
+ spatial_caption = self.filenames[index]['caption']['spatial_caption']
121
+ empty_caption = ''
122
+ caption_choices = [coca_caption, spatial_caption, empty_caption]
123
+ probabilities = [0.4, 0.4, 0.2]
124
+ caption = random.choices(caption_choices, probabilities)[0]
125
+ else:
126
+ caption = ''
127
+
128
+ rasters = {}
129
+ # RGB data
130
+ rasters.update(self._load_rgb_data(rgb_path))
131
+
132
+ # Depth data
133
+ if DatasetMode.RGB_ONLY != self.mode and depth_path is not None:
134
+ # load data
135
+ depth_data = self._load_depth_data(depth_path)
136
+ rasters.update(depth_data)
137
+ # valid mask
138
+ if mask_path is not None:
139
+ valid_mask_raw = Image.open(mask_path)
140
+ valid_mask_filled = Image.open(mask_path)
141
+ rasters["valid_mask_raw"] = torch.from_numpy(np.asarray(valid_mask_raw)).unsqueeze(0).bool()
142
+ rasters["valid_mask_filled"] = torch.from_numpy(np.asarray(valid_mask_filled)).unsqueeze(0).bool()
143
+ else:
144
+ rasters["valid_mask_raw"] = self._get_valid_mask(
145
+ rasters["depth_raw_linear"]
146
+ ).clone()
147
+ rasters["valid_mask_filled"] = self._get_valid_mask(
148
+ rasters["depth_filled_linear"]
149
+ ).clone()
150
+
151
+ other = {"index": index, "rgb_path": rgb_path, 'text': caption}
152
+
153
+ if self.resize_to_hw is not None:
154
+ resize_transform = transforms.Compose([
155
+ Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT),
156
+ CenterCrop(size=self.resize_to_hw)])
157
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
158
+
159
+ return rasters, other
160
+
161
+ def _load_rgb_data(self, rgb_path):
162
+ # Read RGB data
163
+ rgb = self._read_rgb_file(rgb_path)
164
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
165
+
166
+ outputs = {
167
+ "rgb_int": torch.from_numpy(rgb).int(),
168
+ "rgb_norm": torch.from_numpy(rgb_norm).float(),
169
+ }
170
+ return outputs
171
+
172
+ def _load_depth_data(self, depth_path, filled_rel_path=None):
173
+ # Read depth data
174
+ outputs = {}
175
+ depth_raw = self._read_depth_file(depth_path).squeeze()
176
+ depth_raw_linear = torch.from_numpy(depth_raw.copy()).float().unsqueeze(0) # [1, H, W]
177
+ outputs["depth_raw_linear"] = depth_raw_linear.clone()
178
+
179
+ if self.has_filled_depth:
180
+ depth_filled = self._read_depth_file(filled_rel_path).squeeze()
181
+ depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0)
182
+ outputs["depth_filled_linear"] = depth_filled_linear
183
+ else:
184
+ outputs["depth_filled_linear"] = depth_raw_linear.clone()
185
+
186
+ return outputs
187
+
188
+ def _get_data_path(self, index):
189
+ filename_line = self.filenames[index]
190
+
191
+ # Get data path
192
+ rgb_rel_path = filename_line[0]
193
+
194
+ depth_rel_path, text_rel_path = None, None
195
+ if DatasetMode.RGB_ONLY != self.mode:
196
+ depth_rel_path = filename_line[1]
197
+ if len(filename_line) > 2:
198
+ text_rel_path = filename_line[2]
199
+ return rgb_rel_path, depth_rel_path, text_rel_path
200
+
201
+ def _read_image(self, img_path) -> np.ndarray:
202
+ image_to_read = img_path
203
+ image = Image.open(image_to_read) # [H, W, rgb]
204
+ image = np.asarray(image)
205
+ return image
206
+
207
+ def _read_rgb_file(self, path) -> np.ndarray:
208
+ rgb = self._read_image(path)
209
+ rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W]
210
+ return rgb
211
+
212
+ def _read_depth_file(self, path):
213
+ depth_in = self._read_image(path)
214
+ # Replace code below to decode depth according to dataset definition
215
+ depth_decoded = depth_in
216
+ return depth_decoded
217
+
218
+ def _get_valid_mask(self, depth: torch.Tensor):
219
+ valid_mask = torch.logical_and(
220
+ (depth > self.min_depth), (depth < self.max_depth)
221
+ ).bool()
222
+ return valid_mask
223
+
224
+ def _training_preprocess(self, rasters):
225
+ # Augmentation
226
+ if self.augm_args is not None:
227
+ rasters = self._augment_data(rasters)
228
+
229
+ # Normalization
230
+ # rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0
231
+ # rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0
232
+
233
+ rasters["depth_raw_norm"] = self.depth_transform(
234
+ rasters["depth_raw_linear"], rasters["valid_mask_raw"]
235
+ ).clone()
236
+ rasters["depth_filled_norm"] = self.depth_transform(
237
+ rasters["depth_filled_linear"], rasters["valid_mask_filled"]
238
+ ).clone()
239
+
240
+ # Set invalid pixel to far plane
241
+ if self.move_invalid_to_far_plane:
242
+ if self.depth_transform.far_plane_at_max:
243
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
244
+ self.depth_transform.norm_max
245
+ )
246
+ else:
247
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
248
+ self.depth_transform.norm_min
249
+ )
250
+
251
+ # Resize
252
+ if self.resize_to_hw is not None:
253
+ resize_transform = transforms.Compose([
254
+ Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT),
255
+ CenterCrop(size=self.resize_to_hw)])
256
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
257
+ return rasters
258
+
259
+ def _augment_data(self, rasters_dict):
260
+ # lr flipping
261
+ lr_flip_p = self.augm_args.lr_flip_p
262
+ if random.random() < lr_flip_p:
263
+ rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()}
264
+
265
+ return rasters_dict
266
+
267
+ def __del__(self):
268
+ if hasattr(self, "tar_obj") and self.tar_obj is not None:
269
+ self.tar_obj.close()
270
+ self.tar_obj = None
271
+
272
+ def get_pred_name(rgb_basename, name_mode, suffix=".png"):
273
+ if DepthFileNameMode.rgb_id == name_mode:
274
+ pred_basename = "pred_" + rgb_basename.split("_")[1]
275
+ elif DepthFileNameMode.i_d_rgb == name_mode:
276
+ pred_basename = rgb_basename.replace("_rgb.", "_pred.")
277
+ elif DepthFileNameMode.id == name_mode:
278
+ pred_basename = "pred_" + rgb_basename
279
+ elif DepthFileNameMode.rgb_i_d == name_mode:
280
+ pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:])
281
+ else:
282
+ raise NotImplementedError
283
+ # change suffix
284
+ pred_basename = os.path.splitext(pred_basename)[0] + suffix
285
+
286
+ return pred_basename
src/dataset/base_inpaint_dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-30
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+ import glob
23
+ import io
24
+ import json
25
+ import os
26
+ import pdb
27
+ import random
28
+ import tarfile
29
+ from enum import Enum
30
+ from typing import Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from torch.utils.data import Dataset
36
+ from torchvision.transforms import InterpolationMode, Resize, CenterCrop
37
+ import torchvision.transforms as transforms
38
+ from transformers import CLIPTextModel, CLIPTokenizer
39
+ from src.util.depth_transform import DepthNormalizerBase
40
+ import random
41
+
42
+ from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode
43
+ from pycocotools import mask as coco_mask
44
+ from scipy.ndimage import gaussian_filter
45
+
46
+ def read_image_from_tar(tar_obj, img_rel_path):
47
+ image = tar_obj.extractfile("./" + img_rel_path)
48
+ image = image.read()
49
+ image = Image.open(io.BytesIO(image))
50
+
51
+
52
+ class BaseInpaintDataset(Dataset):
53
+ def __init__(
54
+ self,
55
+ mode: DatasetMode,
56
+ filename_ls_path: str,
57
+ dataset_dir: str,
58
+ disp_name: str,
59
+ depth_transform: Union[DepthNormalizerBase, None] = None,
60
+ tokenizer: CLIPTokenizer = None,
61
+ augmentation_args: dict = None,
62
+ resize_to_hw=None,
63
+ move_invalid_to_far_plane: bool = True,
64
+ rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1],
65
+ **kwargs,
66
+ ) -> None:
67
+ super().__init__()
68
+ self.mode = mode
69
+ # dataset info
70
+ self.filename_ls_path = filename_ls_path
71
+ self.disp_name = disp_name
72
+ # training arguments
73
+ self.depth_transform: DepthNormalizerBase = depth_transform
74
+ self.augm_args = augmentation_args
75
+ self.resize_to_hw = resize_to_hw
76
+ self.rgb_transform = rgb_transform
77
+ self.move_invalid_to_far_plane = move_invalid_to_far_plane
78
+ self.tokenizer = tokenizer
79
+ # Load filenames
80
+ self.filenames = []
81
+ filename_paths = glob.glob(self.filename_ls_path)
82
+ for path in filename_paths:
83
+ with open(path, "r") as f:
84
+ self.filenames += json.load(f)
85
+ # Tar dataset
86
+ self.tar_obj = None
87
+ self.is_tar = (
88
+ True
89
+ if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir)
90
+ else False
91
+ )
92
+
93
+ def __len__(self):
94
+ return len(self.filenames)
95
+
96
+ def __getitem__(self, index):
97
+ rasters, other = self._get_data_item(index)
98
+ if DatasetMode.TRAIN == self.mode:
99
+ rasters = self._training_preprocess(rasters)
100
+ # merge
101
+ outputs = rasters
102
+ outputs.update(other)
103
+ return outputs
104
+
105
+ def _get_data_item(self, index):
106
+ rgb_path = self.filenames[index]['rgb_path']
107
+ mask_path = None
108
+ if 'valid_mask' in self.filenames[index]:
109
+ mask_path = self.filenames[index]['valid_mask']
110
+ if self.filenames[index]['caption'] is not None:
111
+ coca_caption = self.filenames[index]['caption']['coca_caption']
112
+ spatial_caption = self.filenames[index]['caption']['spatial_caption']
113
+ empty_caption = ''
114
+ caption_choices = [coca_caption, spatial_caption, empty_caption]
115
+ probabilities = [0.4, 0.4, 0.2]
116
+ caption = random.choices(caption_choices, probabilities)[0]
117
+ else:
118
+ caption = ''
119
+
120
+ rasters = {}
121
+ # RGB data
122
+ rasters.update(self._load_rgb_data(rgb_path))
123
+
124
+ try:
125
+ anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations']
126
+ random.shuffle(anno)
127
+ object_num = random.randint(5, 10)
128
+ mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8)
129
+ for single_anno in (anno[0:object_num] if len(anno)>object_num else anno):
130
+ mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8)
131
+ except:
132
+ mask = None
133
+
134
+ a = random.random()
135
+ if a < 0.1 or mask is None:
136
+ mask = np.zeros(rasters['rgb_int'].shape[-2:])
137
+ rows, cols = mask.shape
138
+ grid_size = random.randint(5, 14)
139
+ grid_rows, grid_cols = rows // grid_size, cols // grid_size
140
+ for i in range(grid_rows):
141
+ for j in range(grid_cols):
142
+ random_prob = np.random.rand()
143
+ if random_prob < 0.2:
144
+ row_start = i * grid_size
145
+ row_end = (i + 1) * grid_size
146
+ col_start = j * grid_size
147
+ col_end = (j + 1) * grid_size
148
+ mask[row_start:row_end, col_start:col_end] = 1
149
+
150
+ rasters['mask'] = torch.from_numpy(mask).unsqueeze(0).to(torch.float32)
151
+
152
+ if self.resize_to_hw is not None:
153
+ resize_transform = transforms.Compose([
154
+ Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT),
155
+ CenterCrop(size=self.resize_to_hw)])
156
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
157
+
158
+ # mask = torch.zeros(rasters['rgb_int'].shape[-2:])
159
+ # rows, cols = mask.shape
160
+ # grid_size = random.randint(3, 10)
161
+ # grid_rows, grid_cols = rows // grid_size, cols // grid_size
162
+ # for i in range(grid_rows):
163
+ # for j in range(grid_cols):
164
+ # random_prob = np.random.rand()
165
+ # if random_prob < 0.5:
166
+ # row_start = i * grid_size
167
+ # row_end = (i + 1) * grid_size
168
+ # col_start = j * grid_size
169
+ # col_end = (j + 1) * grid_size
170
+ # mask[row_start:row_end, col_start:col_end] = 1
171
+
172
+ # rasters['mask'] = mask.unsqueeze(0)
173
+
174
+ other = {"index": index, "rgb_path": rgb_path, 'text': caption}
175
+ return rasters, other
176
+
177
+ def _load_rgb_data(self, rgb_path):
178
+ # Read RGB data
179
+ rgb = self._read_rgb_file(rgb_path)
180
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
181
+
182
+ outputs = {
183
+ "rgb_int": torch.from_numpy(rgb).int(),
184
+ "rgb_norm": torch.from_numpy(rgb_norm).float(),
185
+ }
186
+ return outputs
187
+
188
+ def _get_data_path(self, index):
189
+ filename_line = self.filenames[index]
190
+
191
+ # Get data path
192
+ rgb_rel_path = filename_line[0]
193
+
194
+ depth_rel_path, text_rel_path = None, None
195
+ if DatasetMode.RGB_ONLY != self.mode:
196
+ depth_rel_path = filename_line[1]
197
+ if len(filename_line) > 2:
198
+ text_rel_path = filename_line[2]
199
+ return rgb_rel_path, depth_rel_path, text_rel_path
200
+
201
+ def _read_image(self, img_path) -> np.ndarray:
202
+ image_to_read = img_path
203
+ image = Image.open(image_to_read) # [H, W, rgb]
204
+ image = np.asarray(image)
205
+ return image
206
+
207
+ def _read_rgb_file(self, path) -> np.ndarray:
208
+ rgb = self._read_image(path)
209
+ rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W]
210
+ return rgb
211
+
212
+ def _read_depth_file(self, path):
213
+ depth_in = self._read_image(path)
214
+ # Replace code below to decode depth according to dataset definition
215
+ depth_decoded = depth_in
216
+ return depth_decoded
217
+
218
+ def _training_preprocess(self, rasters):
219
+ # Augmentation
220
+ if self.augm_args is not None:
221
+ rasters = self._augment_data(rasters)
222
+
223
+ # Normalization
224
+ # rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0
225
+ # rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0
226
+
227
+ rasters["depth_raw_norm"] = self.depth_transform(
228
+ rasters["depth_raw_linear"], rasters["valid_mask_raw"]
229
+ ).clone()
230
+ rasters["depth_filled_norm"] = self.depth_transform(
231
+ rasters["depth_filled_linear"], rasters["valid_mask_filled"]
232
+ ).clone()
233
+
234
+ # Set invalid pixel to far plane
235
+ if self.move_invalid_to_far_plane:
236
+ if self.depth_transform.far_plane_at_max:
237
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
238
+ self.depth_transform.norm_max
239
+ )
240
+ else:
241
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
242
+ self.depth_transform.norm_min
243
+ )
244
+
245
+ # Resize
246
+ if self.resize_to_hw is not None:
247
+ resize_transform = transforms.Compose([
248
+ Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT),
249
+ CenterCrop(size=self.resize_to_hw)])
250
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
251
+ return rasters
252
+
253
+ def _augment_data(self, rasters_dict):
254
+ # lr flipping
255
+ lr_flip_p = self.augm_args.lr_flip_p
256
+ if random.random() < lr_flip_p:
257
+ rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()}
258
+
259
+ return rasters_dict
260
+
261
+ def __del__(self):
262
+ if hasattr(self, "tar_obj") and self.tar_obj is not None:
263
+ self.tar_obj.close()
264
+ self.tar_obj = None
265
+
266
+ def get_pred_name(rgb_basename, name_mode, suffix=".png"):
267
+ if DepthFileNameMode.rgb_id == name_mode:
268
+ pred_basename = "pred_" + rgb_basename.split("_")[1]
269
+ elif DepthFileNameMode.i_d_rgb == name_mode:
270
+ pred_basename = rgb_basename.replace("_rgb.", "_pred.")
271
+ elif DepthFileNameMode.id == name_mode:
272
+ pred_basename = "pred_" + rgb_basename
273
+ elif DepthFileNameMode.rgb_i_d == name_mode:
274
+ pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:])
275
+ else:
276
+ raise NotImplementedError
277
+ # change suffix
278
+ pred_basename = os.path.splitext(pred_basename)[0] + suffix
279
+
280
+ return pred_basename
src/dataset/depthanything_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
24
+ import torch
25
+ from torchvision.transforms import InterpolationMode, Resize, CenterCrop
26
+ import torchvision.transforms as transforms
27
+
28
+ class DepthAnythingDataset(BaseDepthDataset):
29
+ def __init__(
30
+ self,
31
+ **kwargs,
32
+ ) -> None:
33
+ super().__init__(
34
+ # ScanNet data parameter
35
+ min_depth=-1,
36
+ max_depth=256,
37
+ has_filled_depth=False,
38
+ name_mode=DepthFileNameMode.id,
39
+ **kwargs,
40
+ )
41
+
42
+ def _read_depth_file(self, rel_path):
43
+ depth_in = self._read_image(rel_path)
44
+ # Decode ScanNet depth
45
+ # depth_decoded = depth_in / 1000.0
46
+ return depth_in
47
+
48
+ def _training_preprocess(self, rasters):
49
+ # Augmentation
50
+ if self.augm_args is not None:
51
+ rasters = self._augment_data(rasters)
52
+
53
+ # Normalization
54
+ rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0
55
+ rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0
56
+
57
+ # Set invalid pixel to far plane
58
+ if self.move_invalid_to_far_plane:
59
+ if self.depth_transform.far_plane_at_max:
60
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
61
+ self.depth_transform.norm_max
62
+ )
63
+ else:
64
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
65
+ self.depth_transform.norm_min
66
+ )
67
+
68
+ # Resize
69
+ if self.resize_to_hw is not None:
70
+ T = transforms.Compose([
71
+ Resize(self.resize_to_hw[0]),
72
+ CenterCrop(self.resize_to_hw),
73
+ ])
74
+ rasters = {k: T(v) for k, v in rasters.items()}
75
+ return rasters
76
+
77
+ # def _load_depth_data(self, depth_rel_path, filled_rel_path):
78
+ # # Read depth data
79
+ # outputs = {}
80
+ # depth_raw = self._read_depth_file(depth_rel_path).squeeze()
81
+ # depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] [0, 255]
82
+ # outputs["depth_raw_linear"] = depth_raw_linear.clone()
83
+ #
84
+ # if self.has_filled_depth:
85
+ # depth_filled = self._read_depth_file(filled_rel_path).squeeze()
86
+ # depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0)
87
+ # outputs["depth_filled_linear"] = depth_filled_linear
88
+ # else:
89
+ # outputs["depth_filled_linear"] = depth_raw_linear.clone()
90
+ #
91
+ # return outputs
src/dataset/diode_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-26
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import os
24
+ import tarfile
25
+ from io import BytesIO
26
+
27
+ import numpy as np
28
+ import torch
29
+
30
+ from .eval_base_dataset import EvaluateBaseDataset, DepthFileNameMode, DatasetMode
31
+
32
+
33
+ class DIODEDataset(EvaluateBaseDataset):
34
+ def __init__(
35
+ self,
36
+ **kwargs,
37
+ ) -> None:
38
+ super().__init__(
39
+ # DIODE data parameter
40
+ min_depth=0.6,
41
+ max_depth=350,
42
+ has_filled_depth=False,
43
+ name_mode=DepthFileNameMode.id,
44
+ **kwargs,
45
+ )
46
+
47
+ def _read_npy_file(self, rel_path):
48
+ if self.is_tar:
49
+ if self.tar_obj is None:
50
+ self.tar_obj = tarfile.open(self.dataset_dir)
51
+ fileobj = self.tar_obj.extractfile("./" + rel_path)
52
+ npy_path_or_content = BytesIO(fileobj.read())
53
+ else:
54
+ npy_path_or_content = os.path.join(self.dataset_dir, rel_path)
55
+ data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :]
56
+ return data
57
+
58
+ def _read_depth_file(self, rel_path):
59
+ depth = self._read_npy_file(rel_path)
60
+ return depth
61
+
62
+ def _get_data_path(self, index):
63
+ return self.filenames[index]
64
+
65
+ def _get_data_item(self, index):
66
+ # Special: depth mask is read from data
67
+
68
+ rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index)
69
+
70
+ rasters = {}
71
+
72
+ # RGB data
73
+ rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
74
+
75
+ # Depth data
76
+ if DatasetMode.RGB_ONLY != self.mode:
77
+ # load data
78
+ depth_data = self._load_depth_data(
79
+ depth_rel_path=depth_rel_path, filled_rel_path=None
80
+ )
81
+ rasters.update(depth_data)
82
+
83
+ # valid mask
84
+ mask = self._read_npy_file(mask_rel_path).astype(bool)
85
+ mask = torch.from_numpy(mask).bool()
86
+ rasters["valid_mask_raw"] = mask.clone()
87
+ rasters["valid_mask_filled"] = mask.clone()
88
+
89
+ other = {"index": index, "rgb_relative_path": rgb_rel_path}
90
+
91
+ return rasters, other
src/dataset/eth3d_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import torch
24
+ import tarfile
25
+ import os
26
+ import numpy as np
27
+
28
+ from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset
29
+
30
+
31
+ class ETH3DDataset(EvaluateBaseDataset):
32
+ HEIGHT, WIDTH = 4032, 6048
33
+
34
+ def __init__(
35
+ self,
36
+ **kwargs,
37
+ ) -> None:
38
+ super().__init__(
39
+ # ETH3D data parameter
40
+ min_depth=1e-5,
41
+ max_depth=torch.inf,
42
+ has_filled_depth=False,
43
+ name_mode=DepthFileNameMode.id,
44
+ **kwargs,
45
+ )
46
+
47
+ def _read_depth_file(self, rel_path):
48
+ # Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats
49
+ if self.is_tar:
50
+ if self.tar_obj is None:
51
+ self.tar_obj = tarfile.open(self.dataset_dir)
52
+ binary_data = self.tar_obj.extractfile("./" + rel_path)
53
+ binary_data = binary_data.read()
54
+
55
+ else:
56
+ depth_path = os.path.join(self.dataset_dir, rel_path)
57
+ with open(depth_path, "rb") as file:
58
+ binary_data = file.read()
59
+ # Convert the binary data to a numpy array of 32-bit floats
60
+ depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy()
61
+
62
+ depth_decoded[depth_decoded == torch.inf] = 0.0
63
+
64
+ depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH))
65
+ return depth_decoded
src/dataset/eval_base_dataset.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-30
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import io
24
+ import os
25
+ import random
26
+ import tarfile
27
+ from enum import Enum
28
+ from typing import Union
29
+
30
+ import numpy as np
31
+ import torch
32
+ from PIL import Image
33
+ from torch.utils.data import Dataset
34
+ from torchvision.transforms import InterpolationMode, Resize
35
+
36
+ from src.util.depth_transform import DepthNormalizerBase
37
+
38
+
39
+ class DatasetMode(Enum):
40
+ RGB_ONLY = "rgb_only"
41
+ EVAL = "evaluate"
42
+ TRAIN = "train"
43
+
44
+
45
+ class DepthFileNameMode(Enum):
46
+ """Prediction file naming modes"""
47
+
48
+ id = 1 # id.png
49
+ rgb_id = 2 # rgb_id.png
50
+ i_d_rgb = 3 # i_d_1_rgb.png
51
+ rgb_i_d = 4
52
+
53
+
54
+ def read_image_from_tar(tar_obj, img_rel_path):
55
+ image = tar_obj.extractfile("./" + img_rel_path)
56
+ image = image.read()
57
+ image = Image.open(io.BytesIO(image))
58
+
59
+
60
+ class EvaluateBaseDataset(Dataset):
61
+ def __init__(
62
+ self,
63
+ mode: DatasetMode,
64
+ filename_ls_path: str,
65
+ dataset_dir: str,
66
+ disp_name: str,
67
+ min_depth: float,
68
+ max_depth: float,
69
+ has_filled_depth: bool,
70
+ name_mode: DepthFileNameMode,
71
+ depth_transform: Union[DepthNormalizerBase, None] = None,
72
+ augmentation_args: dict = None,
73
+ resize_to_hw=None,
74
+ move_invalid_to_far_plane: bool = True,
75
+ rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1],
76
+ **kwargs,
77
+ ) -> None:
78
+ super().__init__()
79
+ self.mode = mode
80
+ # dataset info
81
+ self.filename_ls_path = filename_ls_path
82
+ self.dataset_dir = dataset_dir
83
+ assert os.path.exists(
84
+ self.dataset_dir
85
+ ), f"Dataset does not exist at: {self.dataset_dir}"
86
+ self.disp_name = disp_name
87
+ self.has_filled_depth = has_filled_depth
88
+ self.name_mode: DepthFileNameMode = name_mode
89
+ self.min_depth = min_depth
90
+ self.max_depth = max_depth
91
+
92
+ # training arguments
93
+ self.depth_transform: DepthNormalizerBase = depth_transform
94
+ self.augm_args = augmentation_args
95
+ self.resize_to_hw = resize_to_hw
96
+ self.rgb_transform = rgb_transform
97
+ self.move_invalid_to_far_plane = move_invalid_to_far_plane
98
+
99
+ # Load filenames
100
+ with open(self.filename_ls_path, "r") as f:
101
+ self.filenames = [
102
+ s.split() for s in f.readlines()
103
+ ] # [['rgb.png', 'depth.tif'], [], ...]
104
+
105
+ # Tar dataset
106
+ self.tar_obj = None
107
+ self.is_tar = (
108
+ True
109
+ if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir)
110
+ else False
111
+ )
112
+
113
+ def __len__(self):
114
+ return len(self.filenames)
115
+
116
+ def __getitem__(self, index):
117
+ rasters, other = self._get_data_item(index)
118
+ if DatasetMode.TRAIN == self.mode:
119
+ rasters = self._training_preprocess(rasters)
120
+ # merge
121
+ outputs = rasters
122
+ outputs.update(other)
123
+ return outputs
124
+
125
+ def _get_data_item(self, index):
126
+ rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index)
127
+
128
+ rasters = {}
129
+
130
+ # RGB data
131
+ rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
132
+
133
+ # Depth data
134
+ if DatasetMode.RGB_ONLY != self.mode:
135
+ # load data
136
+ depth_data = self._load_depth_data(
137
+ depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path
138
+ )
139
+ rasters.update(depth_data)
140
+ # valid mask
141
+ rasters["valid_mask_raw"] = self._get_valid_mask(
142
+ rasters["depth_raw_linear"]
143
+ ).clone()
144
+ rasters["valid_mask_filled"] = self._get_valid_mask(
145
+ rasters["depth_filled_linear"]
146
+ ).clone()
147
+
148
+ other = {"index": index, "rgb_relative_path": rgb_rel_path}
149
+
150
+ return rasters, other
151
+
152
+ def _load_rgb_data(self, rgb_rel_path):
153
+ # Read RGB data
154
+ rgb = self._read_rgb_file(rgb_rel_path)
155
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
156
+
157
+ outputs = {
158
+ "rgb_int": torch.from_numpy(rgb).int(),
159
+ "rgb_norm": torch.from_numpy(rgb_norm).float(),
160
+ }
161
+ return outputs
162
+
163
+ def _load_depth_data(self, depth_rel_path, filled_rel_path):
164
+ # Read depth data
165
+ outputs = {}
166
+ depth_raw = self._read_depth_file(depth_rel_path).squeeze()
167
+ depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W]
168
+ outputs["depth_raw_linear"] = depth_raw_linear.clone()
169
+
170
+ if self.has_filled_depth:
171
+ depth_filled = self._read_depth_file(filled_rel_path).squeeze()
172
+ depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0)
173
+ outputs["depth_filled_linear"] = depth_filled_linear
174
+ else:
175
+ outputs["depth_filled_linear"] = depth_raw_linear.clone()
176
+
177
+ return outputs
178
+
179
+ def _get_data_path(self, index):
180
+ filename_line = self.filenames[index]
181
+
182
+ # Get data path
183
+ rgb_rel_path = filename_line[0]
184
+
185
+ depth_rel_path, filled_rel_path = None, None
186
+ if DatasetMode.RGB_ONLY != self.mode:
187
+ depth_rel_path = filename_line[1]
188
+ if self.has_filled_depth:
189
+ filled_rel_path = filename_line[2]
190
+ return rgb_rel_path, depth_rel_path, filled_rel_path
191
+
192
+ def _read_image(self, img_rel_path) -> np.ndarray:
193
+ if self.is_tar:
194
+ if self.tar_obj is None:
195
+ self.tar_obj = tarfile.open(self.dataset_dir)
196
+ image_to_read = self.tar_obj.extractfile("./" + img_rel_path)
197
+ image_to_read = image_to_read.read()
198
+ image_to_read = io.BytesIO(image_to_read)
199
+ else:
200
+ image_to_read = os.path.join(self.dataset_dir, img_rel_path)
201
+ image = Image.open(image_to_read) # [H, W, rgb]
202
+ image = np.asarray(image)
203
+ return image
204
+
205
+ def _read_rgb_file(self, rel_path) -> np.ndarray:
206
+ rgb = self._read_image(rel_path)
207
+ rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W]
208
+ return rgb
209
+
210
+ def _read_depth_file(self, rel_path):
211
+ depth_in = self._read_image(rel_path)
212
+ # Replace code below to decode depth according to dataset definition
213
+ depth_decoded = depth_in
214
+
215
+ return depth_decoded
216
+
217
+ def _get_valid_mask(self, depth: torch.Tensor):
218
+ valid_mask = torch.logical_and(
219
+ (depth > self.min_depth), (depth < self.max_depth)
220
+ ).bool()
221
+ return valid_mask
222
+
223
+ def _training_preprocess(self, rasters):
224
+ # Augmentation
225
+ if self.augm_args is not None:
226
+ rasters = self._augment_data(rasters)
227
+
228
+ # Normalization
229
+ rasters["depth_raw_norm"] = self.depth_transform(
230
+ rasters["depth_raw_linear"], rasters["valid_mask_raw"]
231
+ ).clone()
232
+ rasters["depth_filled_norm"] = self.depth_transform(
233
+ rasters["depth_filled_linear"], rasters["valid_mask_filled"]
234
+ ).clone()
235
+
236
+ # Set invalid pixel to far plane
237
+ if self.move_invalid_to_far_plane:
238
+ if self.depth_transform.far_plane_at_max:
239
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
240
+ self.depth_transform.norm_max
241
+ )
242
+ else:
243
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
244
+ self.depth_transform.norm_min
245
+ )
246
+
247
+ # Resize
248
+ if self.resize_to_hw is not None:
249
+ resize_transform = Resize(
250
+ size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT
251
+ )
252
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
253
+
254
+ return rasters
255
+
256
+ def _augment_data(self, rasters_dict):
257
+ # lr flipping
258
+ lr_flip_p = self.augm_args.lr_flip_p
259
+ if random.random() < lr_flip_p:
260
+ rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()}
261
+
262
+ return rasters_dict
263
+
264
+ def __del__(self):
265
+ if hasattr(self, "tar_obj") and self.tar_obj is not None:
266
+ self.tar_obj.close()
267
+ self.tar_obj = None
268
+
269
+ def get_pred_name(rgb_basename, name_mode, suffix=".png"):
270
+ if DepthFileNameMode.rgb_id == name_mode:
271
+ pred_basename = "pred_" + rgb_basename.split("_")[1]
272
+ elif DepthFileNameMode.i_d_rgb == name_mode:
273
+ pred_basename = rgb_basename.replace("_rgb.", "_pred.")
274
+ elif DepthFileNameMode.id == name_mode:
275
+ pred_basename = "pred_" + rgb_basename
276
+ elif DepthFileNameMode.rgb_i_d == name_mode:
277
+ pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:])
278
+ else:
279
+ raise NotImplementedError
280
+ # change suffix
281
+ pred_basename = os.path.splitext(pred_basename)[0] + suffix
282
+
283
+ return pred_basename
src/dataset/hypersim_dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+
24
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
25
+
26
+ class HypersimDataset(BaseDepthDataset):
27
+ def __init__(
28
+ self,
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(
32
+ # Hypersim data parameter
33
+ min_depth=1e-5,
34
+ max_depth=65.0,
35
+ has_filled_depth=False,
36
+ name_mode=DepthFileNameMode.rgb_i_d,
37
+ **kwargs,
38
+ )
39
+
40
+ def _read_depth_file(self, rel_path):
41
+ depth_in = self._read_image(rel_path)
42
+ # Decode Hypersim depth
43
+ depth_decoded = depth_in / 1000.0
44
+ return depth_decoded
src/dataset/inpaint_dataset.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-30
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+ import glob
23
+ import io
24
+ import json
25
+ import os
26
+ import pdb
27
+ import random
28
+ import tarfile
29
+ from enum import Enum
30
+ from typing import Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from torch.utils.data import Dataset
36
+ from torchvision.transforms import InterpolationMode, Resize, CenterCrop
37
+ import torchvision.transforms as transforms
38
+ from transformers import CLIPTextModel, CLIPTokenizer
39
+ from src.util.depth_transform import DepthNormalizerBase
40
+ import random
41
+
42
+ from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode
43
+
44
+
45
+ def read_image_from_tar(tar_obj, img_rel_path):
46
+ image = tar_obj.extractfile("./" + img_rel_path)
47
+ image = image.read()
48
+ image = Image.open(io.BytesIO(image))
49
+
50
+
51
+ class BaseDepthDataset(Dataset):
52
+ def __init__(
53
+ self,
54
+ mode: DatasetMode,
55
+ filename_ls_path: str,
56
+ dataset_dir: str,
57
+ disp_name: str,
58
+ min_depth: float,
59
+ max_depth: float,
60
+ has_filled_depth: bool,
61
+ name_mode: DepthFileNameMode,
62
+ depth_transform: Union[DepthNormalizerBase, None] = None,
63
+ tokenizer: CLIPTokenizer = None,
64
+ augmentation_args: dict = None,
65
+ resize_to_hw=None,
66
+ move_invalid_to_far_plane: bool = True,
67
+ rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1],
68
+ **kwargs,
69
+ ) -> None:
70
+ super().__init__()
71
+ self.mode = mode
72
+ # dataset info
73
+ self.filename_ls_path = filename_ls_path
74
+ self.disp_name = disp_name
75
+ self.has_filled_depth = has_filled_depth
76
+ self.name_mode: DepthFileNameMode = name_mode
77
+ self.min_depth = min_depth
78
+ self.max_depth = max_depth
79
+ # training arguments
80
+ self.depth_transform: DepthNormalizerBase = depth_transform
81
+ self.augm_args = augmentation_args
82
+ self.resize_to_hw = resize_to_hw
83
+ self.rgb_transform = rgb_transform
84
+ self.move_invalid_to_far_plane = move_invalid_to_far_plane
85
+ self.tokenizer = tokenizer
86
+ # Load filenames
87
+ self.filenames = []
88
+ filename_paths = glob.glob(self.filename_ls_path)
89
+ for path in filename_paths:
90
+ with open(path, "r") as f:
91
+ self.filenames += json.load(f)
92
+ # Tar dataset
93
+ self.tar_obj = None
94
+ self.is_tar = (
95
+ True
96
+ if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir)
97
+ else False
98
+ )
99
+
100
+ def __len__(self):
101
+ return len(self.filenames)
102
+
103
+ def __getitem__(self, index):
104
+ rasters, other = self._get_data_item(index)
105
+ if DatasetMode.TRAIN == self.mode:
106
+ rasters = self._training_preprocess(rasters)
107
+ # merge
108
+ outputs = rasters
109
+ outputs.update(other)
110
+ return outputs
111
+
112
+ def _get_data_item(self, index):
113
+ rgb_path = self.filenames[index]['rgb_path']
114
+ depth_path = self.filenames[index]['depth_path']
115
+ mask_path = None
116
+ if 'valid_mask' in self.filenames[index]:
117
+ mask_path = self.filenames[index]['valid_mask']
118
+ if self.filenames[index]['caption'] is not None:
119
+ coca_caption = self.filenames[index]['caption']['coca_caption']
120
+ spatial_caption = self.filenames[index]['caption']['spatial_caption']
121
+ empty_caption = ''
122
+ caption_choices = [coca_caption, spatial_caption, empty_caption]
123
+ probabilities = [0.4, 0.4, 0.2]
124
+ caption = random.choices(caption_choices, probabilities)[0]
125
+ else:
126
+ caption = ''
127
+
128
+ rasters = {}
129
+ # RGB data
130
+ rasters.update(self._load_rgb_data(rgb_path))
131
+
132
+ # Depth data
133
+ if DatasetMode.RGB_ONLY != self.mode and depth_path is not None:
134
+ # load data
135
+ depth_data = self._load_depth_data(depth_path)
136
+ rasters.update(depth_data)
137
+ # valid mask
138
+ if mask_path is not None:
139
+ valid_mask_raw = Image.open(mask_path)
140
+ valid_mask_filled = Image.open(mask_path)
141
+ rasters["valid_mask_raw"] = torch.from_numpy(np.asarray(valid_mask_raw)).unsqueeze(0).bool()
142
+ rasters["valid_mask_filled"] = torch.from_numpy(np.asarray(valid_mask_filled)).unsqueeze(0).bool()
143
+ else:
144
+ rasters["valid_mask_raw"] = self._get_valid_mask(
145
+ rasters["depth_raw_linear"]
146
+ ).clone()
147
+ rasters["valid_mask_filled"] = self._get_valid_mask(
148
+ rasters["depth_filled_linear"]
149
+ ).clone()
150
+
151
+ other = {"index": index, "rgb_path": rgb_path, 'text': caption}
152
+
153
+ if self.resize_to_hw is not None:
154
+ resize_transform = transforms.Compose([
155
+ Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT),
156
+ CenterCrop(size=self.resize_to_hw)])
157
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
158
+
159
+ return rasters, other
160
+
161
+ def _load_rgb_data(self, rgb_path):
162
+ # Read RGB data
163
+ rgb = self._read_rgb_file(rgb_path)
164
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
165
+
166
+ outputs = {
167
+ "rgb_int": torch.from_numpy(rgb).int(),
168
+ "rgb_norm": torch.from_numpy(rgb_norm).float(),
169
+ }
170
+ return outputs
171
+
172
+ def _load_depth_data(self, depth_path, filled_rel_path=None):
173
+ # Read depth data
174
+ outputs = {}
175
+ depth_raw = self._read_depth_file(depth_path).squeeze()
176
+ depth_raw_linear = torch.from_numpy(depth_raw.copy()).float().unsqueeze(0) # [1, H, W]
177
+ outputs["depth_raw_linear"] = depth_raw_linear.clone()
178
+
179
+ if self.has_filled_depth:
180
+ depth_filled = self._read_depth_file(filled_rel_path).squeeze()
181
+ depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0)
182
+ outputs["depth_filled_linear"] = depth_filled_linear
183
+ else:
184
+ outputs["depth_filled_linear"] = depth_raw_linear.clone()
185
+
186
+ return outputs
187
+
188
+ def _get_data_path(self, index):
189
+ filename_line = self.filenames[index]
190
+
191
+ # Get data path
192
+ rgb_rel_path = filename_line[0]
193
+
194
+ depth_rel_path, text_rel_path = None, None
195
+ if DatasetMode.RGB_ONLY != self.mode:
196
+ depth_rel_path = filename_line[1]
197
+ if len(filename_line) > 2:
198
+ text_rel_path = filename_line[2]
199
+ return rgb_rel_path, depth_rel_path, text_rel_path
200
+
201
+ def _read_image(self, img_path) -> np.ndarray:
202
+ image_to_read = img_path
203
+ image = Image.open(image_to_read) # [H, W, rgb]
204
+ image = np.asarray(image)
205
+ return image
206
+
207
+ def _read_rgb_file(self, path) -> np.ndarray:
208
+ rgb = self._read_image(path)
209
+ rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W]
210
+ return rgb
211
+
212
+ def _read_depth_file(self, path):
213
+ depth_in = self._read_image(path)
214
+ # Replace code below to decode depth according to dataset definition
215
+ depth_decoded = depth_in
216
+ return depth_decoded
217
+
218
+ def _get_valid_mask(self, depth: torch.Tensor):
219
+ valid_mask = torch.logical_and(
220
+ (depth > self.min_depth), (depth < self.max_depth)
221
+ ).bool()
222
+ return valid_mask
223
+
224
+ def _training_preprocess(self, rasters):
225
+ # Augmentation
226
+ if self.augm_args is not None:
227
+ rasters = self._augment_data(rasters)
228
+
229
+ # Normalization
230
+ # rasters["depth_raw_norm"] = rasters["depth_raw_linear"] / 255.0 * 2.0 - 1.0
231
+ # rasters["depth_filled_norm"] = rasters["depth_filled_linear"] / 255.0 * 2.0 - 1.0
232
+
233
+ rasters["depth_raw_norm"] = self.depth_transform(
234
+ rasters["depth_raw_linear"], rasters["valid_mask_raw"]
235
+ ).clone()
236
+ rasters["depth_filled_norm"] = self.depth_transform(
237
+ rasters["depth_filled_linear"], rasters["valid_mask_filled"]
238
+ ).clone()
239
+
240
+ # Set invalid pixel to far plane
241
+ if self.move_invalid_to_far_plane:
242
+ if self.depth_transform.far_plane_at_max:
243
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
244
+ self.depth_transform.norm_max
245
+ )
246
+ else:
247
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
248
+ self.depth_transform.norm_min
249
+ )
250
+
251
+ # Resize
252
+ if self.resize_to_hw is not None:
253
+ resize_transform = transforms.Compose([
254
+ Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT),
255
+ CenterCrop(size=self.resize_to_hw)])
256
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
257
+ return rasters
258
+
259
+ def _augment_data(self, rasters_dict):
260
+ # lr flipping
261
+ lr_flip_p = self.augm_args.lr_flip_p
262
+ if random.random() < lr_flip_p:
263
+ rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()}
264
+
265
+ return rasters_dict
266
+
267
+ def __del__(self):
268
+ if hasattr(self, "tar_obj") and self.tar_obj is not None:
269
+ self.tar_obj.close()
270
+ self.tar_obj = None
271
+
272
+ def get_pred_name(rgb_basename, name_mode, suffix=".png"):
273
+ if DepthFileNameMode.rgb_id == name_mode:
274
+ pred_basename = "pred_" + rgb_basename.split("_")[1]
275
+ elif DepthFileNameMode.i_d_rgb == name_mode:
276
+ pred_basename = rgb_basename.replace("_rgb.", "_pred.")
277
+ elif DepthFileNameMode.id == name_mode:
278
+ pred_basename = "pred_" + rgb_basename
279
+ elif DepthFileNameMode.rgb_i_d == name_mode:
280
+ pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:])
281
+ else:
282
+ raise NotImplementedError
283
+ # change suffix
284
+ pred_basename = os.path.splitext(pred_basename)[0] + suffix
285
+
286
+ return pred_basename
src/dataset/kitti_dataset.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import torch
24
+
25
+ from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset
26
+
27
+
28
+ class KITTIDataset(EvaluateBaseDataset):
29
+ def __init__(
30
+ self,
31
+ kitti_bm_crop, # Crop to KITTI benchmark size
32
+ valid_mask_crop, # Evaluation mask. [None, garg or eigen]
33
+ **kwargs,
34
+ ) -> None:
35
+ super().__init__(
36
+ # KITTI data parameter
37
+ min_depth=1e-5,
38
+ max_depth=80,
39
+ has_filled_depth=False,
40
+ name_mode=DepthFileNameMode.id,
41
+ **kwargs,
42
+ )
43
+ self.kitti_bm_crop = kitti_bm_crop
44
+ self.valid_mask_crop = valid_mask_crop
45
+ assert self.valid_mask_crop in [
46
+ None,
47
+ "garg", # set evaluation mask according to Garg ECCV16
48
+ "eigen", # set evaluation mask according to Eigen NIPS14
49
+ ], f"Unknown crop type: {self.valid_mask_crop}"
50
+
51
+ # Filter out empty depth
52
+ self.filenames = [f for f in self.filenames if "None" != f[1]]
53
+
54
+ def _read_depth_file(self, rel_path):
55
+ depth_in = self._read_image(rel_path)
56
+ # Decode KITTI depth
57
+ depth_decoded = depth_in / 256.0
58
+ return depth_decoded
59
+
60
+ def _load_rgb_data(self, rgb_rel_path):
61
+ rgb_data = super()._load_rgb_data(rgb_rel_path)
62
+ if self.kitti_bm_crop:
63
+ rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()}
64
+ return rgb_data
65
+
66
+ def _load_depth_data(self, depth_rel_path, filled_rel_path):
67
+ depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path)
68
+ if self.kitti_bm_crop:
69
+ depth_data = {
70
+ k: self.kitti_benchmark_crop(v) for k, v in depth_data.items()
71
+ }
72
+ return depth_data
73
+
74
+ @staticmethod
75
+ def kitti_benchmark_crop(input_img):
76
+ """
77
+ Crop images to KITTI benchmark size
78
+ Args:
79
+ `input_img` (torch.Tensor): Input image to be cropped.
80
+
81
+ Returns:
82
+ torch.Tensor:Cropped image.
83
+ """
84
+ KB_CROP_HEIGHT = 352
85
+ KB_CROP_WIDTH = 1216
86
+
87
+ height, width = input_img.shape[-2:]
88
+ top_margin = int(height - KB_CROP_HEIGHT)
89
+ left_margin = int((width - KB_CROP_WIDTH) / 2)
90
+ if 2 == len(input_img.shape):
91
+ out = input_img[
92
+ top_margin : top_margin + KB_CROP_HEIGHT,
93
+ left_margin : left_margin + KB_CROP_WIDTH,
94
+ ]
95
+ elif 3 == len(input_img.shape):
96
+ out = input_img[
97
+ :,
98
+ top_margin : top_margin + KB_CROP_HEIGHT,
99
+ left_margin : left_margin + KB_CROP_WIDTH,
100
+ ]
101
+ return out
102
+
103
+ def _get_valid_mask(self, depth: torch.Tensor):
104
+ # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py
105
+ valid_mask = super()._get_valid_mask(depth) # [1, H, W]
106
+
107
+ if self.valid_mask_crop is not None:
108
+ eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
109
+ gt_height, gt_width = eval_mask.shape
110
+
111
+ if "garg" == self.valid_mask_crop:
112
+ eval_mask[
113
+ int(0.40810811 * gt_height) : int(0.99189189 * gt_height),
114
+ int(0.03594771 * gt_width) : int(0.96405229 * gt_width),
115
+ ] = 1
116
+ elif "eigen" == self.valid_mask_crop:
117
+ eval_mask[
118
+ int(0.3324324 * gt_height) : int(0.91351351 * gt_height),
119
+ int(0.0359477 * gt_width) : int(0.96405229 * gt_width),
120
+ ] = 1
121
+
122
+ eval_mask.reshape(valid_mask.shape)
123
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
124
+ return valid_mask
src/dataset/mixed_sampler.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-18
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import torch
24
+ from torch.utils.data import (
25
+ BatchSampler,
26
+ RandomSampler,
27
+ SequentialSampler,
28
+ )
29
+
30
+
31
+ class MixedBatchSampler(BatchSampler):
32
+ """Sample one batch from a selected dataset with given probability.
33
+ Compatible with datasets at different resolution
34
+ """
35
+
36
+ def __init__(
37
+ self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None
38
+ ):
39
+ self.base_sampler = None
40
+ self.batch_size = batch_size
41
+ self.shuffle = shuffle
42
+ self.drop_last = drop_last
43
+ self.generator = generator
44
+
45
+ self.src_dataset_ls = src_dataset_ls
46
+ self.n_dataset = len(self.src_dataset_ls)
47
+
48
+ # Dataset length
49
+ self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
50
+ self.cum_dataset_length = [
51
+ sum(self.dataset_length[:i]) for i in range(self.n_dataset)
52
+ ] # cumulative dataset length
53
+
54
+ # BatchSamplers for each source dataset
55
+ if self.shuffle:
56
+ self.src_batch_samplers = [
57
+ BatchSampler(
58
+ sampler=RandomSampler(
59
+ ds, replacement=False, generator=self.generator
60
+ ),
61
+ batch_size=self.batch_size,
62
+ drop_last=self.drop_last,
63
+ )
64
+ for ds in self.src_dataset_ls
65
+ ]
66
+ else:
67
+ self.src_batch_samplers = [
68
+ BatchSampler(
69
+ sampler=SequentialSampler(ds),
70
+ batch_size=self.batch_size,
71
+ drop_last=self.drop_last,
72
+ )
73
+ for ds in self.src_dataset_ls
74
+ ]
75
+ self.raw_batches = [
76
+ list(bs) for bs in self.src_batch_samplers
77
+ ] # index in original dataset
78
+ self.n_batches = [len(b) for b in self.raw_batches]
79
+ self.n_total_batch = sum(self.n_batches)
80
+
81
+ # sampling probability
82
+ if prob is None:
83
+ # if not given, decide by dataset length
84
+ self.prob = torch.tensor(self.n_batches) / self.n_total_batch
85
+ else:
86
+ self.prob = torch.as_tensor(prob)
87
+
88
+ def __iter__(self):
89
+ """_summary_
90
+
91
+ Yields:
92
+ list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls
93
+ """
94
+ for _ in range(self.n_total_batch):
95
+ idx_ds = torch.multinomial(
96
+ self.prob, 1, replacement=True, generator=self.generator
97
+ ).item()
98
+ # if batch list is empty, generate new list
99
+ if 0 == len(self.raw_batches[idx_ds]):
100
+ self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
101
+ # get a batch from list
102
+ batch_raw = self.raw_batches[idx_ds].pop()
103
+ # shift by cumulative dataset length
104
+ shift = self.cum_dataset_length[idx_ds]
105
+ batch = [n + shift for n in batch_raw]
106
+
107
+ yield batch
108
+
109
+ def __len__(self):
110
+ return self.n_total_batch
111
+
112
+
113
+ # Unit test
114
+ if "__main__" == __name__:
115
+ from torch.utils.data import ConcatDataset, DataLoader, Dataset
116
+
117
+ class SimpleDataset(Dataset):
118
+ def __init__(self, start, len) -> None:
119
+ super().__init__()
120
+ self.start = start
121
+ self.len = len
122
+
123
+ def __len__(self):
124
+ return self.len
125
+
126
+ def __getitem__(self, index):
127
+ return self.start + index
128
+
129
+ dataset_1 = SimpleDataset(0, 10)
130
+ dataset_2 = SimpleDataset(200, 20)
131
+ dataset_3 = SimpleDataset(1000, 50)
132
+
133
+ concat_dataset = ConcatDataset(
134
+ [dataset_1, dataset_2, dataset_3]
135
+ ) # will directly concatenate
136
+
137
+ mixed_sampler = MixedBatchSampler(
138
+ src_dataset_ls=[dataset_1, dataset_2, dataset_3],
139
+ batch_size=4,
140
+ drop_last=True,
141
+ shuffle=False,
142
+ prob=[0.6, 0.3, 0.1],
143
+ generator=torch.Generator().manual_seed(0),
144
+ )
145
+
146
+ loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler)
147
+
148
+ for d in loader:
149
+ print(d)
src/dataset/nyu_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import torch
24
+
25
+ from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset
26
+
27
+
28
+ class NYUDataset(EvaluateBaseDataset):
29
+ def __init__(
30
+ self,
31
+ eigen_valid_mask: bool,
32
+ **kwargs,
33
+ ) -> None:
34
+ super().__init__(
35
+ # NYUv2 dataset parameter
36
+ min_depth=1e-3,
37
+ max_depth=10.0,
38
+ has_filled_depth=True,
39
+ name_mode=DepthFileNameMode.rgb_id,
40
+ **kwargs,
41
+ )
42
+
43
+ self.eigen_valid_mask = eigen_valid_mask
44
+
45
+ def _read_depth_file(self, rel_path):
46
+ depth_in = self._read_image(rel_path)
47
+ # Decode NYU depth
48
+ depth_decoded = depth_in / 1000.0
49
+ return depth_decoded
50
+
51
+ def _get_valid_mask(self, depth: torch.Tensor):
52
+ valid_mask = super()._get_valid_mask(depth)
53
+
54
+ # Eigen crop for evaluation
55
+ if self.eigen_valid_mask:
56
+ eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
57
+ eval_mask[45:471, 41:601] = 1
58
+ eval_mask.reshape(valid_mask.shape)
59
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
60
+
61
+ return valid_mask
src/dataset/scannet_dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ from .eval_base_dataset import DepthFileNameMode, EvaluateBaseDataset
24
+
25
+
26
+ class ScanNetDataset(EvaluateBaseDataset):
27
+ def __init__(
28
+ self,
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(
32
+ # ScanNet data parameter
33
+ min_depth=1e-3,
34
+ max_depth=10,
35
+ has_filled_depth=False,
36
+ name_mode=DepthFileNameMode.id,
37
+ **kwargs,
38
+ )
39
+
40
+ def _read_depth_file(self, rel_path):
41
+ depth_in = self._read_image(rel_path)
42
+ # Decode ScanNet depth
43
+ depth_decoded = depth_in / 1000.0
44
+ return depth_decoded
src/dataset/vkitti_dataset.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-02-08
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import torch
24
+
25
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
26
+ from .kitti_dataset import KITTIDataset
27
+
28
+ class VirtualKITTIDataset(BaseDepthDataset):
29
+ def __init__(
30
+ self,
31
+ kitti_bm_crop, # Crop to KITTI benchmark size
32
+ valid_mask_crop, # Evaluation mask. [None, garg or eigen]
33
+ **kwargs,
34
+ ) -> None:
35
+ super().__init__(
36
+ # virtual KITTI data parameter
37
+ min_depth=1e-5,
38
+ max_depth=80, # 655.35
39
+ has_filled_depth=False,
40
+ name_mode=DepthFileNameMode.id,
41
+ **kwargs,
42
+ )
43
+ self.kitti_bm_crop = kitti_bm_crop
44
+ self.valid_mask_crop = valid_mask_crop
45
+ assert self.valid_mask_crop in [
46
+ None,
47
+ "garg", # set evaluation mask according to Garg ECCV16
48
+ "eigen", # set evaluation mask according to Eigen NIPS14
49
+ ], f"Unknown crop type: {self.valid_mask_crop}"
50
+
51
+ # Filter out empty depth
52
+ self.filenames = self.filenames
53
+
54
+ def _read_depth_file(self, rel_path):
55
+ depth_in = self._read_image(rel_path)
56
+ # Decode vKITTI depth
57
+ depth_decoded = depth_in / 100.0
58
+ return depth_decoded
59
+
60
+ def _load_rgb_data(self, rgb_rel_path):
61
+ rgb_data = super()._load_rgb_data(rgb_rel_path)
62
+ if self.kitti_bm_crop:
63
+ rgb_data = {
64
+ k: KITTIDataset.kitti_benchmark_crop(v) for k, v in rgb_data.items()
65
+ }
66
+ return rgb_data
67
+
68
+ def _load_depth_data(self, depth_rel_path, filled_rel_path=None):
69
+ depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path)
70
+ if self.kitti_bm_crop:
71
+ depth_data = {
72
+ k: KITTIDataset.kitti_benchmark_crop(v) for k, v in depth_data.items()
73
+ }
74
+ return depth_data
75
+
76
+ def _get_valid_mask(self, depth: torch.Tensor):
77
+ # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py
78
+ valid_mask = super()._get_valid_mask(depth) # [1, H, W]
79
+
80
+ if self.valid_mask_crop is not None:
81
+ eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
82
+ gt_height, gt_width = eval_mask.shape
83
+
84
+ if "garg" == self.valid_mask_crop:
85
+ eval_mask[
86
+ int(0.40810811 * gt_height) : int(0.99189189 * gt_height),
87
+ int(0.03594771 * gt_width) : int(0.96405229 * gt_width),
88
+ ] = 1
89
+ elif "eigen" == self.valid_mask_crop:
90
+ eval_mask[
91
+ int(0.3324324 * gt_height) : int(0.91351351 * gt_height),
92
+ int(0.0359477 * gt_width) : int(0.96405229 * gt_width),
93
+ ] = 1
94
+
95
+ eval_mask.reshape(valid_mask.shape)
96
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
97
+ return valid_mask
src/trainer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-05-17
3
+
4
+ from .marigold_trainer import MarigoldTrainer
5
+ from .marigold_xl_trainer import MarigoldXLTrainer
6
+ from .marigold_inpaint_trainer import MarigoldInpaintTrainer
7
+
8
+ trainer_cls_name_dict = {
9
+ "MarigoldTrainer": MarigoldTrainer,
10
+ "MarigoldXLTrainer": MarigoldXLTrainer,
11
+ "MarigoldInpaintTrainer": MarigoldInpaintTrainer
12
+ }
13
+
14
+
15
+ def get_trainer_cls(trainer_name):
16
+ return trainer_cls_name_dict[trainer_name]
src/trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (519 Bytes). View file
 
src/trainer/__pycache__/marigold_inpaint_trainer.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
src/trainer/__pycache__/marigold_trainer.cpython-310.pyc ADDED
Binary file (22.5 kB). View file
 
src/trainer/__pycache__/marigold_xl_trainer.cpython-310.pyc ADDED
Binary file (22.4 kB). View file
 
src/trainer/marigold_inpaint_trainer.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An official reimplemented version of Marigold training script.
2
+ # Last modified: 2024-04-29
3
+ #
4
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # --------------------------------------------------------------------------
18
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
19
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
20
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
21
+ # More information about the method can be found at https://marigoldmonodepth.github.io
22
+ # --------------------------------------------------------------------------
23
+ from diffusers import StableDiffusionInpaintPipeline
24
+ import logging
25
+ import os
26
+ import pdb
27
+ import cv2
28
+ import shutil
29
+ import json
30
+ from pycocotools import mask as coco_mask
31
+ from datetime import datetime
32
+ from typing import List, Union
33
+ import random
34
+ import safetensors
35
+ import numpy as np
36
+ import torch
37
+ from diffusers import DDPMScheduler
38
+ from omegaconf import OmegaConf
39
+ from torch.nn import Conv2d
40
+ from torch.nn.parameter import Parameter
41
+ from torch.optim import Adam
42
+ from torch.optim.lr_scheduler import LambdaLR
43
+ from torch.utils.data import DataLoader, Dataset
44
+ from tqdm import tqdm
45
+ from PIL import Image
46
+ # import torch.optim.lr_scheduler
47
+
48
+ from diffusers.schedulers import PNDMScheduler
49
+ from torchvision.transforms.functional import pil_to_tensor
50
+ from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput
51
+ from src.util import metric
52
+ from src.util.data_loader import skip_first_batches
53
+ from src.util.logging_util import tb_logger, eval_dic_to_text
54
+ from src.util.loss import get_loss
55
+ from src.util.lr_scheduler import IterExponential
56
+ from src.util.metric import MetricTracker
57
+ from src.util.multi_res_noise import multi_res_noise_like
58
+ from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth
59
+ from src.util.seeding import generate_seed_sequence
60
+ from accelerate import Accelerator
61
+ import os
62
+ from torchvision.transforms import InterpolationMode, Resize, CenterCrop
63
+ import torchvision.transforms as transforms
64
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
65
+
66
+ class MarigoldInpaintTrainer:
67
+ def __init__(
68
+ self,
69
+ cfg: OmegaConf,
70
+ model: MarigoldPipeline,
71
+ train_dataloader: DataLoader,
72
+ device,
73
+ base_ckpt_dir,
74
+ out_dir_ckpt,
75
+ out_dir_eval,
76
+ out_dir_vis,
77
+ accumulation_steps: int,
78
+ depth_model = None,
79
+ separate_list: List = None,
80
+ val_dataloaders: List[DataLoader] = None,
81
+ vis_dataloaders: List[DataLoader] = None,
82
+ train_dataset: Dataset = None,
83
+ timestep_method: str = 'unidiffuser',
84
+ connection: bool = False
85
+ ):
86
+ self.cfg: OmegaConf = cfg
87
+ self.model: MarigoldPipeline = model
88
+ self.depth_model = depth_model
89
+ self.device = device
90
+ self.seed: Union[int, None] = (
91
+ self.cfg.trainer.init_seed
92
+ ) # used to generate seed sequence, set to `None` to train w/o seeding
93
+ self.out_dir_ckpt = out_dir_ckpt
94
+ self.out_dir_eval = out_dir_eval
95
+ self.out_dir_vis = out_dir_vis
96
+ self.train_loader: DataLoader = train_dataloader
97
+ self.val_loaders: List[DataLoader] = val_dataloaders
98
+ self.vis_loaders: List[DataLoader] = vis_dataloaders
99
+ self.accumulation_steps: int = accumulation_steps
100
+ self.separate_list = separate_list
101
+ self.timestep_method = timestep_method
102
+ self.train_dataset = train_dataset
103
+ self.connection = connection
104
+ # Adapt input layers
105
+ # if 8 != self.model.unet.config["in_channels"]:
106
+ # self._replace_unet_conv_in()
107
+ # if 8 != self.model.unet.config["out_channels"]:
108
+ # self._replace_unet_conv_out()
109
+
110
+ self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss'])
111
+ # self.generator = torch.Generator('cuda:0').manual_seed(1024)
112
+
113
+ # Encode empty text prompt
114
+ self.model.encode_empty_text()
115
+ self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
116
+
117
+ self.model.unet.enable_xformers_memory_efficient_attention()
118
+
119
+ # Trainability
120
+ self.model.text_encoder.requires_grad_(False)
121
+ # self.model.unet.requires_grad_(True)
122
+
123
+ grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters())
124
+
125
+ # Optimizer !should be defined after input layer is adapted
126
+ lr = self.cfg.lr
127
+ self.optimizer = Adam(grad_part, lr=lr)
128
+
129
+ total_params = sum(p.numel() for p in self.model.unet.parameters())
130
+ total_params_m = total_params / 1_000_000
131
+ print(f"Total parameters: {total_params_m:.2f}M")
132
+ trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad)
133
+ trainable_params_m = trainable_params / 1_000_000
134
+ print(f"Trainable parameters: {trainable_params_m:.2f}M")
135
+
136
+ # LR scheduler
137
+ lr_func = IterExponential(
138
+ total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
139
+ final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
140
+ warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
141
+ )
142
+ self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
143
+
144
+ # Loss
145
+ self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
146
+
147
+ # Training noise scheduler
148
+ # self.rgb_training_noise_scheduler: PNDMScheduler = PNDMScheduler.from_pretrained(
149
+ # os.path.join(
150
+ # cfg.trainer.rgb_training_noise_scheduler.pretrained_path,
151
+ # "scheduler",
152
+ # )
153
+ # )
154
+
155
+ self.rgb_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
156
+ cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler")
157
+ self.depth_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
158
+ cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler")
159
+
160
+ self.rgb_prediction_type = self.rgb_training_noise_scheduler.config.prediction_type
161
+ # assert (
162
+ # self.rgb_prediction_type == self.model.rgb_scheduler.config.prediction_type
163
+ # ), "Different prediction types"
164
+ self.depth_prediction_type = self.depth_training_noise_scheduler.config.prediction_type
165
+ assert (
166
+ self.depth_prediction_type == self.model.depth_scheduler.config.prediction_type
167
+ ), "Different prediction types"
168
+ self.scheduler_timesteps = (
169
+ self.rgb_training_noise_scheduler.config.num_train_timesteps
170
+ )
171
+
172
+ # Settings
173
+ self.max_epoch = self.cfg.max_epoch
174
+ self.max_iter = self.cfg.max_iter
175
+ self.gradient_accumulation_steps = accumulation_steps
176
+ self.gt_depth_type = self.cfg.gt_depth_type
177
+ self.gt_mask_type = self.cfg.gt_mask_type
178
+ self.save_period = self.cfg.trainer.save_period
179
+ self.backup_period = self.cfg.trainer.backup_period
180
+ self.val_period = self.cfg.trainer.validation_period
181
+ self.vis_period = self.cfg.trainer.visualization_period
182
+
183
+ # Multi-resolution noise
184
+ self.apply_multi_res_noise = self.cfg.multi_res_noise is not None
185
+ if self.apply_multi_res_noise:
186
+ self.mr_noise_strength = self.cfg.multi_res_noise.strength
187
+ self.annealed_mr_noise = self.cfg.multi_res_noise.annealed
188
+ self.mr_noise_downscale_strategy = (
189
+ self.cfg.multi_res_noise.downscale_strategy
190
+ )
191
+
192
+ # Internal variables
193
+ self.epoch = 0
194
+ self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
195
+ self.effective_iter = 0 # how many times optimizer.step() is called
196
+ self.in_evaluation = False
197
+ self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
198
+
199
+ def _replace_unet_conv_in(self):
200
+ # replace the first layer to accept 8 in_channels
201
+ _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3]
202
+ _bias = self.model.unet.conv_in.bias.clone() # [320]
203
+ zero_weight = torch.zeros(_weight.shape).to(_weight.device)
204
+ _weight = torch.cat([_weight, zero_weight], dim=1)
205
+ # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
206
+ # half the activation magnitude
207
+ # _weight *= 0.5
208
+ # new conv_in channel
209
+ _n_convin_out_channel = self.model.unet.conv_in.out_channels
210
+ _new_conv_in = Conv2d(
211
+ 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
212
+ )
213
+ _new_conv_in.weight = Parameter(_weight)
214
+ _new_conv_in.bias = Parameter(_bias)
215
+ self.model.unet.conv_in = _new_conv_in
216
+ logging.info("Unet conv_in layer is replaced")
217
+ # replace config
218
+ self.model.unet.config["in_channels"] = 8
219
+ logging.info("Unet config is updated")
220
+ return
221
+
222
+ def parallel_train(self, t_end=None, accelerator=None):
223
+ logging.info("Start training")
224
+ self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare(
225
+ self.model, self.optimizer, self.train_loader, self.lr_scheduler
226
+ )
227
+ self.depth_model = accelerator.prepare(self.depth_model)
228
+
229
+ self.accelerator = accelerator
230
+ if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')):
231
+ accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest'))
232
+ self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest'))
233
+
234
+ # if accelerator.is_main_process:
235
+ # self._inpaint_rgbd()
236
+
237
+ self.train_metrics.reset()
238
+ accumulated_step = 0
239
+ for epoch in range(self.epoch, self.max_epoch + 1):
240
+ self.epoch = epoch
241
+ logging.debug(f"epoch: {self.epoch}")
242
+
243
+ # Skip previous batches when resume
244
+ for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
245
+ self.model.unet.train()
246
+
247
+ # globally consistent random generators
248
+ if self.seed is not None:
249
+ local_seed = self._get_next_seed()
250
+ rand_num_generator = torch.Generator(device=self.model.device)
251
+ rand_num_generator.manual_seed(local_seed)
252
+ else:
253
+ rand_num_generator = None
254
+
255
+ # >>> With gradient accumulation >>>
256
+
257
+ # Get data
258
+ rgb = batch["rgb_norm"].to(self.model.device)
259
+ with torch.no_grad():
260
+ disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device)
261
+
262
+ if len(disparities.shape) == 2:
263
+ disparities = disparities.unsqueeze(0)
264
+
265
+ depth_gt_for_latent = []
266
+ for disparity_map in disparities:
267
+ depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1
268
+ depth_gt_for_latent.append(depth_map)
269
+ depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0)
270
+
271
+ batch_size = rgb.shape[0]
272
+
273
+ mask = self.model.mask_processor.preprocess(batch['mask'] * 255).to(self.model.device)
274
+
275
+ rgb_timesteps = torch.randint(
276
+ 0,
277
+ self.scheduler_timesteps,
278
+ (batch_size,),
279
+ device=self.model.device,
280
+ generator=rand_num_generator,
281
+ ).long() # [B]
282
+ depth_timesteps = rgb_timesteps
283
+
284
+ rgb_flag = 1
285
+ depth_flag = 1
286
+
287
+ if self.timestep_method == 'joint':
288
+ rgb_mask = mask
289
+ depth_mask = mask
290
+
291
+ elif self.timestep_method == 'partition':
292
+ rand_num = random.random()
293
+ if rand_num < 0.5: # joint prediction
294
+ rgb_mask = mask
295
+ depth_mask = mask
296
+ elif rand_num < 0.75: # full rgb; depth prediction
297
+ rgb_flag = 0
298
+ rgb_mask = torch.zeros_like(mask)
299
+ depth_mask = mask
300
+ else:
301
+ depth_flag = 0
302
+ rgb_mask = mask
303
+ if random.random() < 0.5:
304
+ depth_mask = torch.zeros_like(mask) # full depth; rgb prediction
305
+ else:
306
+ depth_mask = mask # partial depth; rgb prediction
307
+
308
+ masked_rgb = rgb * (rgb_mask < 0.5)
309
+ masked_depth = depth_gt_for_latent * (depth_mask.squeeze() < 0.5)
310
+ with torch.no_grad():
311
+ # Encode image
312
+ rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
313
+ mask_rgb_latent = self.model.encode_rgb(masked_rgb)
314
+
315
+ if depth_timesteps.sum() == 0:
316
+ gt_depth_latent = self.encode_depth(masked_depth)
317
+ else:
318
+ gt_depth_latent = self.encode_depth(depth_gt_for_latent)
319
+ mask_depth_latent = self.encode_depth(masked_depth)
320
+
321
+ rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
322
+ depth_mask = torch.nn.functional.interpolate(depth_mask, size=gt_depth_latent.shape[-2:])
323
+
324
+ # Sample noise
325
+ rgb_noise = torch.randn(
326
+ rgb_latent.shape,
327
+ device=self.model.device,
328
+ generator=rand_num_generator,
329
+ ) # [B, 4, h, w]
330
+ depth_noise = torch.randn(
331
+ gt_depth_latent.shape,
332
+ device=self.model.device,
333
+ generator=rand_num_generator,
334
+ ) # [B, 4, h, w]
335
+
336
+ if rgb_timesteps.sum() == 0:
337
+ noisy_rgb_latents = rgb_latent
338
+ else:
339
+ noisy_rgb_latents = self.rgb_training_noise_scheduler.add_noise(
340
+ rgb_latent, rgb_noise, rgb_timesteps
341
+ ) # [B, 4, h, w]
342
+ if depth_timesteps.sum() == 0:
343
+ noisy_depth_latents = gt_depth_latent
344
+ else:
345
+ noisy_depth_latents = self.depth_training_noise_scheduler.add_noise(
346
+ gt_depth_latent, depth_noise, depth_timesteps
347
+ ) # [B, 4, h, w]
348
+
349
+ noisy_latents = torch.cat(
350
+ [noisy_rgb_latents, rgb_mask, mask_rgb_latent, mask_depth_latent, noisy_depth_latents, depth_mask, mask_rgb_latent, mask_depth_latent], dim=1
351
+ ).float() # [B, 9*2, h, w]
352
+
353
+ # Text embedding
354
+ input_ids = self.model.tokenizer(
355
+ batch['text'],
356
+ padding="max_length",
357
+ max_length=self.model.tokenizer.model_max_length,
358
+ truncation=True,
359
+ return_tensors="pt",
360
+ )
361
+ input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()}
362
+ text_embed = self.model.text_encoder(**input_ids)[0]
363
+
364
+ model_pred = self.model.unet(
365
+ noisy_latents, rgb_timesteps, depth_timesteps, text_embed, controlnet_connection=self.connection
366
+ ).sample # [B, 8, h, w]
367
+
368
+ if torch.isnan(model_pred).any():
369
+ logging.warning("model_pred contains NaN.")
370
+
371
+ # Get the target for loss depending on the prediction type
372
+ if "sample" == self.rgb_prediction_type:
373
+ rgb_target = rgb_latent
374
+ elif "epsilon" == self.rgb_prediction_type:
375
+ rgb_target = rgb_latent
376
+ elif "v_prediction" == self.rgb_prediction_type:
377
+ rgb_target = self.rgb_training_noise_scheduler.get_velocity(
378
+ rgb_latent, rgb_noise, rgb_timesteps
379
+ ) # [B, 4, h, w]
380
+ else:
381
+ raise ValueError(f"Unknown rgb prediction type {self.prediction_type}")
382
+
383
+ if "sample" == self.depth_prediction_type:
384
+ depth_target = gt_depth_latent
385
+ elif "epsilon" == self.depth_prediction_type:
386
+ depth_target = gt_depth_latent
387
+ elif "v_prediction" == self.depth_prediction_type:
388
+ depth_target = self.depth_training_noise_scheduler.get_velocity(
389
+ gt_depth_latent, depth_noise, depth_timesteps
390
+ ) # [B, 4, h, w]
391
+ else:
392
+ raise ValueError(f"Unknown depth prediction type {self.prediction_type}")
393
+ # Masked latent loss
394
+ with accelerator.accumulate(self.model):
395
+
396
+ rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float())
397
+ depth_loss = self.loss(model_pred[:, 4:, :, :].float(), depth_target.float())
398
+
399
+ if rgb_flag == 0:
400
+ loss = depth_loss
401
+ elif depth_flag == 0:
402
+ loss = rgb_loss
403
+ else:
404
+ loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss
405
+
406
+ self.train_metrics.update("loss", loss.item())
407
+ self.train_metrics.update("rgb_loss", rgb_loss.item())
408
+ self.train_metrics.update("depth_loss", depth_loss.item())
409
+ # loss = loss / self.gradient_accumulation_steps
410
+ accelerator.backward(loss)
411
+ self.optimizer.step()
412
+ self.optimizer.zero_grad()
413
+ # loss.backward()
414
+ self.n_batch_in_epoch += 1
415
+ # print(accelerator.process_index, self.lr_scheduler.get_last_lr())
416
+ self.lr_scheduler.step(self.effective_iter)
417
+
418
+ if accelerator.sync_gradients:
419
+ accumulated_step += 1
420
+
421
+ if accumulated_step >= self.gradient_accumulation_steps:
422
+ accumulated_step = 0
423
+ self.effective_iter += 1
424
+
425
+ if accelerator.is_main_process:
426
+ # Log to tensorboard
427
+ if self.effective_iter == 1:
428
+ self._inpaint_rgbd()
429
+
430
+ accumulated_loss = self.train_metrics.result()["loss"]
431
+ rgb_loss = self.train_metrics.result()["rgb_loss"]
432
+ depth_loss = self.train_metrics.result()["depth_loss"]
433
+ tb_logger.log_dic(
434
+ {
435
+ f"train/{k}": v
436
+ for k, v in self.train_metrics.result().items()
437
+ },
438
+ global_step=self.effective_iter,
439
+ )
440
+ tb_logger.writer.add_scalar(
441
+ "lr",
442
+ self.lr_scheduler.get_last_lr()[0],
443
+ global_step=self.effective_iter,
444
+ )
445
+ tb_logger.writer.add_scalar(
446
+ "n_batch_in_epoch",
447
+ self.n_batch_in_epoch,
448
+ global_step=self.effective_iter,
449
+ )
450
+ logging.info(
451
+ f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}"
452
+ )
453
+ accelerator.wait_for_everyone()
454
+
455
+ if self.save_period > 0 and 0 == self.effective_iter % self.save_period:
456
+ accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest'))
457
+ unwrapped_model = accelerator.unwrap_model(self.model)
458
+ if accelerator.is_main_process:
459
+ accelerator.save_model(unwrapped_model.unet,
460
+ os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False)
461
+ self.save_miscs('latest')
462
+ self._inpaint_rgbd()
463
+ accelerator.wait_for_everyone()
464
+
465
+ if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
466
+ unwrapped_model = accelerator.unwrap_model(self.model)
467
+ if accelerator.is_main_process:
468
+ accelerator.save_model(unwrapped_model.unet,
469
+ os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()), safe_serialization=False)
470
+ accelerator.wait_for_everyone()
471
+
472
+ # End of training
473
+ if self.max_iter > 0 and self.effective_iter >= self.max_iter:
474
+ unwrapped_model = accelerator.unwrap_model(self.model)
475
+ if accelerator.is_main_process:
476
+ unwrapped_model.unet.save_pretrained(
477
+ os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()))
478
+ accelerator.wait_for_everyone()
479
+ return
480
+
481
+ torch.cuda.empty_cache()
482
+ # <<< Effective batch end <<<
483
+
484
+ # Epoch end
485
+ self.n_batch_in_epoch = 0
486
+
487
+ def _inpaint_rgbd(self):
488
+ image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg',
489
+ '/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg',
490
+ '/dataset/~sa-1b/data/sa_000045/sa_457934.jpg']
491
+ prompt = ['A white car is parked in front of the factory',
492
+ 'church with cemetery next to it',
493
+ 'A house with a red brick roof']
494
+
495
+ imgs = [pil_to_tensor(Image.open(p)) for p in image_path]
496
+ depth_imgs = [self.depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs]
497
+
498
+ masks = []
499
+ for rgb_path in image_path:
500
+ anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations']
501
+ random.shuffle(anno)
502
+ object_num = random.randint(5, 10)
503
+ mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8)
504
+ for single_anno in (anno[0:object_num] if len(anno)>object_num else anno):
505
+ mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8)
506
+ masks.append(torch.from_numpy(mask))
507
+
508
+ resize_transform = transforms.Compose([
509
+ Resize(size=512, interpolation=InterpolationMode.NEAREST_EXACT),
510
+ CenterCrop(size=[512, 512])])
511
+ imgs = [resize_transform(img) for img in imgs]
512
+ depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs]
513
+ masks = [resize_transform(mask.unsqueeze(0)) for mask in masks]
514
+ # pdb.set_trace()
515
+
516
+ for i in range(len(imgs)):
517
+ output_image = self.model._rgbd_inpaint(imgs[i], depth_imgs[i], masks[i], [prompt[i]], processing_res=512, mode='joint_inpaint')
518
+ tb_logger.writer.add_image(f'{prompt[i]}', pil_to_tensor(output_image), self.effective_iter)
519
+
520
+ def encode_depth(self, depth_in):
521
+ # stack depth into 3-channel
522
+ stacked = self.stack_depth_images(depth_in)
523
+ # encode using VAE encoder
524
+ depth_latent = self.model.encode_rgb(stacked)
525
+ return depth_latent
526
+
527
+ @staticmethod
528
+ def stack_depth_images(depth_in):
529
+ if 4 == len(depth_in.shape):
530
+ stacked = depth_in.repeat(1, 3, 1, 1)
531
+ elif 3 == len(depth_in.shape):
532
+ stacked = depth_in.unsqueeze(1)
533
+ stacked = stacked.repeat(1, 3, 1, 1)
534
+ elif 2 == len(depth_in.shape):
535
+ stacked = depth_in.unsqueeze(0).unsqueeze(0)
536
+ stacked = stacked.repeat(1, 3, 1, 1)
537
+ return stacked
538
+
539
+ def visualize(self):
540
+ for val_loader in self.vis_loaders:
541
+ vis_dataset_name = val_loader.dataset.disp_name
542
+ vis_out_dir = os.path.join(
543
+ self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
544
+ )
545
+ os.makedirs(vis_out_dir, exist_ok=True)
546
+ _ = self.validate_single_dataset(
547
+ data_loader=val_loader,
548
+ metric_tracker=self.val_metrics,
549
+ save_to_dir=vis_out_dir,
550
+ )
551
+
552
+ def _get_next_seed(self):
553
+ if 0 == len(self.global_seed_sequence):
554
+ self.global_seed_sequence = generate_seed_sequence(
555
+ initial_seed=self.seed,
556
+ length=self.max_iter * self.gradient_accumulation_steps,
557
+ )
558
+ logging.info(
559
+ f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
560
+ )
561
+ return self.global_seed_sequence.pop()
562
+
563
+ def save_miscs(self, ckpt_name):
564
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
565
+ state = {
566
+ "config": self.cfg,
567
+ "effective_iter": self.effective_iter,
568
+ "epoch": self.epoch,
569
+ "n_batch_in_epoch": self.n_batch_in_epoch,
570
+ "global_seed_sequence": self.global_seed_sequence,
571
+ }
572
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
573
+ torch.save(state, train_state_path)
574
+
575
+ logging.info(f"Misc state is saved to: {train_state_path}")
576
+
577
+ def load_miscs(self, ckpt_path):
578
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
579
+ self.effective_iter = checkpoint["effective_iter"]
580
+ self.epoch = checkpoint["epoch"]
581
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
582
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
583
+
584
+ logging.info(f"Misc state is loaded from {ckpt_path}")
585
+
586
+
587
+ def save_checkpoint(self, ckpt_name, save_train_state):
588
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
589
+ logging.info(f"Saving checkpoint to: {ckpt_dir}")
590
+ # Backup previous checkpoint
591
+ temp_ckpt_dir = None
592
+ if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
593
+ temp_ckpt_dir = os.path.join(
594
+ os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
595
+ )
596
+ if os.path.exists(temp_ckpt_dir):
597
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
598
+ os.rename(ckpt_dir, temp_ckpt_dir)
599
+ logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
600
+
601
+ # Save UNet
602
+ unet_path = os.path.join(ckpt_dir, "unet")
603
+ self.model.unet.save_pretrained(unet_path, safe_serialization=False)
604
+ logging.info(f"UNet is saved to: {unet_path}")
605
+
606
+ if save_train_state:
607
+ state = {
608
+ "config": self.cfg,
609
+ "effective_iter": self.effective_iter,
610
+ "epoch": self.epoch,
611
+ "n_batch_in_epoch": self.n_batch_in_epoch,
612
+ "best_metric": self.best_metric,
613
+ "in_evaluation": self.in_evaluation,
614
+ "global_seed_sequence": self.global_seed_sequence,
615
+ }
616
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
617
+ torch.save(state, train_state_path)
618
+ # iteration indicator
619
+ f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
620
+ f.close()
621
+
622
+ logging.info(f"Trainer state is saved to: {train_state_path}")
623
+
624
+ # Remove temp ckpt
625
+ if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
626
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
627
+ logging.debug("Old checkpoint backup is removed.")
628
+
629
+ def load_checkpoint(
630
+ self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
631
+ ):
632
+ logging.info(f"Loading checkpoint from: {ckpt_path}")
633
+ # Load UNet
634
+ _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
635
+ self.model.unet.load_state_dict(
636
+ torch.load(_model_path, map_location=self.device)
637
+ )
638
+ self.model.unet.to(self.device)
639
+ logging.info(f"UNet parameters are loaded from {_model_path}")
640
+
641
+ # Load training states
642
+ if load_trainer_state:
643
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
644
+ self.effective_iter = checkpoint["effective_iter"]
645
+ self.epoch = checkpoint["epoch"]
646
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
647
+ self.in_evaluation = checkpoint["in_evaluation"]
648
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
649
+
650
+ self.best_metric = checkpoint["best_metric"]
651
+
652
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
653
+ logging.info(f"optimizer state is loaded from {ckpt_path}")
654
+
655
+ if resume_lr_scheduler:
656
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
657
+ logging.info(f"LR scheduler state is loaded from {ckpt_path}")
658
+
659
+ logging.info(
660
+ f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
661
+ )
662
+ return
663
+
664
+ def _get_backup_ckpt_name(self):
665
+ return f"iter_{self.effective_iter:06d}"
src/trainer/marigold_trainer.py ADDED
@@ -0,0 +1,968 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An official reimplemented version of Marigold training script.
2
+ # Last modified: 2024-04-29
3
+ #
4
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # --------------------------------------------------------------------------
18
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
19
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
20
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
21
+ # More information about the method can be found at https://marigoldmonodepth.github.io
22
+ # --------------------------------------------------------------------------
23
+
24
+
25
+ import logging
26
+ import os
27
+ import pdb
28
+ import shutil
29
+ from datetime import datetime
30
+ from typing import List, Union
31
+ import random
32
+ import safetensors
33
+ import numpy as np
34
+ import torch
35
+ from diffusers import DDPMScheduler
36
+ from omegaconf import OmegaConf
37
+ from torch.nn import Conv2d
38
+ from torch.nn.parameter import Parameter
39
+ from torch.optim import Adam
40
+ from torch.optim.lr_scheduler import LambdaLR
41
+ from torch.utils.data import DataLoader
42
+ from tqdm import tqdm
43
+ from PIL import Image
44
+ # import torch.optim.lr_scheduler
45
+
46
+ from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput
47
+ from src.util import metric
48
+ from src.util.data_loader import skip_first_batches
49
+ from src.util.logging_util import tb_logger, eval_dic_to_text
50
+ from src.util.loss import get_loss
51
+ from src.util.lr_scheduler import IterExponential
52
+ from src.util.metric import MetricTracker
53
+ from src.util.multi_res_noise import multi_res_noise_like
54
+ from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth
55
+ from src.util.seeding import generate_seed_sequence
56
+ from accelerate import Accelerator
57
+ import os
58
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
59
+
60
+ class MarigoldTrainer:
61
+ def __init__(
62
+ self,
63
+ cfg: OmegaConf,
64
+ model: MarigoldPipeline,
65
+ train_dataloader: DataLoader,
66
+ device,
67
+ base_ckpt_dir,
68
+ out_dir_ckpt,
69
+ out_dir_eval,
70
+ out_dir_vis,
71
+ accumulation_steps: int,
72
+ depth_model = None,
73
+ separate_list: List = None,
74
+ val_dataloaders: List[DataLoader] = None,
75
+ vis_dataloaders: List[DataLoader] = None,
76
+ timestep_method: str = 'unidiffuser'
77
+ ):
78
+ self.cfg: OmegaConf = cfg
79
+ self.model: MarigoldPipeline = model
80
+ self.depth_model = depth_model
81
+ self.device = device
82
+ self.seed: Union[int, None] = (
83
+ self.cfg.trainer.init_seed
84
+ ) # used to generate seed sequence, set to `None` to train w/o seeding
85
+ self.out_dir_ckpt = out_dir_ckpt
86
+ self.out_dir_eval = out_dir_eval
87
+ self.out_dir_vis = out_dir_vis
88
+ self.train_loader: DataLoader = train_dataloader
89
+ self.val_loaders: List[DataLoader] = val_dataloaders
90
+ self.vis_loaders: List[DataLoader] = vis_dataloaders
91
+ self.accumulation_steps: int = accumulation_steps
92
+ self.separate_list = separate_list
93
+ self.timestep_method = timestep_method
94
+ # Adapt input layers
95
+ # if 8 != self.model.unet.config["in_channels"]:
96
+ # self._replace_unet_conv_in()
97
+ # if 8 != self.model.unet.config["out_channels"]:
98
+ # self._replace_unet_conv_out()
99
+
100
+ self.prompt = ['a view of a city skyline from a bridge',
101
+ 'a man and a woman sitting on a couch',
102
+ 'a black car parked in a parking lot next to the water',
103
+ 'Enchanted forest with glowing plants, fairies, and ancient castle.',
104
+ 'Futuristic city with skyscrapers, neon lights, and hovering vehicles.',
105
+ 'Fantasy mountain landscape with waterfalls, dragons, and mythical creatures.']
106
+ # self.generator = torch.Generator('cuda:0').manual_seed(1024)
107
+
108
+ # Encode empty text prompt
109
+ self.model.encode_empty_text()
110
+ self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
111
+
112
+ self.model.unet.enable_xformers_memory_efficient_attention()
113
+
114
+ # Trainability
115
+ self.model.text_encoder.requires_grad_(False)
116
+ # self.model.unet.requires_grad_(True)
117
+
118
+ grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters())
119
+
120
+ # Optimizer !should be defined after input layer is adapted
121
+ lr = self.cfg.lr
122
+ self.optimizer = Adam(grad_part, lr=lr)
123
+
124
+ total_params = sum(p.numel() for p in self.model.unet.parameters())
125
+ total_params_m = total_params / 1_000_000
126
+ print(f"Total parameters: {total_params_m:.2f}M")
127
+ trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad)
128
+ trainable_params_m = trainable_params / 1_000_000
129
+ print(f"Trainable parameters: {trainable_params_m:.2f}M")
130
+
131
+ # LR scheduler
132
+ lr_func = IterExponential(
133
+ total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
134
+ final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
135
+ warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
136
+ )
137
+ self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
138
+
139
+ # Loss
140
+ self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
141
+
142
+ # Training noise scheduler
143
+ self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
144
+ os.path.join(
145
+ cfg.trainer.training_noise_scheduler.pretrained_path,
146
+ "scheduler",
147
+ )
148
+ )
149
+ # pdb.set_trace()
150
+ self.prediction_type = self.training_noise_scheduler.config.prediction_type
151
+ assert (
152
+ self.prediction_type == self.model.scheduler.config.prediction_type
153
+ ), "Different prediction types"
154
+ self.scheduler_timesteps = (
155
+ self.training_noise_scheduler.config.num_train_timesteps
156
+ )
157
+
158
+ # Eval metrics
159
+ self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
160
+ self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss'])
161
+ self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
162
+ # main metric for best checkpoint saving
163
+ self.main_val_metric = cfg.validation.main_val_metric
164
+ self.main_val_metric_goal = cfg.validation.main_val_metric_goal
165
+ assert (
166
+ self.main_val_metric in cfg.eval.eval_metrics
167
+ ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
168
+ self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
169
+
170
+ # Settings
171
+ self.max_epoch = self.cfg.max_epoch
172
+ self.max_iter = self.cfg.max_iter
173
+ self.gradient_accumulation_steps = accumulation_steps
174
+ self.gt_depth_type = self.cfg.gt_depth_type
175
+ self.gt_mask_type = self.cfg.gt_mask_type
176
+ self.save_period = self.cfg.trainer.save_period
177
+ self.backup_period = self.cfg.trainer.backup_period
178
+ self.val_period = self.cfg.trainer.validation_period
179
+ self.vis_period = self.cfg.trainer.visualization_period
180
+
181
+ # Multi-resolution noise
182
+ self.apply_multi_res_noise = self.cfg.multi_res_noise is not None
183
+ if self.apply_multi_res_noise:
184
+ self.mr_noise_strength = self.cfg.multi_res_noise.strength
185
+ self.annealed_mr_noise = self.cfg.multi_res_noise.annealed
186
+ self.mr_noise_downscale_strategy = (
187
+ self.cfg.multi_res_noise.downscale_strategy
188
+ )
189
+
190
+ # Internal variables
191
+ self.epoch = 0
192
+ self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
193
+ self.effective_iter = 0 # how many times optimizer.step() is called
194
+ self.in_evaluation = False
195
+ self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
196
+
197
+ def _replace_unet_conv_in(self):
198
+ # replace the first layer to accept 8 in_channels
199
+ _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3]
200
+ _bias = self.model.unet.conv_in.bias.clone() # [320]
201
+ zero_weight = torch.zeros(_weight.shape).to(_weight.device)
202
+ _weight = torch.cat([_weight, zero_weight], dim=1)
203
+ # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
204
+ # half the activation magnitude
205
+ # _weight *= 0.5
206
+ # new conv_in channel
207
+ _n_convin_out_channel = self.model.unet.conv_in.out_channels
208
+ _new_conv_in = Conv2d(
209
+ 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
210
+ )
211
+ _new_conv_in.weight = Parameter(_weight)
212
+ _new_conv_in.bias = Parameter(_bias)
213
+ self.model.unet.conv_in = _new_conv_in
214
+ logging.info("Unet conv_in layer is replaced")
215
+ # replace config
216
+ self.model.unet.config["in_channels"] = 8
217
+ logging.info("Unet config is updated")
218
+ return
219
+
220
+ def _replace_unet_conv_out(self):
221
+ # replace the first layer to accept 8 in_channels
222
+ _weight = self.model.unet.conv_out.weight.clone() # [8, 320, 3, 3]
223
+ _bias = self.model.unet.conv_out.bias.clone() # [320]
224
+ _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s)
225
+ _bias = _bias.repeat((2))
226
+ # half the activation magnitude
227
+
228
+ # new conv_in channel
229
+ _n_convin_out_channel = self.model.unet.conv_out.out_channels
230
+ _new_conv_out = Conv2d(
231
+ _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
232
+ )
233
+ _new_conv_out.weight = Parameter(_weight)
234
+ _new_conv_out.bias = Parameter(_bias)
235
+ self.model.unet.conv_out = _new_conv_out
236
+ logging.info("Unet conv_out layer is replaced")
237
+ # replace config
238
+ self.model.unet.config["out_channels"] = 8
239
+ logging.info("Unet config is updated")
240
+ return
241
+
242
+ def parallel_train(self, t_end=None, accelerator=None):
243
+ logging.info("Start training")
244
+ # pdb.set_trace()
245
+ self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare(
246
+ self.model, self.optimizer, self.train_loader, self.lr_scheduler
247
+ )
248
+ self.depth_model = accelerator.prepare(self.depth_model)
249
+
250
+ self.accelerator = accelerator
251
+ if self.val_loaders is not None:
252
+ for idx, loader in enumerate(self.val_loaders):
253
+ self.val_loaders[idx] = accelerator.prepare(loader)
254
+
255
+ if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')):
256
+ accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest'))
257
+ self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest'))
258
+
259
+ self.train_metrics.reset()
260
+ accumulated_step = 0
261
+ for epoch in range(self.epoch, self.max_epoch + 1):
262
+ self.epoch = epoch
263
+ logging.debug(f"epoch: {self.epoch}")
264
+
265
+ # Skip previous batches when resume
266
+ for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
267
+ self.model.unet.train()
268
+
269
+ # globally consistent random generators
270
+ if self.seed is not None:
271
+ local_seed = self._get_next_seed()
272
+ rand_num_generator = torch.Generator(device=self.model.device)
273
+ rand_num_generator.manual_seed(local_seed)
274
+ else:
275
+ rand_num_generator = None
276
+
277
+ # >>> With gradient accumulation >>>
278
+
279
+ # Get data
280
+ rgb = batch["rgb_norm"].to(self.model.device)
281
+ if self.gt_depth_type not in batch:
282
+ with torch.no_grad():
283
+ disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device)
284
+ depth_gt_for_latent = []
285
+ for disparity_map in disparities:
286
+ depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1
287
+ depth_gt_for_latent.append(depth_map)
288
+ depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0)
289
+ else:
290
+ if "least_square_disparity" == self.cfg.eval.alignment:
291
+ # convert GT depth -> GT disparity
292
+ depth_raw_ts = batch["depth_raw_linear"].squeeze()
293
+ depth_raw = depth_raw_ts.cpu().numpy()
294
+ # pdb.set_trace()
295
+ disparities = depth2disparity(
296
+ depth=depth_raw
297
+ )
298
+ depth_gt_for_latent = []
299
+ for disparity_map in disparities:
300
+ depth_map = ((disparity_map - disparity_map.min()) / (
301
+ disparity_map.max() - disparity_map.min())) * 2 - 1
302
+ depth_gt_for_latent.append(torch.from_numpy(depth_map))
303
+ depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0).to(self.model.device)
304
+ else:
305
+ depth_gt_for_latent = batch[self.gt_depth_type].to(self.model.device)
306
+
307
+ batch_size = rgb.shape[0]
308
+
309
+ if self.gt_mask_type is not None:
310
+ valid_mask_for_latent = batch[self.gt_mask_type].to(self.model.device)
311
+ invalid_mask = ~valid_mask_for_latent
312
+ valid_mask_down = ~torch.max_pool2d(
313
+ invalid_mask.float(), 8, 8
314
+ ).bool()
315
+ valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1))
316
+
317
+ with torch.no_grad():
318
+ # Encode image
319
+ rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
320
+ # Encode GT depth
321
+ gt_depth_latent = self.encode_depth(
322
+ depth_gt_for_latent
323
+ ) # [B, 4, h, w]
324
+ # Sample a random timestep for each image
325
+ if self.cfg.loss.depth_factor == 1:
326
+ rgb_timesteps = torch.zeros(
327
+ (batch_size),
328
+ device=self.model.device
329
+ ).long() # [B]
330
+ depth_timesteps = torch.randint(
331
+ 0,
332
+ self.scheduler_timesteps,
333
+ (batch_size,),
334
+ device=self.model.device,
335
+ generator=rand_num_generator,
336
+ ).long() # [B]
337
+ elif self.timestep_method == 'unidiffuser':
338
+ rgb_timesteps = torch.randint(
339
+ 0,
340
+ self.scheduler_timesteps,
341
+ (batch_size,),
342
+ device=self.model.device,
343
+ generator=rand_num_generator,
344
+ ).long() # [B]
345
+ depth_timesteps = torch.randint(
346
+ 0,
347
+ self.scheduler_timesteps,
348
+ (batch_size,),
349
+ device=self.model.device,
350
+ generator=rand_num_generator,
351
+ ).long() # [B]
352
+ elif self.timestep_method == 'joint':
353
+ rgb_timesteps = torch.randint(
354
+ 0,
355
+ self.scheduler_timesteps,
356
+ (batch_size,),
357
+ device=self.model.device,
358
+ generator=rand_num_generator,
359
+ ).long() # [B]
360
+ depth_timesteps = rgb_timesteps # [B]
361
+ elif self.timestep_method == 'partition':
362
+ rand_num = random.random()
363
+ if rand_num < 0.3333:
364
+ # joint generation
365
+ rgb_timesteps = torch.randint(
366
+ 0,
367
+ self.scheduler_timesteps,
368
+ (batch_size,),
369
+ device=self.model.device,
370
+ generator=rand_num_generator,
371
+ ).long() # [B]
372
+ depth_timesteps = rgb_timesteps
373
+ elif rand_num < 0.6666:
374
+ # image2depth generation
375
+ rgb_timesteps = torch.zeros(
376
+ (batch_size),
377
+ device=self.model.device
378
+ ).long() # [B]
379
+ depth_timesteps = torch.randint(
380
+ 0,
381
+ self.scheduler_timesteps,
382
+ (batch_size,),
383
+ device=self.model.device,
384
+ generator=rand_num_generator,
385
+ ).long() # [B]
386
+ else:
387
+ # depth2image generation
388
+ rgb_timesteps = torch.randint(
389
+ 0,
390
+ self.scheduler_timesteps,
391
+ (batch_size,),
392
+ device=self.model.device,
393
+ generator=rand_num_generator,
394
+ ).long() # [B]
395
+ depth_timesteps = torch.zeros(
396
+ (batch_size),
397
+ device=self.model.device
398
+ ).long() # [B]
399
+
400
+ # Sample noise
401
+ if self.apply_multi_res_noise:
402
+ rgb_strength = self.mr_noise_strength
403
+ if self.annealed_mr_noise:
404
+ # calculate strength depending on t
405
+ rgb_strength = rgb_strength * (rgb_timesteps / self.scheduler_timesteps)
406
+ rgb_noise = multi_res_noise_like(
407
+ rgb_latent,
408
+ strength=rgb_strength,
409
+ downscale_strategy=self.mr_noise_downscale_strategy,
410
+ generator=rand_num_generator,
411
+ device=self.model.device,
412
+ )
413
+
414
+ depth_strength = self.mr_noise_strength
415
+ if self.annealed_mr_noise:
416
+ # calculate strength depending on t
417
+ depth_strength = depth_strength * (depth_timesteps / self.scheduler_timesteps)
418
+ depth_noise = multi_res_noise_like(
419
+ gt_depth_latent,
420
+ strength=depth_strength,
421
+ downscale_strategy=self.mr_noise_downscale_strategy,
422
+ generator=rand_num_generator,
423
+ device=self.model.device,
424
+ )
425
+ else:
426
+ rgb_noise = torch.randn(
427
+ rgb_latent.shape,
428
+ device=self.model.device,
429
+ generator=rand_num_generator,
430
+ ) # [B, 8, h, w]
431
+
432
+ depth_noise = torch.randn(
433
+ gt_depth_latent.shape,
434
+ device=self.model.device,
435
+ generator=rand_num_generator,
436
+ ) # [B, 8, h, w]
437
+ # Add noise to the latents (diffusion forward process)
438
+
439
+ if depth_timesteps.sum() == 0:
440
+ noisy_rgb_latents = rgb_latent
441
+ else:
442
+ noisy_rgb_latents = self.training_noise_scheduler.add_noise(
443
+ rgb_latent, rgb_noise, rgb_timesteps
444
+ ) # [B, 4, h, w]
445
+
446
+ noisy_depth_latents = self.training_noise_scheduler.add_noise(
447
+ gt_depth_latent, depth_noise, depth_timesteps
448
+ ) # [B, 4, h, w]
449
+
450
+ noisy_latents = torch.cat(
451
+ [noisy_rgb_latents, noisy_depth_latents], dim=1
452
+ ).float() # [B, 8, h, w]
453
+
454
+ # Text embedding
455
+ input_ids = self.model.tokenizer(
456
+ batch['text'],
457
+ padding="max_length",
458
+ max_length=self.model.tokenizer.model_max_length,
459
+ truncation=True,
460
+ return_tensors="pt",
461
+ )
462
+ input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()}
463
+ text_embed = self.model.text_encoder(**input_ids)[0]
464
+ # text_embed = self.empty_text_embed.to(device).repeat(
465
+ # (batch_size, 1, 1)
466
+ # ) # [B, 77, 1024]
467
+ model_pred = self.model.unet(
468
+ noisy_latents, rgb_timesteps, depth_timesteps, text_embed
469
+ ).sample # [B, 4, h, w]
470
+ if torch.isnan(model_pred).any():
471
+ logging.warning("model_pred contains NaN.")
472
+
473
+ # Get the target for loss depending on the prediction type
474
+ if "sample" == self.prediction_type:
475
+ rgb_target = rgb_latent
476
+ depth_target = gt_depth_latent
477
+ elif "epsilon" == self.prediction_type:
478
+ rgb_target = rgb_latent
479
+ depth_target = gt_depth_latent
480
+ elif "v_prediction" == self.prediction_type:
481
+ rgb_target = self.training_noise_scheduler.get_velocity(
482
+ rgb_latent, rgb_noise, rgb_timesteps
483
+ ) # [B, 4, h, w]
484
+ depth_target = self.training_noise_scheduler.get_velocity(
485
+ gt_depth_latent, depth_noise, depth_timesteps
486
+ ) # [B, 4, h, w]
487
+ else:
488
+ raise ValueError(f"Unknown prediction type {self.prediction_type}")
489
+ # Masked latent loss
490
+ with accelerator.accumulate(self.model):
491
+ if self.gt_mask_type is not None:
492
+ depth_loss = self.loss(
493
+ model_pred[:, 4:, :, :][valid_mask_down].float(),
494
+ depth_target[valid_mask_down].float(),
495
+ )
496
+ else:
497
+ depth_loss = self.loss(model_pred[:, 4:, :, :].float(),depth_target.float())
498
+
499
+ rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float())
500
+
501
+ if torch.sum(rgb_timesteps) == 0 or torch.sum(rgb_timesteps) == len(rgb_timesteps) * self.scheduler_timesteps:
502
+ loss = depth_loss
503
+ elif torch.sum(depth_timesteps) == 0 or torch.sum(depth_timesteps) == len(depth_timesteps) * self.scheduler_timesteps:
504
+ loss = rgb_loss
505
+ else:
506
+ loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss
507
+
508
+ self.train_metrics.update("loss", loss.item())
509
+ self.train_metrics.update("rgb_loss", rgb_loss.item())
510
+ self.train_metrics.update("depth_loss", depth_loss.item())
511
+ # loss = loss / self.gradient_accumulation_steps
512
+ accelerator.backward(loss)
513
+ self.optimizer.step()
514
+ self.optimizer.zero_grad()
515
+ # loss.backward()
516
+ self.n_batch_in_epoch += 1
517
+ # print(accelerator.process_index, self.lr_scheduler.get_last_lr())
518
+ self.lr_scheduler.step(self.effective_iter)
519
+
520
+ if accelerator.sync_gradients:
521
+ accumulated_step += 1
522
+
523
+ if accumulated_step >= self.gradient_accumulation_steps:
524
+ accumulated_step = 0
525
+ self.effective_iter += 1
526
+
527
+ if accelerator.is_main_process:
528
+ # Log to tensorboard
529
+ if self.effective_iter == 1:
530
+ generator = torch.Generator(self.model.device).manual_seed(1024)
531
+ img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator,
532
+ show_pbar=True)
533
+ for idx in range(len(self.prompt)):
534
+ tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter)
535
+ self._depth2image()
536
+ self._image2depth()
537
+
538
+ accumulated_loss = self.train_metrics.result()["loss"]
539
+ rgb_loss = self.train_metrics.result()["rgb_loss"]
540
+ depth_loss = self.train_metrics.result()["depth_loss"]
541
+ tb_logger.log_dic(
542
+ {
543
+ f"train/{k}": v
544
+ for k, v in self.train_metrics.result().items()
545
+ },
546
+ global_step=self.effective_iter,
547
+ )
548
+ tb_logger.writer.add_scalar(
549
+ "lr",
550
+ self.lr_scheduler.get_last_lr()[0],
551
+ global_step=self.effective_iter,
552
+ )
553
+ tb_logger.writer.add_scalar(
554
+ "n_batch_in_epoch",
555
+ self.n_batch_in_epoch,
556
+ global_step=self.effective_iter,
557
+ )
558
+ logging.info(
559
+ f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}"
560
+ )
561
+ accelerator.wait_for_everyone()
562
+
563
+ if self.save_period > 0 and 0 == self.effective_iter % self.save_period:
564
+ accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest'))
565
+ unwrapped_model = accelerator.unwrap_model(self.model)
566
+ if accelerator.is_main_process:
567
+ accelerator.save_model(unwrapped_model.unet,
568
+ os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False)
569
+ self.save_miscs('latest')
570
+
571
+ # RGB-D joint generation
572
+ generator = torch.Generator(self.model.device).manual_seed(1024)
573
+ img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, show_pbar=False, height=64, width=64)
574
+ for idx in range(len(self.prompt)):
575
+ tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter)
576
+
577
+ # depth to RGB generation
578
+ self._depth2image()
579
+ # # RGB to depth generation
580
+ self._image2depth()
581
+
582
+ accelerator.wait_for_everyone()
583
+
584
+ if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
585
+ unwrapped_model = accelerator.unwrap_model(self.model)
586
+ if accelerator.is_main_process:
587
+ unwrapped_model.unet.save_pretrained(
588
+ os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()))
589
+ accelerator.wait_for_everyone()
590
+
591
+ if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
592
+ self.validate()
593
+
594
+ # End of training
595
+ if self.max_iter > 0 and self.effective_iter >= self.max_iter:
596
+ unwrapped_model = accelerator.unwrap_model(self.model)
597
+ if accelerator.is_main_process:
598
+ unwrapped_model.unet.save_pretrained(
599
+ os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()))
600
+ accelerator.wait_for_everyone()
601
+ return
602
+
603
+ torch.cuda.empty_cache()
604
+ # <<< Effective batch end <<<
605
+
606
+ # Epoch end
607
+ self.n_batch_in_epoch = 0
608
+
609
+ def _image2depth(self):
610
+ generator = torch.Generator(self.model.device).manual_seed(1024)
611
+ image2dept_paths = ['/home/aiops/wangzh/data/scannet/scene0593_00/color/000100.jpg',
612
+ '/home/aiops/wangzh/data/scannet/scene0593_00/color/000700.jpg',
613
+ '/home/aiops/wangzh/data/scannet/scene0591_01/color/000600.jpg',
614
+ '/home/aiops/wangzh/data/scannet/scene0591_01/color/001500.jpg']
615
+ for img_idx, image_path in enumerate(image2dept_paths):
616
+ rgb_input = Image.open(image_path)
617
+ depth_pred: MarigoldDepthOutput = self.model.image2depth(
618
+ rgb_input,
619
+ denoising_steps=self.cfg.validation.denoising_steps,
620
+ ensemble_size=self.cfg.validation.ensemble_size,
621
+ processing_res=self.cfg.validation.processing_res,
622
+ match_input_res=self.cfg.validation.match_input_res,
623
+ generator=generator,
624
+ batch_size=self.cfg.validation.ensemble_size,
625
+ # use batch size 1 to increase reproducibility
626
+ color_map="Spectral",
627
+ show_progress_bar=False,
628
+ resample_method=self.cfg.validation.resample_method,
629
+ )
630
+ img = self.model.post_process_rgbd(['None'], [rgb_input], [depth_pred['depth_colored']])
631
+ tb_logger.writer.add_image(f'image2depth_{img_idx}', img[0], self.effective_iter)
632
+
633
+ def _depth2image(self):
634
+ generator = torch.Generator(self.model.device).manual_seed(1024)
635
+ if "least_square_disparity" == self.cfg.eval.alignment:
636
+ depth2image_path = ['/home/aiops/wangzh/data/ori_depth_part0-0/sa_10000335.jpg',
637
+ '/home/aiops/wangzh/data/ori_depth_part0-0/sa_3572319.jpg',
638
+ '/home/aiops/wangzh/data/ori_depth_part0-0/sa_457934.jpg']
639
+ else:
640
+ depth2image_path = ['/home/aiops/wangzh/data/sa_001000/sa_10000335.jpg',
641
+ '/home/aiops/wangzh/data/sa_000357/sa_3572319.jpg',
642
+ '/home/aiops/wangzh/data/sa_000045/sa_457934.jpg']
643
+ prompts = ['Red car parked in the factory',
644
+ 'White gothic church with cemetery next to it',
645
+ 'House with red roof and starry sky in the background']
646
+ for img_idx, depth_path in enumerate(depth2image_path):
647
+ depth_input = Image.open(depth_path)
648
+ image_pred = self.model.single_depth2image(
649
+ depth_input,
650
+ prompts[img_idx],
651
+ num_inference_steps=50,
652
+ processing_res=self.cfg.validation.processing_res,
653
+ generator=generator,
654
+ show_pbar=False,
655
+ resample_method=self.cfg.validation.resample_method,
656
+ )
657
+ img = self.model.post_process_rgbd([prompts[img_idx]], [image_pred], [depth_input])
658
+ tb_logger.writer.add_image(f'depth2image_{img_idx}', img[0], self.effective_iter)
659
+
660
+ def encode_depth(self, depth_in):
661
+ # stack depth into 3-channel
662
+ stacked = self.stack_depth_images(depth_in)
663
+ # encode using VAE encoder
664
+ depth_latent = self.model.encode_rgb(stacked)
665
+ return depth_latent
666
+
667
+ @staticmethod
668
+ def stack_depth_images(depth_in):
669
+ if 4 == len(depth_in.shape):
670
+ stacked = depth_in.repeat(1, 3, 1, 1)
671
+ elif 3 == len(depth_in.shape):
672
+ stacked = depth_in.unsqueeze(1)
673
+ stacked = stacked.repeat(1, 3, 1, 1)
674
+ return stacked
675
+
676
+ def validate(self):
677
+ for i, val_loader in enumerate(self.val_loaders):
678
+ val_dataset_name = val_loader.dataset.disp_name
679
+ val_metric_dic = self.validate_single_dataset(
680
+ data_loader=val_loader, metric_tracker=self.val_metrics
681
+ )
682
+
683
+ if self.accelerator.is_main_process:
684
+ val_metric_dic = {k:torch.tensor(v).cuda() for k,v in val_metric_dic.items()}
685
+
686
+ tb_logger.log_dic(
687
+ {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
688
+ global_step=self.effective_iter,
689
+ )
690
+ # save to file
691
+ eval_text = eval_dic_to_text(
692
+ val_metrics=val_metric_dic,
693
+ dataset_name=val_dataset_name,
694
+ sample_list_path=val_loader.dataset.filename_ls_path,
695
+ )
696
+ _save_to = os.path.join(
697
+ self.out_dir_eval,
698
+ f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
699
+ )
700
+ with open(_save_to, "w+") as f:
701
+ f.write(eval_text)
702
+
703
+ # Update main eval metric
704
+ if 0 == i:
705
+ main_eval_metric = val_metric_dic[self.main_val_metric]
706
+ if (
707
+ "minimize" == self.main_val_metric_goal
708
+ and main_eval_metric < self.best_metric
709
+ or "maximize" == self.main_val_metric_goal
710
+ and main_eval_metric > self.best_metric
711
+ ):
712
+ self.best_metric = main_eval_metric
713
+ logging.info(
714
+ f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
715
+ )
716
+ # Save a checkpoint
717
+ self.save_checkpoint(
718
+ ckpt_name='best', save_train_state=False
719
+ )
720
+
721
+ self.accelerator.wait_for_everyone()
722
+
723
+ def visualize(self):
724
+ for val_loader in self.vis_loaders:
725
+ vis_dataset_name = val_loader.dataset.disp_name
726
+ vis_out_dir = os.path.join(
727
+ self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
728
+ )
729
+ os.makedirs(vis_out_dir, exist_ok=True)
730
+ _ = self.validate_single_dataset(
731
+ data_loader=val_loader,
732
+ metric_tracker=self.val_metrics,
733
+ save_to_dir=vis_out_dir,
734
+ )
735
+
736
+ @torch.no_grad()
737
+ def validate_single_dataset(
738
+ self,
739
+ data_loader: DataLoader,
740
+ metric_tracker: MetricTracker,
741
+ save_to_dir: str = None,
742
+ ):
743
+ self.model.to(self.device)
744
+ metric_tracker.reset()
745
+
746
+ # Generate seed sequence for consistent evaluation
747
+ val_init_seed = self.cfg.validation.init_seed
748
+ val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
749
+
750
+ for i, batch in enumerate(
751
+ tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
752
+ start=1,
753
+ ):
754
+
755
+ rgb_int = batch["rgb_int"] # [3, H, W]
756
+ # GT depth
757
+ depth_raw_ts = batch["depth_raw_linear"].squeeze()
758
+ depth_raw = depth_raw_ts.cpu().numpy()
759
+ depth_raw_ts = depth_raw_ts.to(self.device)
760
+ valid_mask_ts = batch["valid_mask_raw"].squeeze()
761
+ valid_mask = valid_mask_ts.cpu().numpy()
762
+ valid_mask_ts = valid_mask_ts.to(self.device)
763
+
764
+ # Random number generator
765
+ seed = val_seed_ls.pop()
766
+ if seed is None:
767
+ generator = None
768
+ else:
769
+ generator = torch.Generator(device=self.device)
770
+ generator.manual_seed(seed)
771
+
772
+ # Predict depth
773
+ pipe_out: MarigoldDepthOutput = self.model.image2depth(
774
+ rgb_int,
775
+ denoising_steps=self.cfg.validation.denoising_steps,
776
+ ensemble_size=self.cfg.validation.ensemble_size,
777
+ processing_res=self.cfg.validation.processing_res,
778
+ match_input_res=self.cfg.validation.match_input_res,
779
+ generator=generator,
780
+ batch_size=self.cfg.validation.ensemble_size, # use batch size 1 to increase reproducibility
781
+ color_map=None,
782
+ show_progress_bar=False,
783
+ resample_method=self.cfg.validation.resample_method,
784
+ )
785
+
786
+ depth_pred: np.ndarray = pipe_out.depth_np
787
+
788
+ if "least_square" == self.cfg.eval.alignment:
789
+ depth_pred, scale, shift = align_depth_least_square(
790
+ gt_arr=depth_raw,
791
+ pred_arr=depth_pred,
792
+ valid_mask_arr=valid_mask,
793
+ return_scale_shift=True,
794
+ max_resolution=self.cfg.eval.align_max_res,
795
+ )
796
+ elif "least_square_disparity" == self.cfg.eval.alignment:
797
+ # convert GT depth -> GT disparity
798
+ gt_disparity, gt_non_neg_mask = depth2disparity(
799
+ depth=depth_raw, return_mask=True
800
+ )
801
+
802
+ pred_non_neg_mask = depth_pred > 0
803
+ valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
804
+
805
+ disparity_pred, scale, shift = align_depth_least_square(
806
+ gt_arr=gt_disparity,
807
+ pred_arr=depth_pred,
808
+ valid_mask_arr=valid_nonnegative_mask,
809
+ return_scale_shift=True,
810
+ max_resolution=self.cfg.eval.align_max_res,
811
+ )
812
+ # convert to depth
813
+ disparity_pred = np.clip(
814
+ disparity_pred, a_min=1e-3, a_max=None
815
+ ) # avoid 0 disparity
816
+ depth_pred = disparity2depth(disparity_pred)
817
+
818
+ # Clip to dataset min max
819
+ depth_pred = np.clip(
820
+ depth_pred,
821
+ a_min=data_loader.dataset.min_depth,
822
+ a_max=data_loader.dataset.max_depth,
823
+ )
824
+
825
+ # clip to d > 0 for evaluation
826
+ depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
827
+
828
+ # Evaluate
829
+ sample_metric = []
830
+ depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
831
+
832
+ for met_func in self.metric_funcs:
833
+ _metric_name = met_func.__name__
834
+ _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).cuda(self.accelerator.process_index)
835
+ self.accelerator.wait_for_everyone()
836
+ _metric = self.accelerator.gather_for_metrics(_metric.unsqueeze(0)).mean().item()
837
+ sample_metric.append(_metric.__str__())
838
+ metric_tracker.update(_metric_name, _metric)
839
+
840
+ self.accelerator.wait_for_everyone()
841
+ # Save as 16-bit uint png
842
+ if save_to_dir is not None:
843
+ img_name = batch["rgb_relative_path"][0].replace("/", "_")
844
+ png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
845
+ depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16)
846
+ Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
847
+
848
+ return metric_tracker.result()
849
+
850
+ def _get_next_seed(self):
851
+ if 0 == len(self.global_seed_sequence):
852
+ self.global_seed_sequence = generate_seed_sequence(
853
+ initial_seed=self.seed,
854
+ length=self.max_iter * self.gradient_accumulation_steps,
855
+ )
856
+ logging.info(
857
+ f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
858
+ )
859
+ return self.global_seed_sequence.pop()
860
+
861
+ def save_miscs(self, ckpt_name):
862
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
863
+ state = {
864
+ "config": self.cfg,
865
+ "effective_iter": self.effective_iter,
866
+ "epoch": self.epoch,
867
+ "n_batch_in_epoch": self.n_batch_in_epoch,
868
+ "best_metric": self.best_metric,
869
+ "in_evaluation": self.in_evaluation,
870
+ "global_seed_sequence": self.global_seed_sequence,
871
+ }
872
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
873
+ torch.save(state, train_state_path)
874
+
875
+ logging.info(f"Misc state is saved to: {train_state_path}")
876
+
877
+ def load_miscs(self, ckpt_path):
878
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
879
+ self.effective_iter = checkpoint["effective_iter"]
880
+ self.epoch = checkpoint["epoch"]
881
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
882
+ self.in_evaluation = checkpoint["in_evaluation"]
883
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
884
+
885
+ self.best_metric = checkpoint["best_metric"]
886
+
887
+ logging.info(f"Misc state is loaded from {ckpt_path}")
888
+
889
+
890
+ def save_checkpoint(self, ckpt_name, save_train_state):
891
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
892
+ logging.info(f"Saving checkpoint to: {ckpt_dir}")
893
+ # Backup previous checkpoint
894
+ temp_ckpt_dir = None
895
+ if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
896
+ temp_ckpt_dir = os.path.join(
897
+ os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
898
+ )
899
+ if os.path.exists(temp_ckpt_dir):
900
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
901
+ os.rename(ckpt_dir, temp_ckpt_dir)
902
+ logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
903
+
904
+ # Save UNet
905
+ unet_path = os.path.join(ckpt_dir, "unet")
906
+ self.model.unet.save_pretrained(unet_path, safe_serialization=False)
907
+ logging.info(f"UNet is saved to: {unet_path}")
908
+
909
+ if save_train_state:
910
+ state = {
911
+ "config": self.cfg,
912
+ "effective_iter": self.effective_iter,
913
+ "epoch": self.epoch,
914
+ "n_batch_in_epoch": self.n_batch_in_epoch,
915
+ "best_metric": self.best_metric,
916
+ "in_evaluation": self.in_evaluation,
917
+ "global_seed_sequence": self.global_seed_sequence,
918
+ }
919
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
920
+ torch.save(state, train_state_path)
921
+ # iteration indicator
922
+ f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
923
+ f.close()
924
+
925
+ logging.info(f"Trainer state is saved to: {train_state_path}")
926
+
927
+ # Remove temp ckpt
928
+ if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
929
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
930
+ logging.debug("Old checkpoint backup is removed.")
931
+
932
+ def load_checkpoint(
933
+ self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
934
+ ):
935
+ logging.info(f"Loading checkpoint from: {ckpt_path}")
936
+ # Load UNet
937
+ _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
938
+ self.model.unet.load_state_dict(
939
+ torch.load(_model_path, map_location=self.device)
940
+ )
941
+ self.model.unet.to(self.device)
942
+ logging.info(f"UNet parameters are loaded from {_model_path}")
943
+
944
+ # Load training states
945
+ if load_trainer_state:
946
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
947
+ self.effective_iter = checkpoint["effective_iter"]
948
+ self.epoch = checkpoint["epoch"]
949
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
950
+ self.in_evaluation = checkpoint["in_evaluation"]
951
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
952
+
953
+ self.best_metric = checkpoint["best_metric"]
954
+
955
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
956
+ logging.info(f"optimizer state is loaded from {ckpt_path}")
957
+
958
+ if resume_lr_scheduler:
959
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
960
+ logging.info(f"LR scheduler state is loaded from {ckpt_path}")
961
+
962
+ logging.info(
963
+ f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
964
+ )
965
+ return
966
+
967
+ def _get_backup_ckpt_name(self):
968
+ return f"iter_{self.effective_iter:06d}"
src/trainer/marigold_xl_trainer.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An official reimplemented version of Marigold training script.
2
+ # Last modified: 2024-04-29
3
+ #
4
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # --------------------------------------------------------------------------
18
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
19
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
20
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
21
+ # More information about the method can be found at https://marigoldmonodepth.github.io
22
+ # --------------------------------------------------------------------------
23
+
24
+
25
+ import logging
26
+ import os
27
+ import pdb
28
+ import shutil
29
+ from datetime import datetime
30
+ from typing import List, Union
31
+ import safetensors
32
+ import numpy as np
33
+ import torch
34
+ from diffusers import DDPMScheduler
35
+ from omegaconf import OmegaConf
36
+ from torch.nn import Conv2d
37
+ from torch.nn.parameter import Parameter
38
+ from torch.optim import Adam
39
+ from torch.optim.lr_scheduler import LambdaLR
40
+ from torch.utils.data import DataLoader
41
+ from tqdm import tqdm
42
+ from PIL import Image
43
+ # import torch.optim.lr_scheduler
44
+
45
+ from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput
46
+ from src.util import metric
47
+ from src.util.data_loader import skip_first_batches
48
+ from src.util.logging_util import tb_logger, eval_dic_to_text
49
+ from src.util.loss import get_loss
50
+ from src.util.lr_scheduler import IterExponential
51
+ from src.util.metric import MetricTracker
52
+ from src.util.multi_res_noise import multi_res_noise_like
53
+ from src.util.alignment import align_depth_least_square
54
+ from src.util.seeding import generate_seed_sequence
55
+ from accelerate import Accelerator
56
+ import random
57
+
58
+ class MarigoldXLTrainer:
59
+ def __init__(
60
+ self,
61
+ cfg: OmegaConf,
62
+ model: MarigoldPipeline,
63
+ train_dataloader: DataLoader,
64
+ device,
65
+ base_ckpt_dir,
66
+ out_dir_ckpt,
67
+ out_dir_eval,
68
+ out_dir_vis,
69
+ accumulation_steps: int,
70
+ separate_list: List = None,
71
+ val_dataloaders: List[DataLoader] = None,
72
+ vis_dataloaders: List[DataLoader] = None,
73
+ timestep_method: str = 'unidiffuser'
74
+ ):
75
+ self.cfg: OmegaConf = cfg
76
+ self.model: MarigoldPipeline = model
77
+ self.device = device
78
+ self.seed: Union[int, None] = (
79
+ self.cfg.trainer.init_seed
80
+ ) # used to generate seed sequence, set to `None` to train w/o seeding
81
+ self.out_dir_ckpt = out_dir_ckpt
82
+ self.out_dir_eval = out_dir_eval
83
+ self.out_dir_vis = out_dir_vis
84
+ self.train_loader: DataLoader = train_dataloader
85
+ self.val_loaders: List[DataLoader] = val_dataloaders
86
+ self.vis_loaders: List[DataLoader] = vis_dataloaders
87
+ self.accumulation_steps: int = accumulation_steps
88
+ self.separate_list = separate_list
89
+ self.timestep_method = timestep_method
90
+ # Adapt input layers
91
+ # if 8 != self.model.unet.config["in_channels"]:
92
+ # self._replace_unet_conv_in()
93
+ # if 8 != self.model.unet.config["out_channels"]:
94
+ # self._replace_unet_conv_out()
95
+
96
+ self.prompt = ['a view of a city skyline from a bridge',
97
+ 'a man and a woman sitting on a couch',
98
+ 'a black car parked in a parking lot next to the water',
99
+ 'Enchanted forest with glowing plants, fairies, and ancient castle.',
100
+ 'Futuristic city with skyscrapers, neon lights, and hovering vehicles.',
101
+ 'Fantasy mountain landscape with waterfalls, dragons, and mythical creatures.']
102
+ # self.generator = torch.Generator('cuda:0').manual_seed(1024)
103
+
104
+ # Encode empty text prompt
105
+ # self.model.encode_empty_text()
106
+ # self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
107
+
108
+ self.model.unet.enable_xformers_memory_efficient_attention()
109
+
110
+ # Trainability
111
+ self.model.vae.requires_grad_(False)
112
+ self.model.text_encoder.requires_grad_(False)
113
+ # self.model.unet.requires_grad_(True)
114
+
115
+ grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters())
116
+
117
+ # Optimizer !should be defined after input layer is adapted
118
+ lr = self.cfg.lr
119
+ self.optimizer = Adam(grad_part, lr=lr)
120
+
121
+ total_params = sum(p.numel() for p in self.model.unet.parameters())
122
+ total_params_m = total_params / 1_000_000
123
+ print(f"Total parameters: {total_params_m:.2f}M")
124
+ trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad)
125
+ trainable_params_m = trainable_params / 1_000_000
126
+ print(f"Trainable parameters: {trainable_params_m:.2f}M")
127
+
128
+ # LR scheduler
129
+ lr_func = IterExponential(
130
+ total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
131
+ final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
132
+ warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
133
+ )
134
+ self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
135
+
136
+ # Loss
137
+ self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
138
+
139
+ # Training noise scheduler
140
+ self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
141
+ os.path.join(
142
+ cfg.trainer.training_noise_scheduler.pretrained_path,
143
+ "scheduler",
144
+ )
145
+ )
146
+ self.prediction_type = self.training_noise_scheduler.config.prediction_type
147
+ assert (
148
+ self.prediction_type == self.model.scheduler.config.prediction_type
149
+ ), "Different prediction types"
150
+ self.scheduler_timesteps = (
151
+ self.training_noise_scheduler.config.num_train_timesteps
152
+ )
153
+
154
+ # Eval metrics
155
+ self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
156
+ self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss'])
157
+ self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
158
+ # main metric for best checkpoint saving
159
+ self.main_val_metric = cfg.validation.main_val_metric
160
+ self.main_val_metric_goal = cfg.validation.main_val_metric_goal
161
+ assert (
162
+ self.main_val_metric in cfg.eval.eval_metrics
163
+ ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
164
+ self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
165
+
166
+ # Settings
167
+ self.max_epoch = self.cfg.max_epoch
168
+ self.max_iter = self.cfg.max_iter
169
+ self.gradient_accumulation_steps = accumulation_steps
170
+ self.gt_depth_type = self.cfg.gt_depth_type
171
+ self.gt_mask_type = self.cfg.gt_mask_type
172
+ self.save_period = self.cfg.trainer.save_period
173
+ self.backup_period = self.cfg.trainer.backup_period
174
+ self.val_period = self.cfg.trainer.validation_period
175
+ self.vis_period = self.cfg.trainer.visualization_period
176
+
177
+ # Multi-resolution noise
178
+ self.apply_multi_res_noise = self.cfg.multi_res_noise is not None
179
+ if self.apply_multi_res_noise:
180
+ self.mr_noise_strength = self.cfg.multi_res_noise.strength
181
+ self.annealed_mr_noise = self.cfg.multi_res_noise.annealed
182
+ self.mr_noise_downscale_strategy = (
183
+ self.cfg.multi_res_noise.downscale_strategy
184
+ )
185
+
186
+ # Internal variables
187
+ self.epoch = 0
188
+ self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
189
+ self.effective_iter = 0 # how many times optimizer.step() is called
190
+ self.in_evaluation = False
191
+ self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
192
+
193
+ def _replace_unet_conv_in(self):
194
+ # replace the first layer to accept 8 in_channels
195
+ _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3]
196
+ _bias = self.model.unet.conv_in.bias.clone() # [320]
197
+ zero_weight = torch.zeros(_weight.shape).to(_weight.device)
198
+ _weight = torch.cat([_weight, zero_weight], dim=1)
199
+ # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
200
+ # half the activation magnitude
201
+ # _weight *= 0.5
202
+ # new conv_in channel
203
+ _n_convin_out_channel = self.model.unet.conv_in.out_channels
204
+ _new_conv_in = Conv2d(
205
+ 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
206
+ )
207
+ _new_conv_in.weight = Parameter(_weight)
208
+ _new_conv_in.bias = Parameter(_bias)
209
+ self.model.unet.conv_in = _new_conv_in
210
+ logging.info("Unet conv_in layer is replaced")
211
+ # replace config
212
+ self.model.unet.config["in_channels"] = 8
213
+ logging.info("Unet config is updated")
214
+ return
215
+
216
+ def _replace_unet_conv_out(self):
217
+ # replace the first layer to accept 8 in_channels
218
+ _weight = self.model.unet.conv_out.weight.clone() # [8, 320, 3, 3]
219
+ _bias = self.model.unet.conv_out.bias.clone() # [320]
220
+ _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s)
221
+ _bias = _bias.repeat((2))
222
+ # half the activation magnitude
223
+
224
+ # new conv_in channel
225
+ _n_convin_out_channel = self.model.unet.conv_out.out_channels
226
+ _new_conv_out = Conv2d(
227
+ _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
228
+ )
229
+ _new_conv_out.weight = Parameter(_weight)
230
+ _new_conv_out.bias = Parameter(_bias)
231
+ self.model.unet.conv_out = _new_conv_out
232
+ logging.info("Unet conv_out layer is replaced")
233
+ # replace config
234
+ self.model.unet.config["out_channels"] = 8
235
+ logging.info("Unet config is updated")
236
+ return
237
+
238
+ def parallel_train(self, t_end=None, accelerator=None):
239
+ logging.info("Start training")
240
+
241
+ self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare(
242
+ self.model, self.optimizer, self.train_loader, self.lr_scheduler
243
+ )
244
+ self.accelerator = accelerator
245
+ if self.val_loaders is not None:
246
+ for idx, loader in enumerate(self.val_loaders):
247
+ self.val_loaders[idx] = accelerator.prepare(loader)
248
+
249
+ if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')):
250
+ accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest'))
251
+ self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest'))
252
+
253
+ self.train_metrics.reset()
254
+ accumulated_step = 0
255
+ for epoch in range(self.epoch, self.max_epoch + 1):
256
+ self.epoch = epoch
257
+ logging.debug(f"epoch: {self.epoch}")
258
+
259
+ # Skip previous batches when resume
260
+ for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
261
+ self.model.unet.train()
262
+
263
+ # globally consistent random generators
264
+ if self.seed is not None:
265
+ local_seed = self._get_next_seed()
266
+ rand_num_generator = torch.Generator(device=self.model.device)
267
+ rand_num_generator.manual_seed(local_seed)
268
+ else:
269
+ rand_num_generator = None
270
+
271
+ # >>> With gradient accumulation >>>
272
+
273
+ # Get data
274
+ rgb = batch["rgb_norm"].to(self.model.device)
275
+ depth_gt_for_latent = batch[self.gt_depth_type].to(self.model.device)
276
+ batch_size = rgb.shape[0]
277
+
278
+ if self.gt_mask_type is not None:
279
+ valid_mask_for_latent = batch[self.gt_mask_type].to(self.model.device)
280
+ invalid_mask = ~valid_mask_for_latent
281
+ valid_mask_down = ~torch.max_pool2d(
282
+ invalid_mask.float(), 8, 8
283
+ ).bool()
284
+ valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1))
285
+
286
+ with torch.no_grad():
287
+ # Encode image
288
+ rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
289
+ # Encode GT depth
290
+ gt_depth_latent = self.encode_depth(
291
+ depth_gt_for_latent
292
+ ) # [B, 4, h, w]
293
+
294
+ # Sample a random timestep for each image
295
+ if self.cfg.loss.depth_factor == 1:
296
+ rgb_timesteps = torch.zeros(
297
+ (batch_size),
298
+ device=self.model.device
299
+ ).long() # [B]
300
+ depth_timesteps = torch.randint(
301
+ 0,
302
+ self.scheduler_timesteps,
303
+ (batch_size,),
304
+ device=self.model.device,
305
+ generator=rand_num_generator,
306
+ ).long() # [B]
307
+ elif self.timestep_method == 'unidiffuser':
308
+ rgb_timesteps = torch.randint(
309
+ 0,
310
+ self.scheduler_timesteps,
311
+ (batch_size,),
312
+ device=self.model.device,
313
+ generator=rand_num_generator,
314
+ ).long() # [B]
315
+ depth_timesteps = torch.randint(
316
+ 0,
317
+ self.scheduler_timesteps,
318
+ (batch_size,),
319
+ device=self.model.device,
320
+ generator=rand_num_generator,
321
+ ).long() # [B]
322
+ elif self.timestep_method == 'partition':
323
+ rand_num = random.random()
324
+ if rand_num < 0.3333:
325
+ # joint generation
326
+ rgb_timesteps = torch.randint(
327
+ 0,
328
+ self.scheduler_timesteps,
329
+ (batch_size,),
330
+ device=self.model.device,
331
+ generator=rand_num_generator,
332
+ ).long() # [B]
333
+ depth_timesteps = rgb_timesteps
334
+ elif rand_num < 0.6666:
335
+ # image2depth generation
336
+ rgb_timesteps = torch.zeros(
337
+ (batch_size),
338
+ device=self.model.device
339
+ ).long() # [B]
340
+ depth_timesteps = torch.randint(
341
+ 0,
342
+ self.scheduler_timesteps,
343
+ (batch_size,),
344
+ device=self.model.device,
345
+ generator=rand_num_generator,
346
+ ).long() # [B]
347
+ else:
348
+ # depth2image generation
349
+ rgb_timesteps = torch.randint(
350
+ 0,
351
+ self.scheduler_timesteps,
352
+ (batch_size,),
353
+ device=self.model.device,
354
+ generator=rand_num_generator,
355
+ ).long() # [B]
356
+ depth_timesteps = torch.zeros(
357
+ (batch_size),
358
+ device=self.model.device
359
+ ).long() # [B]
360
+
361
+ # Sample noise
362
+ if self.apply_multi_res_noise:
363
+ rgb_strength = self.mr_noise_strength
364
+ if self.annealed_mr_noise:
365
+ # calculate strength depending on t
366
+ rgb_strength = rgb_strength * (rgb_timesteps / self.scheduler_timesteps)
367
+ rgb_noise = multi_res_noise_like(
368
+ rgb_latent,
369
+ strength=rgb_strength,
370
+ downscale_strategy=self.mr_noise_downscale_strategy,
371
+ generator=rand_num_generator,
372
+ device=self.model.device,
373
+ )
374
+
375
+ depth_strength = self.mr_noise_strength
376
+ if self.annealed_mr_noise:
377
+ # calculate strength depending on t
378
+ depth_strength = depth_strength * (depth_timesteps / self.scheduler_timesteps)
379
+ depth_noise = multi_res_noise_like(
380
+ gt_depth_latent,
381
+ strength=depth_strength,
382
+ downscale_strategy=self.mr_noise_downscale_strategy,
383
+ generator=rand_num_generator,
384
+ device=self.model.device,
385
+ )
386
+ else:
387
+ rgb_noise = torch.randn(
388
+ rgb_latent.shape,
389
+ device=self.model.device,
390
+ generator=rand_num_generator,
391
+ ) # [B, 8, h, w]
392
+
393
+ depth_noise = torch.randn(
394
+ gt_depth_latent.shape,
395
+ device=self.model.device,
396
+ generator=rand_num_generator,
397
+ ) # [B, 8, h, w]
398
+ # Add noise to the latents (diffusion forward process)
399
+
400
+ noisy_rgb_latents = self.training_noise_scheduler.add_noise(
401
+ rgb_latent, rgb_noise, rgb_timesteps
402
+ ) # [B, 4, h, w]
403
+ noisy_depth_latents = self.training_noise_scheduler.add_noise(
404
+ gt_depth_latent, depth_noise, depth_timesteps
405
+ ) # [B, 4, h, w]
406
+
407
+ noisy_latents = torch.cat(
408
+ [noisy_rgb_latents, noisy_depth_latents], dim=1
409
+ ).float() # [B, 8, h, w]
410
+
411
+ # Text embedding
412
+ batch_text_embed = []
413
+ batch_pooled_text_embed = []
414
+ for p in batch['text']:
415
+ prompt_embed, pooled_prompt_embed = self.model.encode_text(p)
416
+ batch_text_embed.append(prompt_embed)
417
+ batch_pooled_text_embed.append(pooled_prompt_embed)
418
+ batch_text_embed = torch.cat(batch_text_embed, dim=0)
419
+ batch_pooled_text_embed = torch.cat(batch_pooled_text_embed, dim=0)
420
+ # input_ids = {k:v.squeeze().to(self.model.device) for k,v in batch['text'].items()}
421
+ # prompt_embed, pooled_prompt_embed = self.model.encode_text(batch['text'])
422
+ # text_embed = self.empty_text_embed.to(device).repeat(
423
+ # (batch_size, 1, 1)
424
+ # ) # [B, 77, 1024]
425
+ # Predict the noise residual
426
+ add_time_ids = self.model._get_add_time_ids(
427
+ (batch['rgb_int'].shape[-2], batch['rgb_int'].shape[-1]), (0, 0), (batch['rgb_int'].shape[-2], batch['rgb_int'].shape[-1]), dtype=batch_text_embed.dtype
428
+ )
429
+ pdb.set_trace()
430
+ dtype = self.model.unet.dtype
431
+ added_cond_kwargs = {"text_embeds": batch_pooled_text_embed.to(self.model.device).to(dtype), "time_ids": add_time_ids.to(self.model.device).to(dtype)}
432
+ model_pred = self.model.unet(
433
+ noisy_latents.to(self.model.unet.dtype), rgb_timesteps, depth_timesteps, encoder_hidden_states=batch_text_embed.to(dtype),
434
+ added_cond_kwargs=added_cond_kwargs, separate_list=self.separate_list
435
+ ).sample # [B, 4, h, w]
436
+ if torch.isnan(model_pred).any():
437
+ logging.warning("model_pred contains NaN.")
438
+
439
+ # Get the target for loss depending on the prediction type
440
+ if "sample" == self.prediction_type:
441
+ rgb_target = rgb_latent
442
+ depth_target = gt_depth_latent
443
+ elif "epsilon" == self.prediction_type:
444
+ rgb_target = rgb_latent
445
+ depth_target = gt_depth_latent
446
+ elif "v_prediction" == self.prediction_type:
447
+ rgb_target = self.training_noise_scheduler.get_velocity(
448
+ rgb_latent, rgb_noise, rgb_timesteps
449
+ ) # [B, 4, h, w]
450
+ depth_target = self.training_noise_scheduler.get_velocity(
451
+ gt_depth_latent, depth_noise, depth_timesteps
452
+ ) # [B, 4, h, w]
453
+ else:
454
+ raise ValueError(f"Unknown prediction type {self.prediction_type}")
455
+ # Masked latent loss
456
+ with accelerator.accumulate(self.model):
457
+ if self.gt_mask_type is not None:
458
+ depth_loss = self.loss(
459
+ model_pred[:, 4:, :, :][valid_mask_down].float(),
460
+ depth_target[valid_mask_down].float(),
461
+ )
462
+ else:
463
+ depth_loss = self.cfg.loss.depth_factor * self.loss(model_pred[:, 4:, :, :].float(),depth_target.float())
464
+
465
+ rgb_loss = (1 - self.cfg.loss.depth_factor) * self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float())
466
+ if self.cfg.loss.depth_factor == 1:
467
+ loss = depth_loss
468
+ else:
469
+ loss = rgb_loss + depth_loss
470
+
471
+ self.train_metrics.update("loss", loss.item())
472
+ self.train_metrics.update("rgb_loss", rgb_loss.item())
473
+ self.train_metrics.update("depth_loss", depth_loss.item())
474
+ # loss = loss / self.gradient_accumulation_steps
475
+ accelerator.backward(loss)
476
+ self.optimizer.step()
477
+ self.optimizer.zero_grad()
478
+ # loss.backward()
479
+ self.n_batch_in_epoch += 1
480
+ # print(accelerator.process_index, self.lr_scheduler.get_last_lr())
481
+ self.lr_scheduler.step(self.effective_iter)
482
+
483
+ if accelerator.sync_gradients:
484
+ accumulated_step += 1
485
+
486
+ if accumulated_step >= self.gradient_accumulation_steps:
487
+ accumulated_step = 0
488
+ self.effective_iter += 1
489
+
490
+ if accelerator.is_main_process:
491
+ # Log to tensorboard
492
+ if self.effective_iter == 1:
493
+ generator = torch.Generator(self.model.device).manual_seed(1024)
494
+ img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator,
495
+ show_pbar=True)
496
+ for idx in range(len(self.prompt)):
497
+ tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter)
498
+
499
+ accumulated_loss = self.train_metrics.result()["loss"]
500
+ rgb_loss = self.train_metrics.result()["rgb_loss"]
501
+ depth_loss = self.train_metrics.result()["depth_loss"]
502
+ tb_logger.log_dic(
503
+ {
504
+ f"train/{k}": v
505
+ for k, v in self.train_metrics.result().items()
506
+ },
507
+ global_step=self.effective_iter,
508
+ )
509
+ tb_logger.writer.add_scalar(
510
+ "lr",
511
+ self.lr_scheduler.get_last_lr()[0],
512
+ global_step=self.effective_iter,
513
+ )
514
+ tb_logger.writer.add_scalar(
515
+ "n_batch_in_epoch",
516
+ self.n_batch_in_epoch,
517
+ global_step=self.effective_iter,
518
+ )
519
+ logging.info(
520
+ f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}"
521
+ )
522
+ accelerator.wait_for_everyone()
523
+
524
+ if self.save_period > 0 and 0 == self.effective_iter % self.save_period:
525
+ accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest'))
526
+ unwrapped_model = accelerator.unwrap_model(self.model)
527
+ if accelerator.is_main_process:
528
+ accelerator.save_model(unwrapped_model.unet,
529
+ os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False)
530
+ self.save_miscs('latest')
531
+
532
+ # RGB-D joint generation
533
+ generator = torch.Generator(self.model.device).manual_seed(1024)
534
+ img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator,show_pbar=False)
535
+ for idx in range(len(self.prompt)):
536
+ tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter)
537
+
538
+ # depth to RGB generation
539
+ self._depth2image()
540
+ from diffusers import StableDiffusionControlNetInpaintPipeline
541
+ # RGB to depth generation
542
+ self._image2depth()
543
+
544
+ accelerator.wait_for_everyone()
545
+
546
+ accelerator.wait_for_everyone()
547
+
548
+ if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
549
+ unwrapped_model = accelerator.unwrap_model(self.model)
550
+ if accelerator.is_main_process:
551
+ unwrapped_model.unet.save_pretrained(
552
+ os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()))
553
+ accelerator.wait_for_everyone()
554
+
555
+ if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
556
+ self.validate()
557
+
558
+ # End of training
559
+ if self.max_iter > 0 and self.effective_iter >= self.max_iter:
560
+ unwrapped_model = accelerator.unwrap_model(self.model)
561
+ if accelerator.is_main_process:
562
+ unwrapped_model.unet.save_pretrained(
563
+ os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()))
564
+ accelerator.wait_for_everyone()
565
+ return
566
+
567
+ torch.cuda.empty_cache()
568
+ # <<< Effective batch end <<<
569
+
570
+ # Epoch end
571
+ self.n_batch_in_epoch = 0
572
+
573
+ def _image2depth(self):
574
+ generator = torch.Generator(self.model.device).manual_seed(1024)
575
+ image2dept_paths = ['/home/aiops/wangzh/data/scannet/scene0593_00/color/000100.jpg',
576
+ '/home/aiops/wangzh/data/scannet/scene0593_00/color/000700.jpg',
577
+ '/home/aiops/wangzh/data/scannet/scene0591_01/color/000600.jpg',
578
+ '/home/aiops/wangzh/data/scannet/scene0591_01/color/001500.jpg']
579
+ for img_idx, image_path in enumerate(image2dept_paths):
580
+ rgb_input = Image.open(image_path)
581
+ depth_pred: MarigoldDepthOutput = self.model.image2depth(
582
+ rgb_input,
583
+ denoising_steps=self.cfg.validation.denoising_steps,
584
+ ensemble_size=self.cfg.validation.ensemble_size,
585
+ processing_res=self.cfg.validation.processing_res,
586
+ match_input_res=self.cfg.validation.match_input_res,
587
+ generator=generator,
588
+ batch_size=self.cfg.validation.ensemble_size,
589
+ # use batch size 1 to increase reproducibility
590
+ color_map="Spectral",
591
+ show_progress_bar=False,
592
+ resample_method=self.cfg.validation.resample_method,
593
+ )
594
+ img = self.model.post_process_rgbd(['None'], [rgb_input], [depth_pred['depth_colored']])
595
+ tb_logger.writer.add_image(f'image2depth_{img_idx}', img[0], self.effective_iter)
596
+
597
+ def _depth2image(self):
598
+ generator = torch.Generator(self.model.device).manual_seed(1024)
599
+ if "least_square_disparity" == self.cfg.eval.alignment:
600
+ depth2image_path = ['/home/aiops/wangzh/data/ori_depth_part0-0/sa_10000335.jpg',
601
+ '/home/aiops/wangzh/data/ori_depth_part0-0/sa_3572319.jpg',
602
+ '/home/aiops/wangzh/data/ori_depth_part0-0/sa_457934.jpg']
603
+ else:
604
+ depth2image_path = ['/home/aiops/wangzh/data/depth_part0-0/sa_10000335.jpg',
605
+ '/home/aiops/wangzh/data/depth_part0-0/sa_3572319.jpg',
606
+ '/home/aiops/wangzh/data/depth_part0-0/sa_457934.jpg']
607
+ prompts = ['Red car parked in the factory',
608
+ 'White gothic church with cemetery next to it',
609
+ 'House with red roof and starry sky in the background']
610
+ for img_idx, depth_path in enumerate(depth2image_path):
611
+ depth_input = Image.open(depth_path)
612
+ image_pred = self.model.single_depth2image(
613
+ depth_input,
614
+ prompts[img_idx],
615
+ num_inference_steps=50,
616
+ processing_res=1024,
617
+ generator=generator,
618
+ show_pbar=False,
619
+ resample_method=self.cfg.validation.resample_method,
620
+ )
621
+ img = self.model.post_process_rgbd([prompts[img_idx]], [image_pred], [depth_input])
622
+ tb_logger.writer.add_image(f'depth2image_{img_idx}', img[0], self.effective_iter)
623
+
624
+ def encode_depth(self, depth_in):
625
+ # stack depth into 3-channel
626
+ stacked = self.stack_depth_images(depth_in)
627
+ # encode using VAE encoder
628
+ depth_latent = self.model.encode_rgb(stacked)
629
+ return depth_latent
630
+
631
+ @staticmethod
632
+ def stack_depth_images(depth_in):
633
+ if 4 == len(depth_in.shape):
634
+ stacked = depth_in.repeat(1, 3, 1, 1)
635
+ elif 3 == len(depth_in.shape):
636
+ stacked = depth_in.unsqueeze(1)
637
+ stacked = depth_in.repeat(1, 3, 1, 1)
638
+ return stacked
639
+
640
+ def _train_step_callback(self):
641
+ """Executed after every iteration"""
642
+ # Save backup (with a larger interval, without training states)
643
+ if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
644
+ self.save_checkpoint(
645
+ ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
646
+ )
647
+
648
+ _is_latest_saved = False
649
+ # Validation
650
+ if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
651
+ self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
652
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
653
+ _is_latest_saved = True
654
+ self.validate()
655
+ self.in_evaluation = False
656
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
657
+
658
+ # Save training checkpoint (can be resumed)
659
+ if (
660
+ self.save_period > 0
661
+ and 0 == self.effective_iter % self.save_period
662
+ and not _is_latest_saved
663
+ ):
664
+ generator = torch.Generator(self.model.device).manual_seed(1024)
665
+ img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, show_pbar=True)
666
+ for idx in range(len(self.prompt)):
667
+ tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter)
668
+
669
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
670
+
671
+ # Visualization
672
+ if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
673
+ self.visualize()
674
+
675
+ def validate(self):
676
+ for i, val_loader in enumerate(self.val_loaders):
677
+ val_dataset_name = val_loader.dataset.disp_name
678
+ val_metric_dic = self.validate_single_dataset(
679
+ data_loader=val_loader, metric_tracker=self.val_metrics
680
+ )
681
+
682
+ if self.accelerator.is_main_process:
683
+ val_metric_dic = {k:torch.tensor(v).cuda() for k,v in val_metric_dic.items()}
684
+
685
+ tb_logger.log_dic(
686
+ {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
687
+ global_step=self.effective_iter,
688
+ )
689
+ # save to file
690
+ eval_text = eval_dic_to_text(
691
+ val_metrics=val_metric_dic,
692
+ dataset_name=val_dataset_name,
693
+ sample_list_path=val_loader.dataset.filename_ls_path,
694
+ )
695
+ _save_to = os.path.join(
696
+ self.out_dir_eval,
697
+ f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
698
+ )
699
+ with open(_save_to, "w+") as f:
700
+ f.write(eval_text)
701
+
702
+ # Update main eval metric
703
+ if 0 == i:
704
+ main_eval_metric = val_metric_dic[self.main_val_metric]
705
+ if (
706
+ "minimize" == self.main_val_metric_goal
707
+ and main_eval_metric < self.best_metric
708
+ or "maximize" == self.main_val_metric_goal
709
+ and main_eval_metric > self.best_metric
710
+ ):
711
+ self.best_metric = main_eval_metric
712
+ logging.info(
713
+ f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
714
+ )
715
+ # Save a checkpoint
716
+ self.save_checkpoint(
717
+ ckpt_name='best', save_train_state=False
718
+ )
719
+
720
+ self.accelerator.wait_for_everyone()
721
+
722
+ def visualize(self):
723
+ for val_loader in self.vis_loaders:
724
+ vis_dataset_name = val_loader.dataset.disp_name
725
+ vis_out_dir = os.path.join(
726
+ self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
727
+ )
728
+ os.makedirs(vis_out_dir, exist_ok=True)
729
+ _ = self.validate_single_dataset(
730
+ data_loader=val_loader,
731
+ metric_tracker=self.val_metrics,
732
+ save_to_dir=vis_out_dir,
733
+ )
734
+
735
+ @torch.no_grad()
736
+ def validate_single_dataset(
737
+ self,
738
+ data_loader: DataLoader,
739
+ metric_tracker: MetricTracker,
740
+ save_to_dir: str = None,
741
+ ):
742
+ self.model.to(self.device)
743
+ metric_tracker.reset()
744
+
745
+ # Generate seed sequence for consistent evaluation
746
+ val_init_seed = self.cfg.validation.init_seed
747
+ val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
748
+
749
+ for i, batch in enumerate(
750
+ tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
751
+ start=1,
752
+ ):
753
+
754
+ rgb_int = batch["rgb_int"] # [3, H, W]
755
+ # GT depth
756
+ depth_raw_ts = batch["depth_raw_linear"].squeeze()
757
+ depth_raw = depth_raw_ts.cpu().numpy()
758
+ depth_raw_ts = depth_raw_ts.to(self.device)
759
+ valid_mask_ts = batch["valid_mask_raw"].squeeze()
760
+ valid_mask = valid_mask_ts.cpu().numpy()
761
+ valid_mask_ts = valid_mask_ts.to(self.device)
762
+
763
+ # Random number generator
764
+ seed = val_seed_ls.pop()
765
+ if seed is None:
766
+ generator = None
767
+ else:
768
+ generator = torch.Generator(device=self.device)
769
+ generator.manual_seed(seed)
770
+
771
+ # Predict depth
772
+ pipe_out: MarigoldDepthOutput = self.model.image2depth(
773
+ rgb_int,
774
+ denoising_steps=self.cfg.validation.denoising_steps,
775
+ ensemble_size=self.cfg.validation.ensemble_size,
776
+ processing_res=self.cfg.validation.processing_res,
777
+ match_input_res=self.cfg.validation.match_input_res,
778
+ generator=generator,
779
+ batch_size=self.cfg.validation.ensemble_size, # use batch size 1 to increase reproducibility
780
+ color_map=None,
781
+ show_progress_bar=False,
782
+ resample_method=self.cfg.validation.resample_method,
783
+ )
784
+
785
+ depth_pred: np.ndarray = pipe_out.depth_np
786
+
787
+ if "least_square" == self.cfg.eval.alignment:
788
+ depth_pred, scale, shift = align_depth_least_square(
789
+ gt_arr=depth_raw,
790
+ pred_arr=depth_pred,
791
+ valid_mask_arr=valid_mask,
792
+ return_scale_shift=True,
793
+ max_resolution=self.cfg.eval.align_max_res,
794
+ )
795
+ else:
796
+ raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
797
+
798
+ # Clip to dataset min max
799
+ depth_pred = np.clip(
800
+ depth_pred,
801
+ a_min=data_loader.dataset.min_depth,
802
+ a_max=data_loader.dataset.max_depth,
803
+ )
804
+
805
+ # clip to d > 0 for evaluation
806
+ depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
807
+
808
+ # Evaluate
809
+ sample_metric = []
810
+ depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
811
+
812
+ for met_func in self.metric_funcs:
813
+ _metric_name = met_func.__name__
814
+ _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).cuda(self.accelerator.process_index)
815
+ self.accelerator.wait_for_everyone()
816
+ _metric = self.accelerator.gather_for_metrics(_metric.unsqueeze(0)).mean().item()
817
+ sample_metric.append(_metric.__str__())
818
+ metric_tracker.update(_metric_name, _metric)
819
+
820
+ self.accelerator.wait_for_everyone()
821
+ # Save as 16-bit uint png
822
+ if save_to_dir is not None:
823
+ img_name = batch["rgb_relative_path"][0].replace("/", "_")
824
+ png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
825
+ depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16)
826
+ Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
827
+
828
+ return metric_tracker.result()
829
+
830
+ def _get_next_seed(self):
831
+ if 0 == len(self.global_seed_sequence):
832
+ self.global_seed_sequence = generate_seed_sequence(
833
+ initial_seed=self.seed,
834
+ length=self.max_iter * self.gradient_accumulation_steps,
835
+ )
836
+ logging.info(
837
+ f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
838
+ )
839
+ return self.global_seed_sequence.pop()
840
+
841
+ def save_miscs(self, ckpt_name):
842
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
843
+ state = {
844
+ "config": self.cfg,
845
+ "effective_iter": self.effective_iter,
846
+ "epoch": self.epoch,
847
+ "n_batch_in_epoch": self.n_batch_in_epoch,
848
+ "best_metric": self.best_metric,
849
+ "in_evaluation": self.in_evaluation,
850
+ "global_seed_sequence": self.global_seed_sequence,
851
+ }
852
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
853
+ torch.save(state, train_state_path)
854
+
855
+ logging.info(f"Misc state is saved to: {train_state_path}")
856
+
857
+ def load_miscs(self, ckpt_path):
858
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
859
+ self.effective_iter = checkpoint["effective_iter"]
860
+ self.epoch = checkpoint["epoch"]
861
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
862
+ self.in_evaluation = checkpoint["in_evaluation"]
863
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
864
+
865
+ self.best_metric = checkpoint["best_metric"]
866
+
867
+ logging.info(f"Misc state is loaded from {ckpt_path}")
868
+
869
+
870
+ def save_checkpoint(self, ckpt_name, save_train_state):
871
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
872
+ logging.info(f"Saving checkpoint to: {ckpt_dir}")
873
+ # Backup previous checkpoint
874
+ temp_ckpt_dir = None
875
+ if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
876
+ temp_ckpt_dir = os.path.join(
877
+ os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
878
+ )
879
+ if os.path.exists(temp_ckpt_dir):
880
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
881
+ os.rename(ckpt_dir, temp_ckpt_dir)
882
+ logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
883
+
884
+ # Save UNet
885
+ unet_path = os.path.join(ckpt_dir, "unet")
886
+ self.model.unet.save_pretrained(unet_path, safe_serialization=False)
887
+ logging.info(f"UNet is saved to: {unet_path}")
888
+
889
+ if save_train_state:
890
+ state = {
891
+ "config": self.cfg,
892
+ "effective_iter": self.effective_iter,
893
+ "epoch": self.epoch,
894
+ "n_batch_in_epoch": self.n_batch_in_epoch,
895
+ "best_metric": self.best_metric,
896
+ "in_evaluation": self.in_evaluation,
897
+ "global_seed_sequence": self.global_seed_sequence,
898
+ }
899
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
900
+ torch.save(state, train_state_path)
901
+ # iteration indicator
902
+ f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
903
+ f.close()
904
+
905
+ logging.info(f"Trainer state is saved to: {train_state_path}")
906
+
907
+ # Remove temp ckpt
908
+ if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
909
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
910
+ logging.debug("Old checkpoint backup is removed.")
911
+
912
+ def load_checkpoint(
913
+ self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
914
+ ):
915
+ logging.info(f"Loading checkpoint from: {ckpt_path}")
916
+ # Load UNet
917
+ _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
918
+ self.model.unet.load_state_dict(
919
+ torch.load(_model_path, map_location=self.device)
920
+ )
921
+ self.model.unet.to(self.device)
922
+ logging.info(f"UNet parameters are loaded from {_model_path}")
923
+
924
+ # Load training states
925
+ if load_trainer_state:
926
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
927
+ self.effective_iter = checkpoint["effective_iter"]
928
+ self.epoch = checkpoint["epoch"]
929
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
930
+ self.in_evaluation = checkpoint["in_evaluation"]
931
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
932
+
933
+ self.best_metric = checkpoint["best_metric"]
934
+
935
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
936
+ logging.info(f"optimizer state is loaded from {ckpt_path}")
937
+
938
+ if resume_lr_scheduler:
939
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
940
+ logging.info(f"LR scheduler state is loaded from {ckpt_path}")
941
+
942
+ logging.info(
943
+ f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
944
+ )
945
+ return
946
+
947
+ def _get_backup_ckpt_name(self):
948
+ return f"iter_{self.effective_iter:06d}"
src/util/__pycache__/alignment.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
src/util/__pycache__/config_util.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
src/util/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
src/util/__pycache__/depth_transform.cpython-310.pyc ADDED
Binary file (3.03 kB). View file
 
src/util/__pycache__/logging_util.cpython-310.pyc ADDED
Binary file (3.25 kB). View file
 
src/util/__pycache__/loss.cpython-310.pyc ADDED
Binary file (3.87 kB). View file
 
src/util/__pycache__/lr_scheduler.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
src/util/__pycache__/metric.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
src/util/__pycache__/multi_res_noise.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
src/util/__pycache__/seeding.cpython-310.pyc ADDED
Binary file (937 Bytes). View file
 
src/util/__pycache__/slurm_util.cpython-310.pyc ADDED
Binary file (483 Bytes). View file
 
src/util/alignment.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-01-11
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def align_depth_least_square(
9
+ gt_arr: np.ndarray,
10
+ pred_arr: np.ndarray,
11
+ valid_mask_arr: np.ndarray,
12
+ return_scale_shift=True,
13
+ max_resolution=None,
14
+ ):
15
+ ori_shape = pred_arr.shape # input shape
16
+
17
+ gt = gt_arr.squeeze() # [H, W]
18
+ pred = pred_arr.squeeze()
19
+ valid_mask = valid_mask_arr.squeeze()
20
+
21
+ # Downsample
22
+ if max_resolution is not None:
23
+ scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
24
+ if scale_factor < 1:
25
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
26
+ gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
27
+ pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
28
+ valid_mask = (
29
+ downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
30
+ .bool()
31
+ .numpy()
32
+ )
33
+
34
+ assert (
35
+ gt.shape == pred.shape == valid_mask.shape
36
+ ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
37
+
38
+ gt_masked = gt[valid_mask].reshape((-1, 1))
39
+ pred_masked = pred[valid_mask].reshape((-1, 1))
40
+
41
+ # numpy solver
42
+ _ones = np.ones_like(pred_masked)
43
+ A = np.concatenate([pred_masked, _ones], axis=-1)
44
+ X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
45
+ scale, shift = X
46
+
47
+ aligned_pred = pred_arr * scale + shift
48
+
49
+ # restore dimensions
50
+ aligned_pred = aligned_pred.reshape(ori_shape)
51
+
52
+ if return_scale_shift:
53
+ return aligned_pred, scale, shift
54
+ else:
55
+ return aligned_pred
56
+
57
+
58
+ # ******************** disparity space ********************
59
+ def depth2disparity(depth, return_mask=False):
60
+ if isinstance(depth, torch.Tensor):
61
+ disparity = torch.zeros_like(depth)
62
+ elif isinstance(depth, np.ndarray):
63
+ disparity = np.zeros_like(depth)
64
+ non_negtive_mask = depth > 0
65
+ disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
66
+ if return_mask:
67
+ return disparity, non_negtive_mask
68
+ else:
69
+ return disparity
70
+
71
+ def disparity2depth(disparity, **kwargs):
72
+ return depth2disparity(disparity, **kwargs)
src/util/config_util.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-02-14
3
+
4
+ import omegaconf
5
+ from omegaconf import OmegaConf
6
+
7
+
8
+ def recursive_load_config(config_path: str) -> OmegaConf:
9
+ conf = OmegaConf.load(config_path)
10
+
11
+ output_conf = OmegaConf.create({})
12
+
13
+ # Load base config. Later configs on the list will overwrite previous
14
+ base_configs = conf.get("base_config", default_value=None)
15
+ if base_configs is not None:
16
+ assert isinstance(base_configs, omegaconf.listconfig.ListConfig)
17
+ for _path in base_configs:
18
+ assert (
19
+ _path != config_path
20
+ ), "Circulate merging, base_config should not include itself."
21
+ _base_conf = recursive_load_config(_path)
22
+ output_conf = OmegaConf.merge(output_conf, _base_conf)
23
+
24
+ # Merge configs and overwrite values
25
+ output_conf = OmegaConf.merge(output_conf, conf)
26
+
27
+ return output_conf
28
+
29
+
30
+ def find_value_in_omegaconf(search_key, config):
31
+ result_list = []
32
+
33
+ if isinstance(config, omegaconf.DictConfig):
34
+ for key, value in config.items():
35
+ if key == search_key:
36
+ result_list.append(value)
37
+ elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)):
38
+ result_list.extend(find_value_in_omegaconf(search_key, value))
39
+ elif isinstance(config, omegaconf.ListConfig):
40
+ for item in config:
41
+ if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)):
42
+ result_list.extend(find_value_in_omegaconf(search_key, item))
43
+
44
+ return result_list
45
+
46
+
47
+ if "__main__" == __name__:
48
+ conf = recursive_load_config("config/train_base.yaml")
49
+ print(OmegaConf.to_yaml(conf))