guangkaixu commited on
Commit
562fd4c
1 Parent(s): cdba047
README.md CHANGED
@@ -1,13 +1,24 @@
1
  ---
2
- title: GenPercept
3
  emoji: ⚡
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
- pinned: false
10
- license: mit
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: GenPercept: Diffusion Models Trained with Large Data Are Transferable Visual Models
3
  emoji: ⚡
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
+ pinned: true
10
+ models:
11
+ - guangkaixu/GenPercept
12
+ license: cc0-1.0
13
  ---
14
 
15
+ If you find it useful, please cite our paper:
16
+
17
+ ```
18
+ @article{xu2024diffusion,
19
+ title={Diffusion Models Trained with Large Data Are Transferable Visual Models},
20
+ author={Xu, Guangkai and Ge, Yongtao and Liu, Mingyu and Fan, Chengxiang and Xie, Kangyang and Zhao, Zhiyue and Chen, Hao and Shen, Chunhua},
21
+ journal={arXiv preprint arXiv:2403.06090},
22
+ year={2024}
23
+ }
24
+ ```
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Guangkai Xu, Zhejiang University. All rights reserved.
2
+ #
3
+ # Licensed under the CC0-1.0 license;
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://github.com/aim-uofa/GenPercept/blob/main/LICENSE
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # This code is based on Marigold and diffusers codebases
16
+ # https://github.com/prs-eth/marigold
17
+ # https://github.com/huggingface/diffusers
18
+ # --------------------------------------------------------------------------
19
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
20
+ # Please find bibtex at: https://github.com/aim-uofa/GenPercept#%EF%B8%8F-citation
21
+ # More information about the method can be found at https://github.com/aim-uofa/GenPercept
22
+ # --------------------------------------------------------------------------
23
+
24
+ from __future__ import annotations
25
+
26
+ import functools
27
+ import os
28
+ import tempfile
29
+ import warnings
30
+
31
+ import gradio as gr
32
+ import numpy as np
33
+ import spaces
34
+ import torch as torch
35
+ from PIL import Image
36
+ from gradio_imageslider import ImageSlider
37
+
38
+ from gradio_patches.examples import Examples
39
+ from pipeline_genpercept import GenPerceptPipeline
40
+
41
+ warnings.filterwarnings(
42
+ "ignore", message=".*LoginButton created outside of a Blocks context.*"
43
+ )
44
+
45
+ default_image_processing_res = 768
46
+ default_image_reproducuble = True
47
+
48
+ def process_image_check(path_input):
49
+ if path_input is None:
50
+ raise gr.Error(
51
+ "Missing image in the first pane: upload a file or use one from the gallery below."
52
+ )
53
+
54
+ def process_image(
55
+ pipe,
56
+ path_input,
57
+ processing_res=default_image_processing_res,
58
+ ):
59
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
60
+ print(f"Processing image {name_base}{name_ext}")
61
+
62
+ path_output_dir = tempfile.mkdtemp()
63
+ path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
64
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
65
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
66
+
67
+ input_image = Image.open(path_input)
68
+
69
+ pipe_out = pipe(
70
+ input_image,
71
+ processing_res=processing_res,
72
+ batch_size=1 if processing_res == 0 else 0,
73
+ show_progress_bar=False,
74
+ )
75
+
76
+ depth_pred = pipe_out.depth_np
77
+ depth_colored = pipe_out.depth_colored
78
+ depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
79
+
80
+ np.save(path_out_fp32, depth_pred)
81
+ Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
82
+ depth_colored.save(path_out_vis)
83
+
84
+ return (
85
+ [path_out_16bit, path_out_vis],
86
+ [path_out_16bit, path_out_fp32, path_out_vis],
87
+ )
88
+
89
+ def run_demo_server(pipe):
90
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
91
+ process_pipe_video = spaces.GPU(
92
+ functools.partial(process_video, pipe), duration=120
93
+ )
94
+ process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe))
95
+
96
+ gradio_theme = gr.themes.Default()
97
+
98
+ with gr.Blocks(
99
+ theme=gradio_theme,
100
+ title="GenPercept",
101
+ css="""
102
+ #download {
103
+ height: 118px;
104
+ }
105
+ .slider .inner {
106
+ width: 5px;
107
+ background: #FFF;
108
+ }
109
+ .viewport {
110
+ aspect-ratio: 4/3;
111
+ }
112
+ .tabs button.selected {
113
+ font-size: 20px !important;
114
+ color: crimson !important;
115
+ }
116
+ h1 {
117
+ text-align: center;
118
+ display: block;
119
+ }
120
+ h2 {
121
+ text-align: center;
122
+ display: block;
123
+ }
124
+ h3 {
125
+ text-align: center;
126
+ display: block;
127
+ }
128
+ .md_feedback li {
129
+ margin-bottom: 0px !important;
130
+ }
131
+ """,
132
+ head="""
133
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
134
+ <script>
135
+ window.dataLayer = window.dataLayer || [];
136
+ function gtag() {dataLayer.push(arguments);}
137
+ gtag('js', new Date());
138
+ gtag('config', 'G-1FWSVCGZTG');
139
+ </script>
140
+ """,
141
+ ) as demo:
142
+
143
+ gr.Markdown(
144
+ """
145
+ # GenPercept: Diffusion Models Trained with Large Data Are Transferable Visual Models
146
+ <p align="center">
147
+ <a title="arXiv" href="https://arxiv.org/abs/2403.06090" target="_blank" rel="noopener noreferrer"
148
+ style="display: inline-block;">
149
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
150
+ </a>
151
+ <a title="Github" href="https://github.com/aim-uofa/GenPercept" target="_blank" rel="noopener noreferrer"
152
+ style="display: inline-block;">
153
+ <img src="https://img.shields.io/github/stars/aim-uofa/GenPercept?label=GitHub%20%E2%98%85&logo=github&color=C8C"
154
+ alt="badge-github-stars">
155
+ </a>
156
+ </p>
157
+ <p align="justify">
158
+ GenPercept leverages the prior knowledge of stable diffusion models to estimate detailed visual perception results.
159
+ It achieve remarkable transferable performance on fundamental vision perception tasks using a moderate amount of target data
160
+ (even synthetic data only). Compared to previous methods, our inference process only requires one step and therefore runs faster.
161
+ </p>
162
+ """
163
+ )
164
+
165
+ with gr.Tabs(elem_classes=["tabs"]):
166
+ with gr.Tab("Depth Estimation"):
167
+ with gr.Row():
168
+ with gr.Column():
169
+ image_input = gr.Image(
170
+ label="Input Image",
171
+ type="filepath",
172
+ )
173
+ with gr.Row():
174
+ image_submit_btn = gr.Button(
175
+ value="Estimate Depth", variant="primary"
176
+ )
177
+ image_reset_btn = gr.Button(value="Reset")
178
+ with gr.Accordion("Advanced options", open=False):
179
+ image_processing_res = gr.Radio(
180
+ [
181
+ ("Native", 0),
182
+ ("Recommended", 768),
183
+ ],
184
+ label="Processing resolution",
185
+ value=default_image_processing_res,
186
+ )
187
+ with gr.Column():
188
+ image_output_slider = ImageSlider(
189
+ label="Predicted depth of gray / color (red-near, blue-far)",
190
+ type="filepath",
191
+ show_download_button=True,
192
+ show_share_button=True,
193
+ interactive=False,
194
+ elem_classes="slider",
195
+ position=0.25,
196
+ )
197
+ image_output_files = gr.Files(
198
+ label="Depth outputs",
199
+ elem_id="download",
200
+ interactive=False,
201
+ )
202
+
203
+ filenames = []
204
+ filenames.extend(["anime_%d.jpg" %i+1 for i in range(7)])
205
+ filenames.extend(["line_%d.jpg" %i+1 for i in range(6)])
206
+ filenames.extend(["real_%d.jpg" %i+1 for i in range(24)])
207
+ Examples(
208
+ fn=process_pipe_image,
209
+ examples=[
210
+ os.path.join("images", "depth", name)
211
+ for name in filenames
212
+ ],
213
+ inputs=[image_input],
214
+ outputs=[image_output_slider, image_output_files],
215
+ cache_examples=True,
216
+ directory_name="examples_image",
217
+ )
218
+
219
+ ### Image tab
220
+ image_submit_btn.click(
221
+ fn=process_image_check,
222
+ inputs=image_input,
223
+ outputs=None,
224
+ preprocess=False,
225
+ queue=False,
226
+ ).success(
227
+ fn=process_pipe_image,
228
+ inputs=[
229
+ image_input,
230
+ image_processing_res,
231
+ ],
232
+ outputs=[image_output_slider, image_output_files],
233
+ concurrency_limit=1,
234
+ )
235
+
236
+ image_reset_btn.click(
237
+ fn=lambda: (
238
+ None,
239
+ None,
240
+ None,
241
+ default_image_processing_res,
242
+ ),
243
+ inputs=[],
244
+ outputs=[
245
+ image_input,
246
+ image_output_slider,
247
+ image_output_files,
248
+ image_processing_res,
249
+ ],
250
+ queue=False,
251
+ )
252
+
253
+ ### Server launch
254
+
255
+ demo.queue(
256
+ api_open=False,
257
+ ).launch(
258
+ server_name="0.0.0.0",
259
+ server_port=7860,
260
+ )
261
+
262
+
263
+ def main():
264
+ os.system("pip freeze")
265
+
266
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
267
+
268
+ vae = AutoencoderKL.from_pretrained("./", subfolder='vae')
269
+ unet = UNet2DConditionModel.from_pretrained('./', subfolder="unet")
270
+ empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
271
+
272
+ pipe = GenPerceptPipeline(vae=vae,
273
+ unet=unet,
274
+ empty_text_embed=empty_text_embed)
275
+ try:
276
+ import xformers
277
+ pipe.enable_xformers_memory_efficient_attention()
278
+ except:
279
+ pass # run without xformers
280
+
281
+ pipe = pipe.to(device)
282
+ run_demo_server(pipe)
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()
287
+
288
+
289
+
images/depth/.DS_Store ADDED
Binary file (6.15 kB). View file
 
