Stable-X commited on
Commit
508279d
·
1 Parent(s): c0e046b

Update code

Browse files
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
- title: StableDiffuse
3
- emoji: 👁
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
 
 
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: StableDiffuse: Removing Reflections from Textured Surfaces in a Single Image
3
+ emoji: 🏵️
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.32.2
8
  app_file: app.py
9
+ pinned: true
10
+ license: cc-by-sa-4.0
11
+ models:
12
+ - Stable-X/yoso-diffuse-v0-2
13
+ hf_oauth: true
14
+ hf_oauth_expiration_minutes: 43200
15
  ---
 
 
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
+
21
+ import functools
22
+ import os
23
+ import tempfile
24
+
25
+ import diffusers
26
+ import gradio as gr
27
+ import imageio as imageio
28
+ import numpy as np
29
+ import spaces
30
+ import torch as torch
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
+ from PIL import Image
33
+ from gradio_imageslider import ImageSlider
34
+ from tqdm import tqdm
35
+
36
+ from pathlib import Path
37
+ import gradio
38
+ from gradio.utils import get_cache_folder
39
+ from stablediffuse.pipeline_yoso_diffuse import YOSODiffusePipeline
40
+
41
+ class Examples(gradio.helpers.Examples):
42
+ def __init__(self, *args, directory_name=None, **kwargs):
43
+ super().__init__(*args, **kwargs, _initiated_directly=False)
44
+ if directory_name is not None:
45
+ self.cached_folder = get_cache_folder() / directory_name
46
+ self.cached_file = Path(self.cached_folder) / "log.csv"
47
+ self.create()
48
+
49
+
50
+ default_seed = 2024
51
+ default_batch_size = 1
52
+
53
+ default_image_processing_resolution = 768
54
+
55
+ default_video_num_inference_steps = 10
56
+ default_video_processing_resolution = 768
57
+ default_video_out_max_frames = 60
58
+
59
+ def process_image_check(path_input):
60
+ if path_input is None:
61
+ raise gr.Error(
62
+ "Missing image in the first pane: upload a file or use one from the gallery below."
63
+ )
64
+
65
+ def resize_image(input_image, resolution):
66
+ # Ensure input_image is a PIL Image object
67
+ if not isinstance(input_image, Image.Image):
68
+ raise ValueError("input_image should be a PIL Image object")
69
+
70
+ # Convert image to numpy array
71
+ input_image_np = np.asarray(input_image)
72
+
73
+ # Get image dimensions
74
+ H, W, C = input_image_np.shape
75
+ H = float(H)
76
+ W = float(W)
77
+
78
+ # Calculate the scaling factor
79
+ k = float(resolution) / min(H, W)
80
+
81
+ # Determine new dimensions
82
+ H *= k
83
+ W *= k
84
+ H = int(np.round(H / 64.0)) * 64
85
+ W = int(np.round(W / 64.0)) * 64
86
+
87
+ # Resize the image using PIL's resize method
88
+ img = input_image.resize((W, H), Image.Resampling.LANCZOS)
89
+
90
+ return img
91
+
92
+ def process_image(
93
+ pipe,
94
+ path_input,
95
+ ):
96
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
97
+ print(f"Processing image {name_base}{name_ext}")
98
+
99
+ path_output_dir = tempfile.mkdtemp()
100
+ path_out_png = os.path.join(path_output_dir, f"{name_base}_diffuse.png")
101
+ input_image = Image.open(path_input)
102
+ input_image = resize_image(input_image, default_image_processing_resolution)
103
+
104
+ pipe_out = pipe(
105
+ input_image,
106
+ match_input_resolution=False,
107
+ processing_resolution=max(input_image.size)
108
+ )
109
+
110
+ processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
111
+ processed_frame = (processed_frame[0] * 255).astype(np.uint8)
112
+ processed_frame = Image.fromarray(processed_frame)
113
+ processed_frame.save(path_out_png)
114
+ yield [input_image, path_out_png]
115
+
116
+ def center_crop(img):
117
+ # Open the image file
118
+ img_width, img_height = img.size
119
+ crop_width =min(img_width, img_height)
120
+ # Calculate the cropping box
121
+ left = (img_width - crop_width) / 2
122
+ top = (img_height - crop_width) / 2
123
+ right = (img_width + crop_width) / 2
124
+ bottom = (img_height + crop_width) / 2
125
+
126
+ # Crop the image
127
+ img_cropped = img.crop((left, top, right, bottom))
128
+ return img_cropped
129
+
130
+ def process_video(
131
+ pipe,
132
+ path_input,
133
+ out_max_frames=default_video_out_max_frames,
134
+ target_fps=10,
135
+ progress=gr.Progress(),
136
+ ):
137
+ if path_input is None:
138
+ raise gr.Error(
139
+ "Missing video in the first pane: upload a file or use one from the gallery below."
140
+ )
141
+
142
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
143
+ print(f"Processing video {name_base}{name_ext}")
144
+
145
+ path_output_dir = tempfile.mkdtemp()
146
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_diffuse_colored.mp4")
147
+
148
+ init_latents = None
149
+ reader, writer = None, None
150
+ try:
151
+ reader = imageio.get_reader(path_input)
152
+
153
+ meta_data = reader.get_meta_data()
154
+ fps = meta_data["fps"]
155
+ size = meta_data["size"]
156
+ duration_sec = meta_data["duration"]
157
+
158
+ writer = imageio.get_writer(path_out_vis, fps=target_fps)
159
+
160
+ out_frame_id = 0
161
+ pbar = tqdm(desc="Processing Video", total=duration_sec)
162
+
163
+ for frame_id, frame in enumerate(reader):
164
+ if frame_id % (fps // target_fps) != 0:
165
+ continue
166
+ else:
167
+ out_frame_id += 1
168
+ pbar.update(1)
169
+ if out_frame_id > out_max_frames:
170
+ break
171
+
172
+ frame_pil = Image.fromarray(frame)
173
+ # frame_pil = center_crop(frame_pil)
174
+ pipe_out = pipe(
175
+ frame_pil,
176
+ match_input_resolution=False,
177
+ latents=init_latents
178
+ )
179
+
180
+ if init_latents is None:
181
+ init_latents = pipe_out.gaus_noise
182
+ processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
183
+ processed_frame = processed_frame[0]
184
+ _processed_frame = imageio.core.util.Array(processed_frame)
185
+ writer.append_data(_processed_frame)
186
+
187
+ yield (
188
+ [frame_pil, processed_frame],
189
+ None,
190
+ )
191
+ finally:
192
+
193
+ if writer is not None:
194
+ writer.close()
195
+
196
+ if reader is not None:
197
+ reader.close()
198
+
199
+ yield (
200
+ [frame_pil, processed_frame],
201
+ [path_out_vis,]
202
+ )
203
+
204
+
205
+ def run_demo_server(pipe):
206
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
207
+ process_pipe_video = spaces.GPU(
208
+ functools.partial(process_video, pipe), duration=120
209
+ )
210
+
211
+ gradio_theme = gr.themes.Default()
212
+
213
+ with gr.Blocks(
214
+ theme=gradio_theme,
215
+ title="Stable Diffuse Estimation",
216
+ css="""
217
+ #download {
218
+ height: 118px;
219
+ }
220
+ .slider .inner {
221
+ width: 5px;
222
+ background: #FFF;
223
+ }
224
+ .viewport {
225
+ aspect-ratio: 4/3;
226
+ }
227
+ .tabs button.selected {
228
+ font-size: 20px !important;
229
+ color: crimson !important;
230
+ }
231
+ h1 {
232
+ text-align: center;
233
+ display: block;
234
+ }
235
+ h2 {
236
+ text-align: center;
237
+ display: block;
238
+ }
239
+ h3 {
240
+ text-align: center;
241
+ display: block;
242
+ }
243
+ .md_feedback li {
244
+ margin-bottom: 0px !important;
245
+ }
246
+ """,
247
+ head="""
248
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
249
+ <script>
250
+ window.dataLayer = window.dataLayer || [];
251
+ function gtag() {dataLayer.push(arguments);}
252
+ gtag('js', new Date());
253
+ gtag('config', 'G-1FWSVCGZTG');
254
+ </script>
255
+ """,
256
+ ) as demo:
257
+ gr.Markdown(
258
+ """
259
+ # StableDiffuse: Removing Reflections from Textured Surfaces in a Single Image
260
+ <p align="center">
261
+ """
262
+ )
263
+
264
+ with gr.Tabs(elem_classes=["tabs"]):
265
+ with gr.Tab("Image"):
266
+ with gr.Row():
267
+ with gr.Column():
268
+ image_input = gr.Image(
269
+ label="Input Image",
270
+ type="filepath",
271
+ )
272
+ with gr.Row():
273
+ image_submit_btn = gr.Button(
274
+ value="Compute Diffuse", variant="primary"
275
+ )
276
+ image_reset_btn = gr.Button(value="Reset")
277
+ with gr.Column():
278
+ image_output_slider = ImageSlider(
279
+ label="Diffuse outputs",
280
+ type="filepath",
281
+ show_download_button=True,
282
+ show_share_button=True,
283
+ interactive=False,
284
+ elem_classes="slider",
285
+ position=0.25,
286
+ )
287
+
288
+ Examples(
289
+ fn=process_pipe_image,
290
+ examples=sorted([
291
+ os.path.join("files", "image", name)
292
+ for name in os.listdir(os.path.join("files", "image"))
293
+ ]),
294
+ inputs=[image_input],
295
+ outputs=[image_output_slider],
296
+ cache_examples=False,
297
+ directory_name="examples_image",
298
+ )
299
+
300
+ with gr.Tab("Video"):
301
+ with gr.Row():
302
+ with gr.Column():
303
+ video_input = gr.Video(
304
+ label="Input Video",
305
+ sources=["upload", "webcam"],
306
+ )
307
+ with gr.Row():
308
+ video_submit_btn = gr.Button(
309
+ value="Compute Diffuse", variant="primary"
310
+ )
311
+ video_reset_btn = gr.Button(value="Reset")
312
+ with gr.Column():
313
+ processed_frames = ImageSlider(
314
+ label="Realtime Visualization",
315
+ type="filepath",
316
+ show_download_button=True,
317
+ show_share_button=True,
318
+ interactive=False,
319
+ elem_classes="slider",
320
+ position=0.25,
321
+ )
322
+ video_output_files = gr.Files(
323
+ label="Diffuse outputs",
324
+ elem_id="download",
325
+ interactive=False,
326
+ )
327
+ Examples(
328
+ fn=process_pipe_video,
329
+ examples=sorted([
330
+ os.path.join("files", "video", name)
331
+ for name in os.listdir(os.path.join("files", "video"))
332
+ ]),
333
+ inputs=[video_input],
334
+ outputs=[processed_frames, video_output_files],
335
+ directory_name="examples_video",
336
+ cache_examples=False,
337
+ )
338
+
339
+ ### Image tab
340
+ image_submit_btn.click(
341
+ fn=process_image_check,
342
+ inputs=image_input,
343
+ outputs=None,
344
+ preprocess=False,
345
+ queue=False,
346
+ ).success(
347
+ fn=process_pipe_image,
348
+ inputs=[
349
+ image_input,
350
+ ],
351
+ outputs=[image_output_slider],
352
+ concurrency_limit=1,
353
+ )
354
+
355
+ image_reset_btn.click(
356
+ fn=lambda: (
357
+ None,
358
+ None,
359
+ None,
360
+ ),
361
+ inputs=[],
362
+ outputs=[
363
+ image_input,
364
+ image_output_slider,
365
+ ],
366
+ queue=False,
367
+ )
368
+
369
+ ### Video tab
370
+
371
+ video_submit_btn.click(
372
+ fn=process_pipe_video,
373
+ inputs=[video_input],
374
+ outputs=[processed_frames, video_output_files],
375
+ concurrency_limit=1,
376
+ )
377
+
378
+ video_reset_btn.click(
379
+ fn=lambda: (None, None, None),
380
+ inputs=[],
381
+ outputs=[video_input, processed_frames, video_output_files],
382
+ concurrency_limit=1,
383
+ )
384
+
385
+ ### Server launch
386
+
387
+ demo.queue(
388
+ api_open=False,
389
+ ).launch(
390
+ server_name="0.0.0.0",
391
+ server_port=7860,
392
+ )
393
+
394
+
395
+ def main():
396
+ os.system("pip freeze")
397
+
398
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
399
+
400
+ pipe = YOSODiffusePipeline.from_pretrained(
401
+ 'weights/yoso-diffuse-v0-2', trust_remote_code=True, variant="fp16",
402
+ torch_dtype=torch.float16, t_start=0).to(device)
403
+ try:
404
+ import xformers
405
+ pipe.enable_xformers_memory_efficient_attention()
406
+ except:
407
+ pass # run without xformers
408
+
409
+ run_demo_server(pipe)
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()
requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.1
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.5
4
+ aiosignal==1.3.1
5
+ altair==5.3.0
6
+ annotated-types==0.7.0
7
+ anyio==4.4.0
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ Authlib==1.3.0
11
+ certifi==2024.2.2
12
+ cffi==1.16.0
13
+ charset-normalizer==3.3.2
14
+ click==8.0.4
15
+ contourpy==1.2.1
16
+ cryptography==42.0.7
17
+ cycler==0.12.1
18
+ dataclasses-json==0.6.6
19
+ datasets==2.19.1
20
+ Deprecated==1.2.14
21
+ diffusers==0.28.0
22
+ dill==0.3.8
23
+ dnspython==2.6.1
24
+ email_validator==2.1.1
25
+ exceptiongroup==1.2.1
26
+ fastapi==0.111.0
27
+ fastapi-cli==0.0.4
28
+ ffmpy==0.3.2
29
+ filelock==3.14.0
30
+ fonttools==4.53.0
31
+ frozenlist==1.4.1
32
+ fsspec==2024.3.1
33
+ gradio==4.32.2
34
+ gradio_client==0.17.0
35
+ gradio_imageslider==0.0.20
36
+ h11==0.14.0
37
+ httpcore==1.0.5
38
+ httptools==0.6.1
39
+ httpx==0.27.0
40
+ huggingface-hub==0.23.0
41
+ idna==3.7
42
+ imageio==2.34.1
43
+ imageio-ffmpeg==0.5.0
44
+ importlib_metadata==7.1.0
45
+ importlib_resources==6.4.0
46
+ itsdangerous==2.2.0
47
+ Jinja2==3.1.4
48
+ jsonschema==4.22.0
49
+ jsonschema-specifications==2023.12.1
50
+ kiwisolver==1.4.5
51
+ markdown-it-py==3.0.0
52
+ MarkupSafe==2.1.5
53
+ marshmallow==3.21.2
54
+ matplotlib==3.8.2
55
+ mdurl==0.1.2
56
+ mpmath==1.3.0
57
+ multidict==6.0.5
58
+ multiprocess==0.70.16
59
+ mypy-extensions==1.0.0
60
+ networkx==3.3
61
+ numpy==1.26.4
62
+ nvidia-cublas-cu12==12.1.3.1
63
+ nvidia-cuda-cupti-cu12==12.1.105
64
+ nvidia-cuda-nvrtc-cu12==12.1.105
65
+ nvidia-cuda-runtime-cu12==12.1.105
66
+ nvidia-cudnn-cu12==8.9.2.26
67
+ nvidia-cufft-cu12==11.0.2.54
68
+ nvidia-curand-cu12==10.3.2.106
69
+ nvidia-cusolver-cu12==11.4.5.107
70
+ nvidia-cusparse-cu12==12.1.0.106
71
+ nvidia-nccl-cu12==2.19.3
72
+ nvidia-nvjitlink-cu12==12.5.40
73
+ nvidia-nvtx-cu12==12.1.105
74
+ orjson==3.10.3
75
+ packaging==24.0
76
+ pandas==2.2.2
77
+ pillow==10.3.0
78
+ protobuf==3.20.3
79
+ psutil==5.9.8
80
+ pyarrow==16.0.0
81
+ pyarrow-hotfix==0.6
82
+ pycparser==2.22
83
+ pydantic==2.7.2
84
+ pydantic_core==2.18.3
85
+ pydub==0.25.1
86
+ pygltflib==1.16.1
87
+ Pygments==2.18.0
88
+ pyparsing==3.1.2
89
+ python-dateutil==2.9.0.post0
90
+ python-dotenv==1.0.1
91
+ python-multipart==0.0.9
92
+ pytz==2024.1
93
+ PyYAML==6.0.1
94
+ referencing==0.35.1
95
+ regex==2024.5.15
96
+ requests==2.31.0
97
+ rich==13.7.1
98
+ rpds-py==0.18.1
99
+ ruff==0.4.7
100
+ safetensors==0.4.3
101
+ scipy==1.11.4
102
+ semantic-version==2.10.0
103
+ shellingham==1.5.4
104
+ six==1.16.0
105
+ sniffio==1.3.1
106
+ spaces==0.28.3
107
+ starlette==0.37.2
108
+ sympy==1.12.1
109
+ tokenizers==0.15.2
110
+ tomlkit==0.12.0
111
+ toolz==0.12.1
112
+ torch==2.2.0
113
+ tqdm==4.66.4
114
+ transformers==4.36.1
115
+ trimesh==4.0.5
116
+ triton==2.2.0
117
+ typer==0.12.3
118
+ typing-inspect==0.9.0
119
+ typing_extensions==4.11.0
120
+ tzdata==2024.1
121
+ ujson==5.10.0
122
+ urllib3==2.2.1
123
+ uvicorn==0.30.0
124
+ uvloop==0.19.0
125
+ watchfiles==0.22.0
126
+ websockets==11.0.3
127
+ wrapt==1.16.0
128
+ xformers==0.0.24
129
+ xxhash==3.4.1
130
+ yarl==1.9.4
131
+ zipp==3.19.1
132
+ einops==0.7.0
requirements_min.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.32.1
2
+ gradio-imageslider>=0.0.20
3
+ pygltflib==1.16.1
4
+ trimesh==4.0.5
5
+ imageio
6
+ imageio-ffmpeg
7
+ Pillow
8
+ einops==0.7.0
9
+
10
+ spaces
11
+ accelerate
12
+ diffusers>=0.28.0
13
+ matplotlib==3.8.2
14
+ scipy==1.11.4
15
+ torch==2.0.1
16
+ transformers==4.36.1
17
+ xformers==0.0.21
stablediffuse/__init__.py ADDED
File without changes
stablediffuse/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (162 Bytes). View file
 
stablediffuse/__pycache__/pipeline_yoso_diffuse.cpython-39.pyc ADDED
Binary file (24.3 kB). View file
 
stablediffuse/pipeline_yoso_diffuse.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # --------------------------------------------------------------------------
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
+
27
+
28
+ from diffusers.image_processor import PipelineImageInput
29
+ from diffusers.models import (
30
+ AutoencoderKL,
31
+ UNet2DConditionModel,
32
+ ControlNetModel,
33
+ )
34
+ from diffusers.schedulers import (
35
+ DDIMScheduler
36
+ )
37
+
38
+ from diffusers.utils import (
39
+ BaseOutput,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+
44
+
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
49
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
50
+
51
+ import pdb
52
+
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ EXAMPLE_DOC_STRING = """
59
+ Examples:
60
+ ```py
61
+ >>> import diffusers
62
+ >>> import torch
63
+
64
+ >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
65
+ ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
66
+ ... ).to("cuda")
67
+
68
+ >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
69
+ >>> normals = pipe(image)
70
+
71
+ >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
72
+ >>> vis[0].save("einstein_normals.png")
73
+ ```
74
+ """
75
+
76
+
77
+ @dataclass
78
+ class YOSODiffuseOutput(BaseOutput):
79
+ """
80
+ Output class for Marigold monocular normals prediction pipeline.
81
+
82
+ Args:
83
+ prediction (`np.ndarray`, `torch.Tensor`):
84
+ Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
85
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
86
+ uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
87
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
88
+ \times 1 \times height \times width$.
89
+ latent (`None`, `torch.Tensor`):
90
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
91
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
92
+ """
93
+
94
+ prediction: Union[np.ndarray, torch.Tensor]
95
+ latent: Union[None, torch.Tensor]
96
+ gaus_noise: Union[None, torch.Tensor]
97
+
98
+
99
+ class YOSODiffusePipeline(StableDiffusionControlNetPipeline):
100
+ """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
101
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
102
+
103
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
104
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
105
+
106
+ The pipeline also inherits the following loading methods:
107
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
108
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
109
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
110
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
111
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
112
+
113
+ Args:
114
+ vae ([`AutoencoderKL`]):
115
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
116
+ text_encoder ([`~transformers.CLIPTextModel`]):
117
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
118
+ tokenizer ([`~transformers.CLIPTokenizer`]):
119
+ A `CLIPTokenizer` to tokenize text.
120
+ unet ([`UNet2DConditionModel`]):
121
+ A `UNet2DConditionModel` to denoise the encoded image latents.
122
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
123
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
124
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
125
+ additional conditioning.
126
+ scheduler ([`SchedulerMixin`]):
127
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
128
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
129
+ safety_checker ([`StableDiffusionSafetyChecker`]):
130
+ Classification module that estimates whether generated images could be considered offensive or harmful.
131
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
132
+ about a model's potential harms.
133
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
134
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
135
+ """
136
+
137
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
138
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
139
+ _exclude_from_cpu_offload = ["safety_checker"]
140
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
141
+
142
+
143
+
144
+ def __init__(
145
+ self,
146
+ vae: AutoencoderKL,
147
+ text_encoder: CLIPTextModel,
148
+ tokenizer: CLIPTokenizer,
149
+ unet: UNet2DConditionModel,
150
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
151
+ scheduler: Union[DDIMScheduler],
152
+ safety_checker: StableDiffusionSafetyChecker,
153
+ feature_extractor: CLIPImageProcessor,
154
+ image_encoder: CLIPVisionModelWithProjection = None,
155
+ requires_safety_checker: bool = True,
156
+ default_denoising_steps: Optional[int] = 1,
157
+ default_processing_resolution: Optional[int] = 768,
158
+ prompt="",
159
+ empty_text_embedding=None,
160
+ t_start: Optional[int] = 401,
161
+ ):
162
+ super().__init__(
163
+ vae,
164
+ text_encoder,
165
+ tokenizer,
166
+ unet,
167
+ controlnet,
168
+ scheduler,
169
+ safety_checker,
170
+ feature_extractor,
171
+ image_encoder,
172
+ requires_safety_checker,
173
+ )
174
+
175
+ # TODO yoso ImageProcessor
176
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
177
+ self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
178
+ self.default_denoising_steps = default_denoising_steps
179
+ self.default_processing_resolution = default_processing_resolution
180
+ self.prompt = prompt
181
+ self.prompt_embeds = None
182
+ self.empty_text_embedding = empty_text_embedding
183
+ self.t_start= t_start # target_out latents
184
+
185
+ def check_inputs(
186
+ self,
187
+ image: PipelineImageInput,
188
+ num_inference_steps: int,
189
+ ensemble_size: int,
190
+ processing_resolution: int,
191
+ resample_method_input: str,
192
+ resample_method_output: str,
193
+ batch_size: int,
194
+ ensembling_kwargs: Optional[Dict[str, Any]],
195
+ latents: Optional[torch.Tensor],
196
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
197
+ output_type: str,
198
+ output_uncertainty: bool,
199
+ ) -> int:
200
+ if num_inference_steps is None:
201
+ raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
202
+ if num_inference_steps < 1:
203
+ raise ValueError("`num_inference_steps` must be positive.")
204
+ if ensemble_size < 1:
205
+ raise ValueError("`ensemble_size` must be positive.")
206
+ if ensemble_size == 2:
207
+ logger.warning(
208
+ "`ensemble_size` == 2 results are similar to no ensembling (1); "
209
+ "consider increasing the value to at least 3."
210
+ )
211
+ if ensemble_size == 1 and output_uncertainty:
212
+ raise ValueError(
213
+ "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
214
+ "greater than 1."
215
+ )
216
+ if processing_resolution is None:
217
+ raise ValueError(
218
+ "`processing_resolution` is not specified and could not be resolved from the model config."
219
+ )
220
+ if processing_resolution < 0:
221
+ raise ValueError(
222
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
223
+ "downsampled processing."
224
+ )
225
+ if processing_resolution % self.vae_scale_factor != 0:
226
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
227
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
228
+ raise ValueError(
229
+ "`resample_method_input` takes string values compatible with PIL library: "
230
+ "nearest, nearest-exact, bilinear, bicubic, area."
231
+ )
232
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
233
+ raise ValueError(
234
+ "`resample_method_output` takes string values compatible with PIL library: "
235
+ "nearest, nearest-exact, bilinear, bicubic, area."
236
+ )
237
+ if batch_size < 1:
238
+ raise ValueError("`batch_size` must be positive.")
239
+ if output_type not in ["pt", "np"]:
240
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
241
+ if latents is not None and generator is not None:
242
+ raise ValueError("`latents` and `generator` cannot be used together.")
243
+ if ensembling_kwargs is not None:
244
+ if not isinstance(ensembling_kwargs, dict):
245
+ raise ValueError("`ensembling_kwargs` must be a dictionary.")
246
+ if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
247
+ raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
248
+
249
+ # image checks
250
+ num_images = 0
251
+ W, H = None, None
252
+ if not isinstance(image, list):
253
+ image = [image]
254
+ for i, img in enumerate(image):
255
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
256
+ if img.ndim not in (2, 3, 4):
257
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
258
+ H_i, W_i = img.shape[-2:]
259
+ N_i = 1
260
+ if img.ndim == 4:
261
+ N_i = img.shape[0]
262
+ elif isinstance(img, Image.Image):
263
+ W_i, H_i = img.size
264
+ N_i = 1
265
+ else:
266
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
267
+ if W is None:
268
+ W, H = W_i, H_i
269
+ elif (W, H) != (W_i, H_i):
270
+ raise ValueError(
271
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
272
+ )
273
+ num_images += N_i
274
+
275
+ # latents checks
276
+ if latents is not None:
277
+ if not torch.is_tensor(latents):
278
+ raise ValueError("`latents` must be a torch.Tensor.")
279
+ if latents.dim() != 4:
280
+ raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
281
+
282
+ if processing_resolution > 0:
283
+ max_orig = max(H, W)
284
+ new_H = H * processing_resolution // max_orig
285
+ new_W = W * processing_resolution // max_orig
286
+ if new_H == 0 or new_W == 0:
287
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
288
+ W, H = new_W, new_H
289
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
290
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
291
+ shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
292
+
293
+ if latents.shape != shape_expected:
294
+ raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
295
+
296
+ # generator checks
297
+ if generator is not None:
298
+ if isinstance(generator, list):
299
+ if len(generator) != num_images * ensemble_size:
300
+ raise ValueError(
301
+ "The number of generators must match the total number of ensemble members for all input images."
302
+ )
303
+ if not all(g.device.type == generator[0].device.type for g in generator):
304
+ raise ValueError("`generator` device placement is not consistent in the list.")
305
+ elif not isinstance(generator, torch.Generator):
306
+ raise ValueError(f"Unsupported generator type: {type(generator)}.")
307
+
308
+ return num_images
309
+
310
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
311
+ if not hasattr(self, "_progress_bar_config"):
312
+ self._progress_bar_config = {}
313
+ elif not isinstance(self._progress_bar_config, dict):
314
+ raise ValueError(
315
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
316
+ )
317
+
318
+ progress_bar_config = dict(**self._progress_bar_config)
319
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
320
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
321
+ if iterable is not None:
322
+ return tqdm(iterable, **progress_bar_config)
323
+ elif total is not None:
324
+ return tqdm(total=total, **progress_bar_config)
325
+ else:
326
+ raise ValueError("Either `total` or `iterable` has to be defined.")
327
+
328
+ @torch.no_grad()
329
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
330
+ def __call__(
331
+ self,
332
+ image: PipelineImageInput,
333
+ prompt: Union[str, List[str]] = None,
334
+ negative_prompt: Optional[Union[str, List[str]]] = None,
335
+ num_inference_steps: Optional[int] = None,
336
+ ensemble_size: int = 1,
337
+ processing_resolution: Optional[int] = None,
338
+ match_input_resolution: bool = True,
339
+ resample_method_input: str = "bilinear",
340
+ resample_method_output: str = "bilinear",
341
+ batch_size: int = 1,
342
+ ensembling_kwargs: Optional[Dict[str, Any]] = None,
343
+ latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
344
+ prompt_embeds: Optional[torch.Tensor] = None,
345
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
346
+ num_images_per_prompt: Optional[int] = 1,
347
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
348
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
349
+ output_type: str = "np",
350
+ output_uncertainty: bool = False,
351
+ output_latent: bool = False,
352
+ skip_preprocess: bool = False,
353
+ return_dict: bool = True,
354
+ **kwargs,
355
+ ):
356
+ """
357
+ Function invoked when calling the pipeline.
358
+
359
+ Args:
360
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
361
+ `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
362
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
363
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
364
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
365
+ same width and height.
366
+ num_inference_steps (`int`, *optional*, defaults to `None`):
367
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
368
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
369
+ for Marigold-LCM models.
370
+ ensemble_size (`int`, defaults to `1`):
371
+ Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
372
+ faster inference.
373
+ processing_resolution (`int`, *optional*, defaults to `None`):
374
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
375
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
376
+ value `None` resolves to the optimal value from the model config.
377
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
378
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
379
+ side of the output will equal to `processing_resolution`.
380
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
381
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
382
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
383
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
384
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
385
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
386
+ batch_size (`int`, *optional*, defaults to `1`):
387
+ Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
388
+ ensembling_kwargs (`dict`, *optional*, defaults to `None`)
389
+ Extra dictionary with arguments for precise ensembling control. The following options are available:
390
+ - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
391
+ every pixel location, can be either `"closest"` or `"mean"`.
392
+ latents (`torch.Tensor`, *optional*, defaults to `None`):
393
+ Latent noise tensors to replace the random initialization. These can be taken from the previous
394
+ function call's output.
395
+ generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
396
+ Random number generator object to ensure reproducibility.
397
+ output_type (`str`, *optional*, defaults to `"np"`):
398
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
399
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
400
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
401
+ When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
402
+ the `ensemble_size` argument is set to a value above 2.
403
+ output_latent (`bool`, *optional*, defaults to `False`):
404
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
405
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
406
+ `latents` argument.
407
+ return_dict (`bool`, *optional*, defaults to `True`):
408
+ Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
409
+
410
+ Examples:
411
+
412
+ Returns:
413
+ [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
414
+ If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
415
+ `tuple` is returned where the first element is the prediction, the second element is the uncertainty
416
+ (or `None`), and the third is the latent (or `None`).
417
+ """
418
+
419
+ # 0. Resolving variables.
420
+ device = self._execution_device
421
+ dtype = self.dtype
422
+
423
+ # Model-specific optimal default values leading to fast and reasonable results.
424
+ if num_inference_steps is None:
425
+ num_inference_steps = self.default_denoising_steps
426
+ if processing_resolution is None:
427
+ processing_resolution = self.default_processing_resolution
428
+
429
+ # 1. Check inputs.
430
+ num_images = self.check_inputs(
431
+ image,
432
+ num_inference_steps,
433
+ ensemble_size,
434
+ processing_resolution,
435
+ resample_method_input,
436
+ resample_method_output,
437
+ batch_size,
438
+ ensembling_kwargs,
439
+ latents,
440
+ generator,
441
+ output_type,
442
+ output_uncertainty,
443
+ )
444
+
445
+
446
+ # 2. Prepare empty text conditioning.
447
+ # Model invocation: self.tokenizer, self.text_encoder.
448
+ if self.empty_text_embedding is None:
449
+ prompt = ""
450
+ text_inputs = self.tokenizer(
451
+ prompt,
452
+ padding="do_not_pad",
453
+ max_length=self.tokenizer.model_max_length,
454
+ truncation=True,
455
+ return_tensors="pt",
456
+ )
457
+ text_input_ids = text_inputs.input_ids.to(device)
458
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
459
+
460
+
461
+
462
+ # 3. prepare prompt
463
+ if self.prompt_embeds is None:
464
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
465
+ self.prompt,
466
+ device,
467
+ num_images_per_prompt,
468
+ False,
469
+ negative_prompt,
470
+ prompt_embeds=prompt_embeds,
471
+ negative_prompt_embeds=None,
472
+ lora_scale=None,
473
+ clip_skip=None,
474
+ )
475
+ self.prompt_embeds = prompt_embeds
476
+ self.negative_prompt_embeds = negative_prompt_embeds
477
+
478
+
479
+
480
+ # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
481
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
482
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
483
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
484
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
485
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
486
+ # resolution can lead to loss of either fine details or global context in the output predictions.
487
+ if not skip_preprocess:
488
+ image, padding, original_resolution = self.image_processor.preprocess(
489
+ image, processing_resolution, resample_method_input, device, dtype
490
+ ) # [N,3,PPH,PPW]
491
+ else:
492
+ padding = (0, 0)
493
+ original_resolution = image.shape[2:]
494
+ # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
495
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
496
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
497
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
498
+ # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
499
+ # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
500
+ # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
501
+ # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
502
+ # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
503
+ # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
504
+ # Model invocation: self.vae.encoder.
505
+ image_latent, pred_latent = self.prepare_latents(
506
+ image, latents, generator, ensemble_size, batch_size
507
+ ) # [N*E,4,h,w], [N*E,4,h,w]
508
+
509
+ gaus_noise = pred_latent.detach().clone()
510
+ del image
511
+
512
+
513
+ # 6. obtain control_output
514
+
515
+ cond_scale =controlnet_conditioning_scale
516
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
517
+ image_latent.detach(),
518
+ self.t_start,
519
+ encoder_hidden_states=self.prompt_embeds,
520
+ conditioning_scale=cond_scale,
521
+ guess_mode=False,
522
+ return_dict=False,
523
+ )
524
+
525
+ # 7. YOSO sampling
526
+ latent_x_t = self.unet(
527
+ pred_latent,
528
+ self.t_start,
529
+ encoder_hidden_states=self.prompt_embeds,
530
+ down_block_additional_residuals=down_block_res_samples,
531
+ mid_block_additional_residual=mid_block_res_sample,
532
+ return_dict=False,
533
+ )[0]
534
+
535
+
536
+ del (
537
+ pred_latent,
538
+ image_latent,
539
+ )
540
+
541
+ # decoder
542
+ prediction = self.decode_prediction(latent_x_t)
543
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
544
+
545
+ prediction = self.image_processor.resize_antialias(
546
+ prediction, original_resolution, resample_method_output, is_aa=False
547
+ ) # [N,3,H,W]
548
+
549
+ if output_type == "np":
550
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
551
+
552
+ # 11. Offload all models
553
+ self.maybe_free_model_hooks()
554
+
555
+ return YOSODiffuseOutput(
556
+ prediction=prediction,
557
+ latent=latent_x_t,
558
+ gaus_noise=gaus_noise,
559
+ )
560
+
561
+ # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
562
+ def prepare_latents(
563
+ self,
564
+ image: torch.Tensor,
565
+ latents: Optional[torch.Tensor],
566
+ generator: Optional[torch.Generator],
567
+ ensemble_size: int,
568
+ batch_size: int,
569
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
570
+ def retrieve_latents(encoder_output):
571
+ if hasattr(encoder_output, "latent_dist"):
572
+ return encoder_output.latent_dist.mode()
573
+ elif hasattr(encoder_output, "latents"):
574
+ return encoder_output.latents
575
+ else:
576
+ raise AttributeError("Could not access latents of provided encoder_output")
577
+
578
+
579
+
580
+ image_latent = torch.cat(
581
+ [
582
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
583
+ for i in range(0, image.shape[0], batch_size)
584
+ ],
585
+ dim=0,
586
+ ) # [N,4,h,w]
587
+ image_latent = image_latent * self.vae.config.scaling_factor
588
+ image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
589
+
590
+ pred_latent = latents
591
+ if pred_latent is None:
592
+ pred_latent = randn_tensor(
593
+ image_latent.shape,
594
+ generator=generator,
595
+ device=image_latent.device,
596
+ dtype=image_latent.dtype,
597
+ ) # [N*E,4,h,w]
598
+
599
+ return image_latent, pred_latent
600
+
601
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
602
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
603
+ raise ValueError(
604
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
605
+ )
606
+
607
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
608
+
609
+ return prediction # [B,3,H,W]
610
+
611
+ @staticmethod
612
+ def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
613
+ if normals.dim() != 4 or normals.shape[1] != 3:
614
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
615
+
616
+ norm = torch.norm(normals, dim=1, keepdim=True)
617
+ normals /= norm.clamp(min=eps)
618
+
619
+ return normals
620
+
621
+ @staticmethod
622
+ def ensemble_normals(
623
+ normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
624
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
625
+ """
626
+ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
627
+ the number of ensemble members for a given prediction of size `(H x W)`.
628
+
629
+ Args:
630
+ normals (`torch.Tensor`):
631
+ Input ensemble normals maps.
632
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
633
+ Whether to output uncertainty map.
634
+ reduction (`str`, *optional*, defaults to `"closest"`):
635
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
636
+ `"mean"`.
637
+
638
+ Returns:
639
+ A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
640
+ uncertainties of shape `(1, 1, H, W)`.
641
+ """
642
+ if normals.dim() != 4 or normals.shape[1] != 3:
643
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
644
+ if reduction not in ("closest", "mean"):
645
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
646
+
647
+ mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
648
+ mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
649
+
650
+ sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
651
+ sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
652
+
653
+ uncertainty = None
654
+ if output_uncertainty:
655
+ uncertainty = sim_cos.arccos() # [E,1,H,W]
656
+ uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
657
+
658
+ if reduction == "mean":
659
+ return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
660
+
661
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
662
+ closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
663
+ closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
664
+
665
+ return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
666
+
667
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
668
+ def retrieve_timesteps(
669
+ scheduler,
670
+ num_inference_steps: Optional[int] = None,
671
+ device: Optional[Union[str, torch.device]] = None,
672
+ timesteps: Optional[List[int]] = None,
673
+ sigmas: Optional[List[float]] = None,
674
+ **kwargs,
675
+ ):
676
+ """
677
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
678
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
679
+
680
+ Args:
681
+ scheduler (`SchedulerMixin`):
682
+ The scheduler to get timesteps from.
683
+ num_inference_steps (`int`):
684
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
685
+ must be `None`.
686
+ device (`str` or `torch.device`, *optional*):
687
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
688
+ timesteps (`List[int]`, *optional*):
689
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
690
+ `num_inference_steps` and `sigmas` must be `None`.
691
+ sigmas (`List[float]`, *optional*):
692
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
693
+ `num_inference_steps` and `timesteps` must be `None`.
694
+
695
+ Returns:
696
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
697
+ second element is the number of inference steps.
698
+ """
699
+ if timesteps is not None and sigmas is not None:
700
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
701
+ if timesteps is not None:
702
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
703
+ if not accepts_timesteps:
704
+ raise ValueError(
705
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
706
+ f" timestep schedules. Please check whether you are using the correct scheduler."
707
+ )
708
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
709
+ timesteps = scheduler.timesteps
710
+ num_inference_steps = len(timesteps)
711
+ elif sigmas is not None:
712
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
713
+ if not accept_sigmas:
714
+ raise ValueError(
715
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
716
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
717
+ )
718
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
719
+ timesteps = scheduler.timesteps
720
+ num_inference_steps = len(timesteps)
721
+ else:
722
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
723
+ timesteps = scheduler.timesteps
724
+ return timesteps, num_inference_steps