images/depth/anime_1.jpg ADDED
images/depth/anime_2.jpg ADDED
images/depth/anime_3.jpg ADDED
images/depth/anime_4.jpg ADDED
images/depth/anime_5.jpg ADDED
images/depth/anime_6.jpg ADDED
images/depth/anime_7.jpg ADDED
images/depth/line_1.jpg ADDED
images/depth/line_2.jpg ADDED
images/depth/line_3.jpg ADDED
images/depth/line_4.jpg ADDED
images/depth/line_5.jpg ADDED
images/depth/line_6.jpg ADDED
images/depth/real_1.jpg ADDED
images/depth/real_10.jpg ADDED
images/depth/real_11.jpg ADDED
images/depth/real_12.jpg ADDED
images/depth/real_13.jpg ADDED
images/depth/real_14.jpg ADDED
images/depth/real_15.jpg ADDED
images/depth/real_16.jpg ADDED
images/depth/real_17.jpg ADDED
images/depth/real_18.jpg ADDED
images/depth/real_19.jpg ADDED
images/depth/real_2.jpg ADDED
images/depth/real_20.jpg ADDED
images/depth/real_21.jpg ADDED
images/depth/real_22.jpg ADDED
images/depth/real_23.jpg ADDED
images/depth/real_24.jpg ADDED
images/depth/real_3.jpg ADDED
images/depth/real_4.jpg ADDED
images/depth/real_5.jpg ADDED
images/depth/real_6.jpg ADDED
images/depth/real_7.jpg ADDED
images/depth/real_8.jpg ADDED
images/depth/real_9.jpg ADDED
pipeline_genpercept.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Diffusion Models Trained with Large Data Are Transferable Visual Models (https://arxiv.org/abs/2403.06090)
3
+ # Github source: https://github.com/aim-uofa/GenPercept
4
+ # Copyright (c) 2024 Zhejiang University
5
+ # Licensed under The CC0 1.0 License [see LICENSE for details]
6
+ # By Guangkai Xu
7
+ # Based on Marigold, diffusers codebases
8
+ # https://github.com/prs-eth/marigold
9
+ # https://github.com/huggingface/diffusers
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import numpy as np
14
+ import torch.nn.functional as F
15
+ import matplotlib.pyplot as plt
16
+
17
+ from tqdm.auto import tqdm
18
+ from PIL import Image
19
+ from typing import List, Dict, Union
20
+ from torch.utils.data import DataLoader, TensorDataset
21
+
22
+ from diffusers import (
23
+ DiffusionPipeline,
24
+ UNet2DConditionModel,
25
+ AutoencoderKL,
26
+ )
27
+ from diffusers.utils import BaseOutput
28
+
29
+ from .util.image_util import chw2hwc, colorize_depth_maps, resize_max_res, norm_to_rgb, resize_res
30
+ from .util.batchsize import find_batch_size
31
+
32
+ class GenPerceptOutput(BaseOutput):
33
+
34
+ pred_np: np.ndarray
35
+ pred_colored: Image.Image
36
+
37
+ class GenPerceptPipeline(DiffusionPipeline):
38
+
39
+ vae_scale_factor = 0.18215
40
+ task_infos = {
41
+ 'depth': dict(task_channel_num=1, interpolate='bilinear', ),
42
+ 'seg': dict(task_channel_num=3, interpolate='nearest', ),
43
+ 'sr': dict(task_channel_num=3, interpolate='nearest', ),
44
+ 'normal': dict(task_channel_num=3, interpolate='bilinear', ),
45
+ }
46
+
47
+ def __init__(
48
+ self,
49
+ unet: UNet2DConditionModel,
50
+ vae: AutoencoderKL,
51
+ customized_head=None,
52
+ empty_text_embed=None,
53
+ ):
54
+ super().__init__()
55
+
56
+ self.empty_text_embed = empty_text_embed
57
+
58
+ # register
59
+ register_dict = dict(
60
+ unet=unet,
61
+ vae=vae,
62
+ customized_head=customized_head,
63
+ )
64
+ self.register_modules(**register_dict)
65
+
66
+ @torch.no_grad()
67
+ def __call__(
68
+ self,
69
+ input_image: Union[Image.Image, torch.Tensor],
70
+ mode: str = 'depth',
71
+ resize_hard = False,
72
+ processing_res: int = 768,
73
+ match_input_res: bool = True,
74
+ batch_size: int = 0,
75
+ color_map: str = "Spectral",
76
+ show_progress_bar: bool = True,
77
+ ) -> GenPerceptOutput:
78
+ """
79
+ Function invoked when calling the pipeline.
80
+
81
+ Args:
82
+ input_image (Image):
83
+ Input RGB (or gray-scale) image.
84
+ processing_res (int, optional):
85
+ Maximum resolution of processing.
86
+ If set to 0: will not resize at all.
87
+ Defaults to 768.
88
+ match_input_res (bool, optional):
89
+ Resize depth prediction to match input resolution.
90
+ Only valid if `limit_input_res` is not None.
91
+ Defaults to True.
92
+ batch_size (int, optional):
93
+ Inference batch size.
94
+ If set to 0, the script will automatically decide the proper batch size.
95
+ Defaults to 0.
96
+ show_progress_bar (bool, optional):
97
+ Display a progress bar of diffusion denoising.
98
+ Defaults to True.
99
+ color_map (str, optional):
100
+ Colormap used to colorize the depth map.
101
+ Defaults to "Spectral".
102
+ Returns:
103
+ `GenPerceptOutput`
104
+ """
105
+
106
+ device = self.device
107
+
108
+ task_channel_num = self.task_infos[mode]['task_channel_num']
109
+
110
+ if not match_input_res:
111
+ assert (
112
+ processing_res is not None
113
+ ), "Value error: `resize_output_back` is only valid with "
114
+ assert processing_res >= 0
115
+
116
+ # ----------------- Image Preprocess -----------------
117
+
118
+ if type(input_image) == torch.Tensor: # [B, 3, H, W]
119
+ rgb_norm = input_image.to(device)
120
+ input_size = input_image.shape[2:]
121
+ bs_imgs = rgb_norm.shape[0]
122
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
123
+ rgb_norm = rgb_norm.to(self.dtype)
124
+ else:
125
+ # if len(rgb_paths) > 0 and 'kitti' in rgb_paths[0]:
126
+ # # kb crop
127
+ # height = input_image.size[1]
128
+ # width = input_image.size[0]
129
+ # top_margin = int(height - 352)
130
+ # left_margin = int((width - 1216) / 2)
131
+ # input_image = input_image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
132
+
133
+ # TODO: check the kitti evaluation resolution here.
134
+ input_size = (input_image.size[1], input_image.size[0])
135
+ # Resize image
136
+ if processing_res > 0:
137
+ if resize_hard:
138
+ input_image = resize_res(
139
+ input_image, max_edge_resolution=processing_res
140
+ )
141
+ else:
142
+ input_image = resize_max_res(
143
+ input_image, max_edge_resolution=processing_res
144
+ )
145
+ input_image = input_image.convert("RGB")
146
+ image = np.asarray(input_image)
147
+
148
+ # Normalize rgb values
149
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
150
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0
151
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
152
+ rgb_norm = rgb_norm[None].to(device)
153
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
154
+ bs_imgs = 1
155
+
156
+ # ----------------- Predicting depth -----------------
157
+
158
+ single_rgb_dataset = TensorDataset(rgb_norm)
159
+ if batch_size > 0:
160
+ _bs = batch_size
161
+ else:
162
+ _bs = find_batch_size(
163
+ ensemble_size=1,
164
+ input_res=max(rgb_norm.shape[1:]),
165
+ dtype=self.dtype,
166
+ )
167
+
168
+ single_rgb_loader = DataLoader(
169
+ single_rgb_dataset, batch_size=_bs, shuffle=False
170
+ )
171
+
172
+ # Predict depth maps (batched)
173
+ pred_list = []
174
+ if show_progress_bar:
175
+ iterable = tqdm(
176
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
177
+ )
178
+ else:
179
+ iterable = single_rgb_loader
180
+
181
+ for batch in iterable:
182
+ (batched_img, ) = batch
183
+ pred = self.single_infer(
184
+ rgb_in=batched_img,
185
+ mode=mode,
186
+ )
187
+ pred_list.append(pred.detach().clone())
188
+ preds = torch.concat(pred_list, axis=0).squeeze() # [bs_imgs, task_channel_num, H, W]
189
+ preds = preds.view(bs_imgs, task_channel_num, preds.shape[-2], preds.shape[-1])
190
+
191
+ if match_input_res:
192
+ preds = F.interpolate(preds, input_size, mode=self.task_infos[mode]['interpolate'])
193
+
194
+ # ----------------- Post processing -----------------
195
+ if mode == 'depth':
196
+ if len(preds.shape) == 4:
197
+ preds = preds[:, 0] # [bs_imgs, H, W]
198
+ # Scale prediction to [0, 1]
199
+ min_d = preds.view(bs_imgs, -1).min(dim=1)[0]
200
+ max_d = preds.view(bs_imgs, -1).max(dim=1)[0]
201
+ preds = (preds - min_d[:, None, None]) / (max_d[:, None, None] - min_d[:, None, None])
202
+ preds = preds.cpu().numpy().astype(np.float32)
203
+ # Colorize
204
+ pred_colored_img_list = []
205
+ for i in range(bs_imgs):
206
+ pred_colored_chw = colorize_depth_maps(
207
+ preds[i], 0, 1, cmap=color_map
208
+ ).squeeze() # [3, H, W], value in (0, 1)
209
+ pred_colored_chw = (pred_colored_chw * 255).astype(np.uint8)
210
+ pred_colored_hwc = chw2hwc(pred_colored_chw)
211
+ pred_colored_img = Image.fromarray(pred_colored_hwc)
212
+ pred_colored_img_list.append(pred_colored_img)
213
+
214
+ return GenPerceptOutput(
215
+ pred_np=np.squeeze(preds),
216
+ pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
217
+ )
218
+
219
+ elif mode == 'seg' or mode == 'sr':
220
+ if not self.customized_head:
221
+ # shift to [0, 1]
222
+ preds = (preds + 1.0) / 2.0
223
+ # shift to [0, 255]
224
+ preds = preds * 255
225
+ # Clip output range
226
+ preds = preds.clip(0, 255).cpu().numpy().astype(np.uint8)
227
+ else:
228
+ raise NotImplementedError
229
+
230
+ pred_colored_img_list = []
231
+ for i in range(preds.shape[0]):
232
+ pred_colored_hwc = chw2hwc(preds[i])
233
+ pred_colored_img = Image.fromarray(pred_colored_hwc)
234
+ pred_colored_img_list.append(pred_colored_img)
235
+
236
+ return GenPerceptOutput(
237
+ pred_np=np.squeeze(preds),
238
+ pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
239
+ )
240
+
241
+ elif mode == 'normal':
242
+ if not self.customized_head:
243
+ preds = preds.clip(-1, 1).cpu().numpy() # [-1, 1]
244
+ else:
245
+ raise NotImplementedError
246
+
247
+ pred_colored_img_list = []
248
+ for i in range(preds.shape[0]):
249
+ pred_colored_chw = norm_to_rgb(preds[i])
250
+ pred_colored_hwc = chw2hwc(pred_colored_chw)
251
+ normal_colored_img_i = Image.fromarray(pred_colored_hwc)
252
+ pred_colored_img_list.append(normal_colored_img_i)
253
+
254
+ return GenPerceptOutput(
255
+ pred_np=np.squeeze(preds),
256
+ pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
257
+ )
258
+
259
+ else:
260
+ raise NotImplementedError
261
+
262
+ @torch.no_grad()
263
+ def single_infer(
264
+ self,
265
+ rgb_in: torch.Tensor,
266
+ mode: str = 'depth',
267
+ ) -> torch.Tensor:
268
+ """
269
+ Perform an individual depth prediction without ensembling.
270
+
271
+ Args:
272
+ rgb_in (torch.Tensor):
273
+ Input RGB image.
274
+ num_inference_steps (int):
275
+ Number of diffusion denoising steps (DDIM) during inference.
276
+ show_pbar (bool):
277
+ Display a progress bar of diffusion denoising.
278
+
279
+ Returns:
280
+ torch.Tensor: Predicted depth map.
281
+ """
282
+ device = rgb_in.device
283
+ bs_imgs = rgb_in.shape[0]
284
+ timesteps = torch.tensor([1]).long().repeat(bs_imgs).to(device)
285
+
286
+ # Encode image
287
+ rgb_latent = self.encode_rgb(rgb_in)
288
+
289
+ batch_embed = self.empty_text_embed
290
+ batch_embed = batch_embed.repeat((rgb_latent.shape[0], 1, 1)).to(device) # [bs_imgs, 77, 1024]
291
+
292
+ # Forward!
293
+ if self.customized_head:
294
+ unet_features = self.unet(rgb_latent, timesteps, encoder_hidden_states=batch_embed, return_feature_only=True)[0][::-1]
295
+ pred = self.customized_head(unet_features)
296
+ else:
297
+ unet_output = self.unet(
298
+ rgb_latent, timesteps, encoder_hidden_states=batch_embed
299
+ ) # [bs_imgs, 4, h, w]
300
+ unet_pred = unet_output.sample
301
+ pred_latent = - unet_pred
302
+ pred_latent.to(device)
303
+ pred = self.decode_pred(pred_latent)
304
+ if mode == 'depth':
305
+ # mean of output channels
306
+ pred = pred.mean(dim=1, keepdim=True)
307
+ # clip prediction
308
+ pred = torch.clip(pred, -1.0, 1.0)
309
+ return pred
310
+
311
+
312
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
313
+ """
314
+ Encode RGB image into latent.
315
+
316
+ Args:
317
+ rgb_in (torch.Tensor):
318
+ Input RGB image to be encoded.
319
+
320
+ Returns:
321
+ torch.Tensor: Image latent
322
+ """
323
+ try:
324
+ # encode
325
+ h_temp = self.vae.encoder(rgb_in)
326
+ moments = self.vae.quant_conv(h_temp)
327
+ except:
328
+ # encode
329
+ h_temp = self.vae.encoder(rgb_in.float())
330
+ moments = self.vae.quant_conv(h_temp.float())
331
+
332
+ mean, logvar = torch.chunk(moments, 2, dim=1)
333
+ # scale latent
334
+ rgb_latent = mean * self.vae_scale_factor
335
+ return rgb_latent
336
+
337
+ def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
338
+ """
339
+ Decode pred latent into pred label.
340
+
341
+ Args:
342
+ pred_latent (torch.Tensor):
343
+ prediction latent to be decoded.
344
+
345
+ Returns:
346
+ torch.Tensor: Decoded prediction label.
347
+ """
348
+ # scale latent
349
+ pred_latent = pred_latent / self.vae_scale_factor
350
+ # decode
351
+ z = self.vae.post_quant_conv(pred_latent)
352
+ pred = self.vae.decoder(z)
353
+
354
+ return pred
355
+