Stalbe-X commited on
Commit
b89eee2
·
0 Parent(s):

Initial Commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.stl filter=lfs diff=lfs merge=lfs -text
37
+ *.glb filter=lfs diff=lfs merge=lfs -text
38
+ *.jpg filter=lfs diff=lfs merge=lfs -text
39
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
40
+ *.png filter=lfs diff=lfs merge=lfs -text
41
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .idea
2
+ .DS_Store
3
+ __pycache__
4
+ weights
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stable Normal Estimation
3
+ emoji: 🏵️
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.32.2
8
+ app_file: app.py
9
+ pinned: true
10
+ license: cc-by-sa-4.0
11
+ models:
12
+ - Stable-X/stable-normal-v0-1
13
+ - Stable-X/yoso-normal-v0-1
14
+ hf_oauth: true
15
+ hf_oauth_expiration_minutes: 43200
16
+ ---
app.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
+
21
+ import functools
22
+ import os
23
+ import tempfile
24
+
25
+ import diffusers
26
+ import gradio as gr
27
+ import imageio as imageio
28
+ import numpy as np
29
+ import spaces
30
+ import torch as torch
31
+ from PIL import Image
32
+ from gradio_imageslider import ImageSlider
33
+ from tqdm import tqdm
34
+
35
+ from pathlib import Path
36
+
37
+ import gradio
38
+ from gradio.utils import get_cache_folder
39
+ from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline
40
+ from stablenormal.pipeline_stablenormal import StableNormalPipeline
41
+ from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler
42
+
43
+ class Examples(gradio.helpers.Examples):
44
+ def __init__(self, *args, directory_name=None, **kwargs):
45
+ super().__init__(*args, **kwargs, _initiated_directly=False)
46
+ if directory_name is not None:
47
+ self.cached_folder = get_cache_folder() / directory_name
48
+ self.cached_file = Path(self.cached_folder) / "log.csv"
49
+ self.create()
50
+
51
+
52
+ default_seed = 2024
53
+ default_batch_size = 1
54
+
55
+ default_image_processing_resolution = 768
56
+
57
+ default_video_num_inference_steps = 10
58
+ default_video_processing_resolution = 768
59
+ default_video_out_max_frames = 450
60
+
61
+ def process_image_check(path_input):
62
+ if path_input is None:
63
+ raise gr.Error(
64
+ "Missing image in the first pane: upload a file or use one from the gallery below."
65
+ )
66
+
67
+ def process_image(
68
+ pipe,
69
+ path_input,
70
+ ):
71
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
72
+ print(f"Processing image {name_base}{name_ext}")
73
+
74
+ path_output_dir = tempfile.mkdtemp()
75
+ # path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_normal_fp32.npy")
76
+ path_out_png = os.path.join(path_output_dir, f"{name_base}_normal_colored.png")
77
+
78
+ input_image = Image.open(path_input)
79
+ input_image = center_crop(input_image)
80
+
81
+ pipe_out = pipe(
82
+ input_image,
83
+ match_input_resolution=False,
84
+ return_intermediate_result=False
85
+ )
86
+
87
+ normal_pred = pipe_out.prediction[0, :, :]
88
+ normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction)
89
+ normal_colored[-1].save(path_out_png)
90
+ print(path_out_png)
91
+ # np.save(path_out_fp32, normal_pred)
92
+ # path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_refinement_process.gif")
93
+ # normal_colored[0].save(path_out_vis, save_all=True,
94
+ # append_images=normal_colored[1:],
95
+ # duration=400, loop=0)
96
+ return [input_image, path_out_png]
97
+
98
+ def center_crop(img):
99
+ # Open the image file
100
+ img_width, img_height = img.size
101
+ crop_width =min(img_width, img_height)
102
+ # Calculate the cropping box
103
+ left = (img_width - crop_width) / 2
104
+ top = (img_height - crop_width) / 2
105
+ right = (img_width + crop_width) / 2
106
+ bottom = (img_height + crop_width) / 2
107
+
108
+ # Crop the image
109
+ img_cropped = img.crop((left, top, right, bottom))
110
+ return img_cropped
111
+
112
+ def process_video(
113
+ pipe,
114
+ path_input,
115
+ out_max_frames=default_video_out_max_frames,
116
+ target_fps=3,
117
+ progress=gr.Progress(),
118
+ ):
119
+ if path_input is None:
120
+ raise gr.Error(
121
+ "Missing video in the first pane: upload a file or use one from the gallery below."
122
+ )
123
+
124
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
125
+ print(f"Processing video {name_base}{name_ext}")
126
+
127
+ path_output_dir = tempfile.mkdtemp()
128
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.mp4")
129
+
130
+ reader, writer = None, None
131
+ try:
132
+ reader = imageio.get_reader(path_input)
133
+
134
+ meta_data = reader.get_meta_data()
135
+ fps = meta_data["fps"]
136
+ size = meta_data["size"]
137
+ duration_sec = meta_data["duration"]
138
+
139
+ writer = imageio.get_writer(path_out_vis, fps=target_fps)
140
+
141
+ out_frame_id = 0
142
+ pbar = tqdm(desc="Processing Video", total=duration_sec)
143
+
144
+ for frame_id, frame in enumerate(reader):
145
+ if frame_id % (fps // target_fps) != 0:
146
+ continue
147
+ else:
148
+ out_frame_id += 1
149
+ pbar.update(1)
150
+ if out_frame_id > out_max_frames:
151
+ break
152
+
153
+ frame_pil = Image.fromarray(frame)
154
+ frame_pil = center_crop(frame_pil)
155
+ pipe_out = pipe(
156
+ frame_pil,
157
+ match_input_resolution=False,
158
+ return_intermediate_result=False
159
+ )
160
+
161
+ processed_frame = pipe.image_processor.visualize_normals( # noqa
162
+ pipe_out.prediction
163
+ )[0]
164
+ processed_frame = np.array(processed_frame)
165
+
166
+ _processed_frame = imageio.core.util.Array(processed_frame)
167
+ writer.append_data(_processed_frame)
168
+
169
+ yield (
170
+ [frame_pil, processed_frame],
171
+ None,
172
+ )
173
+ finally:
174
+
175
+ if writer is not None:
176
+ writer.close()
177
+
178
+ if reader is not None:
179
+ reader.close()
180
+
181
+ yield (
182
+ [frame_pil, processed_frame],
183
+ [path_out_vis,]
184
+ )
185
+
186
+
187
+ def run_demo_server(pipe):
188
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
189
+ process_pipe_video = spaces.GPU(
190
+ functools.partial(process_video, pipe), duration=120
191
+ )
192
+
193
+ gradio_theme = gr.themes.Default()
194
+
195
+ with gr.Blocks(
196
+ theme=gradio_theme,
197
+ title="Stable Normal Estimation",
198
+ css="""
199
+ #download {
200
+ height: 118px;
201
+ }
202
+ .slider .inner {
203
+ width: 5px;
204
+ background: #FFF;
205
+ }
206
+ .viewport {
207
+ aspect-ratio: 4/3;
208
+ }
209
+ .tabs button.selected {
210
+ font-size: 20px !important;
211
+ color: crimson !important;
212
+ }
213
+ h1 {
214
+ text-align: center;
215
+ display: block;
216
+ }
217
+ h2 {
218
+ text-align: center;
219
+ display: block;
220
+ }
221
+ h3 {
222
+ text-align: center;
223
+ display: block;
224
+ }
225
+ .md_feedback li {
226
+ margin-bottom: 0px !important;
227
+ }
228
+ """,
229
+ head="""
230
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
231
+ <script>
232
+ window.dataLayer = window.dataLayer || [];
233
+ function gtag() {dataLayer.push(arguments);}
234
+ gtag('js', new Date());
235
+ gtag('config', 'G-1FWSVCGZTG');
236
+ </script>
237
+ """,
238
+ ) as demo:
239
+ gr.Markdown(
240
+ """
241
+ # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal
242
+ <p align="center">
243
+ """
244
+ )
245
+
246
+ with gr.Tabs(elem_classes=["tabs"]):
247
+ with gr.Tab("Image"):
248
+ with gr.Row():
249
+ with gr.Column():
250
+ image_input = gr.Image(
251
+ label="Input Image",
252
+ type="filepath",
253
+ )
254
+ with gr.Row():
255
+ image_submit_btn = gr.Button(
256
+ value="Compute Normal", variant="primary"
257
+ )
258
+ image_reset_btn = gr.Button(value="Reset")
259
+ with gr.Column():
260
+ image_output_slider = ImageSlider(
261
+ label="Normal outputs",
262
+ type="filepath",
263
+ show_download_button=True,
264
+ show_share_button=True,
265
+ interactive=False,
266
+ elem_classes="slider",
267
+ position=0.25,
268
+ )
269
+
270
+ Examples(
271
+ fn=process_pipe_image,
272
+ examples=sorted([
273
+ os.path.join("files", "image", name)
274
+ for name in os.listdir(os.path.join("files", "image"))
275
+ ]),
276
+ inputs=[image_input],
277
+ outputs=[image_output_slider],
278
+ cache_examples=True,
279
+ directory_name="examples_image",
280
+ )
281
+
282
+ with gr.Tab("Video"):
283
+ with gr.Row():
284
+ with gr.Column():
285
+ video_input = gr.Video(
286
+ label="Input Video",
287
+ sources=["upload", "webcam"],
288
+ )
289
+ with gr.Row():
290
+ video_submit_btn = gr.Button(
291
+ value="Compute Normal", variant="primary"
292
+ )
293
+ video_reset_btn = gr.Button(value="Reset")
294
+ with gr.Column():
295
+ processed_frames = ImageSlider(
296
+ label="Realtime Visualization",
297
+ type="filepath",
298
+ show_download_button=True,
299
+ show_share_button=True,
300
+ interactive=False,
301
+ elem_classes="slider",
302
+ position=0.25,
303
+ )
304
+ video_output_files = gr.Files(
305
+ label="Normal outputs",
306
+ elem_id="download",
307
+ interactive=False,
308
+ )
309
+ Examples(
310
+ fn=process_pipe_video,
311
+ examples=sorted([
312
+ os.path.join("files", "video", name)
313
+ for name in os.listdir(os.path.join("files", "video"))
314
+ ]),
315
+ inputs=[video_input],
316
+ outputs=[processed_frames, video_output_files],
317
+ directory_name="examples_video",
318
+ cache_examples=True,
319
+ )
320
+
321
+ with gr.Tab("Panorama"):
322
+ with gr.Column():
323
+ gr.Markdown("Functionality coming soon on June.10th")
324
+
325
+ with gr.Tab("4K Image"):
326
+ with gr.Column():
327
+ gr.Markdown("Functionality coming soon on June.17th")
328
+
329
+ with gr.Tab("Normal Mapping"):
330
+ with gr.Column():
331
+ gr.Markdown("Functionality coming soon on June.24th")
332
+
333
+ with gr.Tab("Normal SuperResolution"):
334
+ with gr.Column():
335
+ gr.Markdown("Functionality coming soon on June.30th")
336
+
337
+ ### Image tab
338
+ image_submit_btn.click(
339
+ fn=process_image_check,
340
+ inputs=image_input,
341
+ outputs=None,
342
+ preprocess=False,
343
+ queue=False,
344
+ ).success(
345
+ fn=process_pipe_image,
346
+ inputs=[
347
+ image_input,
348
+ ],
349
+ outputs=[image_output_slider],
350
+ concurrency_limit=1,
351
+ )
352
+
353
+ image_reset_btn.click(
354
+ fn=lambda: (
355
+ None,
356
+ None,
357
+ None,
358
+ ),
359
+ inputs=[],
360
+ outputs=[
361
+ image_input,
362
+ image_output_slider,
363
+ ],
364
+ queue=False,
365
+ )
366
+
367
+ ### Video tab
368
+
369
+ video_submit_btn.click(
370
+ fn=process_pipe_video,
371
+ inputs=[video_input],
372
+ outputs=[processed_frames, video_output_files],
373
+ concurrency_limit=1,
374
+ )
375
+
376
+ video_reset_btn.click(
377
+ fn=lambda: (None, None, None),
378
+ inputs=[],
379
+ outputs=[video_input, processed_frames, video_output_files],
380
+ concurrency_limit=1,
381
+ )
382
+
383
+ ### Server launch
384
+
385
+ demo.queue(
386
+ api_open=False,
387
+ ).launch(
388
+ server_name="0.0.0.0",
389
+ server_port=7860,
390
+ )
391
+
392
+ from einops import rearrange
393
+ class DINOv2_Encoder:
394
+ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
395
+ IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
396
+
397
+ def __init__(
398
+ self,
399
+ model_name = 'dinov2_vitl14',
400
+ freeze = True,
401
+ antialias=True,
402
+ device="cuda",
403
+ size = 448,
404
+ ):
405
+
406
+ super(DINOv2_Encoder).__init__()
407
+
408
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name)
409
+ self.model.eval()
410
+ self.device = device
411
+ self.antialias = antialias
412
+ self.dtype = torch.float32
413
+
414
+ self.mean = torch.Tensor(self.IMAGENET_DEFAULT_MEAN)
415
+ self.std = torch.Tensor(self.IMAGENET_DEFAULT_STD)
416
+ self.size = size
417
+ if freeze:
418
+ self.freeze()
419
+
420
+
421
+ def freeze(self):
422
+ for param in self.model.parameters():
423
+ param.requires_grad = False
424
+
425
+ @torch.no_grad()
426
+ def encoder(self, x):
427
+ '''
428
+ x: [b h w c], range from (-1, 1), rbg
429
+ '''
430
+
431
+ x = self.preprocess(x).to(self.device, self.dtype)
432
+
433
+ b, c, h, w = x.shape
434
+ patch_h, patch_w = h // 14, w // 14
435
+
436
+ embeddings = self.model.forward_features(x)['x_norm_patchtokens']
437
+ embeddings = rearrange(embeddings, 'b (h w) c -> b h w c', h = patch_h, w = patch_w)
438
+
439
+ return rearrange(embeddings, 'b h w c -> b c h w')
440
+
441
+ def preprocess(self, x):
442
+ ''' x
443
+ '''
444
+ # normalize to [0,1],
445
+ x = torch.nn.functional.interpolate(
446
+ x,
447
+ size=(self.size, self.size),
448
+ mode='bicubic',
449
+ align_corners=True,
450
+ antialias=self.antialias,
451
+ )
452
+
453
+ x = (x + 1.0) / 2.0
454
+ # renormalize according to dino
455
+ mean = self.mean.view(1, 3, 1, 1).to(x.device)
456
+ std = self.std.view(1, 3, 1, 1).to(x.device)
457
+ x = (x - mean) / std
458
+
459
+ return x
460
+
461
+ def to(self, device, dtype=None):
462
+ if dtype is not None:
463
+ self.dtype = dtype
464
+ self.model.to(device, dtype)
465
+ self.mean.to(device, dtype)
466
+ self.std.to(device, dtype)
467
+ else:
468
+ self.model.to(device)
469
+ self.mean.to(device)
470
+ self.std.to(device)
471
+ return self
472
+
473
+ def __call__(self, x, **kwargs):
474
+ return self.encoder(x, **kwargs)
475
+
476
+ def main():
477
+ os.system("pip freeze")
478
+
479
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
480
+
481
+ x_start_pipeline = YOSONormalsPipeline.from_pretrained(
482
+ 'weights/yoso-normal-v0-1', trust_remote_code=True,
483
+ t_start=300).to(device, torch.float16)
484
+ dinov2_prior = DINOv2_Encoder(size=672)
485
+ dinov2_prior.to(device, torch.float16)
486
+
487
+ pipe = StableNormalPipeline.from_pretrained('weights/stable-normal-v0-1', t_start=300, trust_remote_code=True,
488
+ scheduler=HEURI_DDIMScheduler(prediction_type='sample',
489
+ beta_start=0.00085, beta_end=0.0120,
490
+ beta_schedule = "scaled_linear"))
491
+ # two stage concat
492
+ pipe.x_start_pipeline = x_start_pipeline
493
+ pipe.prior = dinov2_prior
494
+ pipe.to(device, torch.float16)
495
+
496
+ try:
497
+ import xformers
498
+ pipe.enable_xformers_memory_efficient_attention()
499
+ except:
500
+ pass # run without xformers
501
+
502
+ run_demo_server(pipe)
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.1
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.5
4
+ aiosignal==1.3.1
5
+ altair==5.3.0
6
+ annotated-types==0.7.0
7
+ anyio==4.4.0
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ Authlib==1.3.0
11
+ certifi==2024.2.2
12
+ cffi==1.16.0
13
+ charset-normalizer==3.3.2
14
+ click==8.0.4
15
+ contourpy==1.2.1
16
+ cryptography==42.0.7
17
+ cycler==0.12.1
18
+ dataclasses-json==0.6.6
19
+ datasets==2.19.1
20
+ Deprecated==1.2.14
21
+ diffusers==0.28.0
22
+ dill==0.3.8
23
+ dnspython==2.6.1
24
+ email_validator==2.1.1
25
+ exceptiongroup==1.2.1
26
+ fastapi==0.111.0
27
+ fastapi-cli==0.0.4
28
+ ffmpy==0.3.2
29
+ filelock==3.14.0
30
+ fonttools==4.53.0
31
+ frozenlist==1.4.1
32
+ fsspec==2024.3.1
33
+ gradio==4.32.2
34
+ gradio_client==0.17.0
35
+ gradio_imageslider==0.0.20
36
+ h11==0.14.0
37
+ httpcore==1.0.5
38
+ httptools==0.6.1
39
+ httpx==0.27.0
40
+ huggingface-hub==0.23.0
41
+ idna==3.7
42
+ imageio==2.34.1
43
+ imageio-ffmpeg==0.5.0
44
+ importlib_metadata==7.1.0
45
+ importlib_resources==6.4.0
46
+ itsdangerous==2.2.0
47
+ Jinja2==3.1.4
48
+ jsonschema==4.22.0
49
+ jsonschema-specifications==2023.12.1
50
+ kiwisolver==1.4.5
51
+ markdown-it-py==3.0.0
52
+ MarkupSafe==2.1.5
53
+ marshmallow==3.21.2
54
+ matplotlib==3.8.2
55
+ mdurl==0.1.2
56
+ mpmath==1.3.0
57
+ multidict==6.0.5
58
+ multiprocess==0.70.16
59
+ mypy-extensions==1.0.0
60
+ networkx==3.3
61
+ numpy==1.26.4
62
+ nvidia-cublas-cu12==12.1.3.1
63
+ nvidia-cuda-cupti-cu12==12.1.105
64
+ nvidia-cuda-nvrtc-cu12==12.1.105
65
+ nvidia-cuda-runtime-cu12==12.1.105
66
+ nvidia-cudnn-cu12==8.9.2.26
67
+ nvidia-cufft-cu12==11.0.2.54
68
+ nvidia-curand-cu12==10.3.2.106
69
+ nvidia-cusolver-cu12==11.4.5.107
70
+ nvidia-cusparse-cu12==12.1.0.106
71
+ nvidia-nccl-cu12==2.19.3
72
+ nvidia-nvjitlink-cu12==12.5.40
73
+ nvidia-nvtx-cu12==12.1.105
74
+ orjson==3.10.3
75
+ packaging==24.0
76
+ pandas==2.2.2
77
+ pillow==10.3.0
78
+ protobuf==3.20.3
79
+ psutil==5.9.8
80
+ pyarrow==16.0.0
81
+ pyarrow-hotfix==0.6
82
+ pycparser==2.22
83
+ pydantic==2.7.2
84
+ pydantic_core==2.18.3
85
+ pydub==0.25.1
86
+ pygltflib==1.16.1
87
+ Pygments==2.18.0
88
+ pyparsing==3.1.2
89
+ python-dateutil==2.9.0.post0
90
+ python-dotenv==1.0.1
91
+ python-multipart==0.0.9
92
+ pytz==2024.1
93
+ PyYAML==6.0.1
94
+ referencing==0.35.1
95
+ regex==2024.5.15
96
+ requests==2.31.0
97
+ rich==13.7.1
98
+ rpds-py==0.18.1
99
+ ruff==0.4.7
100
+ safetensors==0.4.3
101
+ scipy==1.11.4
102
+ semantic-version==2.10.0
103
+ shellingham==1.5.4
104
+ six==1.16.0
105
+ sniffio==1.3.1
106
+ spaces==0.28.3
107
+ starlette==0.37.2
108
+ sympy==1.12.1
109
+ tokenizers==0.15.2
110
+ tomlkit==0.12.0
111
+ toolz==0.12.1
112
+ torch==2.2.0
113
+ tqdm==4.66.4
114
+ transformers==4.36.1
115
+ trimesh==4.0.5
116
+ triton==2.2.0
117
+ typer==0.12.3
118
+ typing-inspect==0.9.0
119
+ typing_extensions==4.11.0
120
+ tzdata==2024.1
121
+ ujson==5.10.0
122
+ urllib3==2.2.1
123
+ uvicorn==0.30.0
124
+ uvloop==0.19.0
125
+ watchfiles==0.22.0
126
+ websockets==11.0.3
127
+ wrapt==1.16.0
128
+ xformers==0.0.24
129
+ xxhash==3.4.1
130
+ yarl==1.9.4
131
+ zipp==3.19.1
132
+ einops==0.7.0
requirements_min.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.32.1
2
+ gradio-imageslider>=0.0.20
3
+ pygltflib==1.16.1
4
+ trimesh==4.0.5
5
+ imageio
6
+ imageio-ffmpeg
7
+ Pillow
8
+ einops==0.7.0
9
+
10
+ spaces
11
+ accelerate
12
+ diffusers>=0.28.0
13
+ matplotlib==3.8.2
14
+ scipy==1.11.4
15
+ torch==2.0.1
16
+ transformers==4.36.1
17
+ xformers==0.0.21
setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from setuptools import setup, find_packages
3
+
4
+ setup_path = Path(__file__).parent
5
+
6
+ setup(
7
+ name = "stablenormal",
8
+ packages=find_packages()
9
+ )
stablenormal/__init__.py ADDED
File without changes
stablenormal/pipeline_stablenormal.py ADDED
@@ -0,0 +1,1201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # --------------------------------------------------------------------------
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
+
27
+
28
+ from diffusers.image_processor import PipelineImageInput
29
+ from diffusers.models import (
30
+ AutoencoderKL,
31
+ UNet2DConditionModel,
32
+ ControlNetModel,
33
+ )
34
+ from diffusers.schedulers import (
35
+ DDIMScheduler
36
+ )
37
+
38
+ from diffusers.utils import (
39
+ BaseOutput,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+
44
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
45
+
46
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
47
+
48
+
49
+
50
+ from diffusers.utils.torch_utils import randn_tensor
51
+ from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
52
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
53
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
54
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
55
+ import torch.nn.functional as F
56
+
57
+ import pdb
58
+
59
+
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ EXAMPLE_DOC_STRING = """
65
+ Examples:
66
+ ```py
67
+ >>> import diffusers
68
+ >>> import torch
69
+
70
+ >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
71
+ ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
72
+ ... ).to("cuda")
73
+
74
+ >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
75
+ >>> normals = pipe(image)
76
+
77
+ >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
78
+ >>> vis[0].save("einstein_normals.png")
79
+ ```
80
+ """
81
+
82
+
83
+ @dataclass
84
+ class StableNormalOutput(BaseOutput):
85
+ """
86
+ Output class for Marigold monocular normals prediction pipeline.
87
+
88
+ Args:
89
+ prediction (`np.ndarray`, `torch.Tensor`):
90
+ Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
91
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
92
+ uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
93
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
94
+ \times 1 \times height \times width$.
95
+ latent (`None`, `torch.Tensor`):
96
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
97
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
98
+ """
99
+
100
+ prediction: Union[np.ndarray, torch.Tensor]
101
+ latent: Union[None, torch.Tensor]
102
+
103
+
104
+ class StableNormalPipeline(StableDiffusionControlNetPipeline):
105
+ """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
106
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
107
+
108
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
109
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
110
+
111
+ The pipeline also inherits the following loading methods:
112
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
113
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
114
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
115
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
116
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
117
+
118
+ Args:
119
+ vae ([`AutoencoderKL`]):
120
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
121
+ text_encoder ([`~transformers.CLIPTextModel`]):
122
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
123
+ tokenizer ([`~transformers.CLIPTokenizer`]):
124
+ A `CLIPTokenizer` to tokenize text.
125
+ unet ([`UNet2DConditionModel`]):
126
+ A `UNet2DConditionModel` to denoise the encoded image latents.
127
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
128
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
129
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
130
+ additional conditioning.
131
+ scheduler ([`SchedulerMixin`]):
132
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
133
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
134
+ safety_checker ([`StableDiffusionSafetyChecker`]):
135
+ Classification module that estimates whether generated images could be considered offensive or harmful.
136
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
137
+ about a model's potential harms.
138
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
139
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
140
+ """
141
+
142
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
143
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
144
+ _exclude_from_cpu_offload = ["safety_checker"]
145
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
146
+
147
+
148
+
149
+ def __init__(
150
+ self,
151
+ vae: AutoencoderKL,
152
+ text_encoder: CLIPTextModel,
153
+ tokenizer: CLIPTokenizer,
154
+ unet: UNet2DConditionModel,
155
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
156
+ dino_controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
157
+ scheduler: Union[DDIMScheduler],
158
+ safety_checker: StableDiffusionSafetyChecker,
159
+ feature_extractor: CLIPImageProcessor,
160
+ image_encoder: CLIPVisionModelWithProjection = None,
161
+ requires_safety_checker: bool = True,
162
+ default_denoising_steps: Optional[int] = 10,
163
+ default_processing_resolution: Optional[int] = 768,
164
+ prompt="The normal map",
165
+ empty_text_embedding=None,
166
+ t_start: Optional[int] = 401,
167
+ ):
168
+ super().__init__(
169
+ vae,
170
+ text_encoder,
171
+ tokenizer,
172
+ unet,
173
+ controlnet,
174
+ scheduler,
175
+ safety_checker,
176
+ feature_extractor,
177
+ image_encoder,
178
+ requires_safety_checker,
179
+ )
180
+
181
+ self.register_modules(
182
+ dino_controlnet=dino_controlnet,
183
+ )
184
+
185
+ self.vae_scale_factor = 768
186
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
187
+ self.dino_image_processor = lambda x: x / 127.5 -1.
188
+
189
+ self.default_denoising_steps = default_denoising_steps
190
+ self.default_processing_resolution = default_processing_resolution
191
+ self.prompt = prompt
192
+ self.prompt_embeds = None
193
+ self.empty_text_embedding = empty_text_embedding
194
+ self.t_start= torch.tensor(t_start) # target_out latents
195
+
196
+
197
+ def check_inputs(
198
+ self,
199
+ image: PipelineImageInput,
200
+ num_inference_steps: int,
201
+ ensemble_size: int,
202
+ processing_resolution: int,
203
+ resample_method_input: str,
204
+ resample_method_output: str,
205
+ batch_size: int,
206
+ ensembling_kwargs: Optional[Dict[str, Any]],
207
+ latents: Optional[torch.Tensor],
208
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
209
+ output_type: str,
210
+ output_uncertainty: bool,
211
+ ) -> int:
212
+ if num_inference_steps is None:
213
+ raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
214
+ if num_inference_steps < 1:
215
+ raise ValueError("`num_inference_steps` must be positive.")
216
+ if ensemble_size < 1:
217
+ raise ValueError("`ensemble_size` must be positive.")
218
+ if ensemble_size == 2:
219
+ logger.warning(
220
+ "`ensemble_size` == 2 results are similar to no ensembling (1); "
221
+ "consider increasing the value to at least 3."
222
+ )
223
+ if ensemble_size == 1 and output_uncertainty:
224
+ raise ValueError(
225
+ "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
226
+ "greater than 1."
227
+ )
228
+ if processing_resolution is None:
229
+ raise ValueError(
230
+ "`processing_resolution` is not specified and could not be resolved from the model config."
231
+ )
232
+ if processing_resolution < 0:
233
+ raise ValueError(
234
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
235
+ "downsampled processing."
236
+ )
237
+ if processing_resolution % self.vae_scale_factor != 0:
238
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
239
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
240
+ raise ValueError(
241
+ "`resample_method_input` takes string values compatible with PIL library: "
242
+ "nearest, nearest-exact, bilinear, bicubic, area."
243
+ )
244
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
245
+ raise ValueError(
246
+ "`resample_method_output` takes string values compatible with PIL library: "
247
+ "nearest, nearest-exact, bilinear, bicubic, area."
248
+ )
249
+ if batch_size < 1:
250
+ raise ValueError("`batch_size` must be positive.")
251
+ if output_type not in ["pt", "np"]:
252
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
253
+ if latents is not None and generator is not None:
254
+ raise ValueError("`latents` and `generator` cannot be used together.")
255
+ if ensembling_kwargs is not None:
256
+ if not isinstance(ensembling_kwargs, dict):
257
+ raise ValueError("`ensembling_kwargs` must be a dictionary.")
258
+ if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
259
+ raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
260
+
261
+ # image checks
262
+ num_images = 0
263
+ W, H = None, None
264
+ if not isinstance(image, list):
265
+ image = [image]
266
+ for i, img in enumerate(image):
267
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
268
+ if img.ndim not in (2, 3, 4):
269
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
270
+ H_i, W_i = img.shape[-2:]
271
+ N_i = 1
272
+ if img.ndim == 4:
273
+ N_i = img.shape[0]
274
+ elif isinstance(img, Image.Image):
275
+ W_i, H_i = img.size
276
+ N_i = 1
277
+ else:
278
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
279
+ if W is None:
280
+ W, H = W_i, H_i
281
+ elif (W, H) != (W_i, H_i):
282
+ raise ValueError(
283
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
284
+ )
285
+ num_images += N_i
286
+
287
+ # latents checks
288
+ if latents is not None:
289
+ if not torch.is_tensor(latents):
290
+ raise ValueError("`latents` must be a torch.Tensor.")
291
+ if latents.dim() != 4:
292
+ raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
293
+
294
+ if processing_resolution > 0:
295
+ max_orig = max(H, W)
296
+ new_H = H * processing_resolution // max_orig
297
+ new_W = W * processing_resolution // max_orig
298
+ if new_H == 0 or new_W == 0:
299
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
300
+ W, H = new_W, new_H
301
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
302
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
303
+ shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
304
+
305
+ if latents.shape != shape_expected:
306
+ raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
307
+
308
+ # generator checks
309
+ if generator is not None:
310
+ if isinstance(generator, list):
311
+ if len(generator) != num_images * ensemble_size:
312
+ raise ValueError(
313
+ "The number of generators must match the total number of ensemble members for all input images."
314
+ )
315
+ if not all(g.device.type == generator[0].device.type for g in generator):
316
+ raise ValueError("`generator` device placement is not consistent in the list.")
317
+ elif not isinstance(generator, torch.Generator):
318
+ raise ValueError(f"Unsupported generator type: {type(generator)}.")
319
+
320
+ return num_images
321
+
322
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
323
+ if not hasattr(self, "_progress_bar_config"):
324
+ self._progress_bar_config = {}
325
+ elif not isinstance(self._progress_bar_config, dict):
326
+ raise ValueError(
327
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
328
+ )
329
+
330
+ progress_bar_config = dict(**self._progress_bar_config)
331
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
332
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
333
+ if iterable is not None:
334
+ return tqdm(iterable, **progress_bar_config)
335
+ elif total is not None:
336
+ return tqdm(total=total, **progress_bar_config)
337
+ else:
338
+ raise ValueError("Either `total` or `iterable` has to be defined.")
339
+
340
+ @torch.no_grad()
341
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
342
+ def __call__(
343
+ self,
344
+ image: PipelineImageInput,
345
+ prompt: Union[str, List[str]] = None,
346
+ negative_prompt: Optional[Union[str, List[str]]] = None,
347
+ num_inference_steps: Optional[int] = None,
348
+ ensemble_size: int = 1,
349
+ processing_resolution: Optional[int] = None,
350
+ return_intermediate_result: bool = False,
351
+ match_input_resolution: bool = True,
352
+ resample_method_input: str = "bilinear",
353
+ resample_method_output: str = "bilinear",
354
+ batch_size: int = 1,
355
+ ensembling_kwargs: Optional[Dict[str, Any]] = None,
356
+ latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
357
+ prompt_embeds: Optional[torch.Tensor] = None,
358
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
359
+ num_images_per_prompt: Optional[int] = 1,
360
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
361
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
362
+ output_type: str = "np",
363
+ output_uncertainty: bool = False,
364
+ output_latent: bool = False,
365
+ return_dict: bool = True,
366
+ ):
367
+ """
368
+ Function invoked when calling the pipeline.
369
+
370
+ Args:
371
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
372
+ `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
373
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
374
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
375
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
376
+ same width and height.
377
+ num_inference_steps (`int`, *optional*, defaults to `None`):
378
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
379
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
380
+ for Marigold-LCM models.
381
+ ensemble_size (`int`, defaults to `1`):
382
+ Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
383
+ faster inference.
384
+ processing_resolution (`int`, *optional*, defaults to `None`):
385
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
386
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
387
+ value `None` resolves to the optimal value from the model config.
388
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
389
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
390
+ side of the output will equal to `processing_resolution`.
391
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
392
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
393
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
394
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
395
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
396
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
397
+ batch_size (`int`, *optional*, defaults to `1`):
398
+ Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
399
+ ensembling_kwargs (`dict`, *optional*, defaults to `None`)
400
+ Extra dictionary with arguments for precise ensembling control. The following options are available:
401
+ - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
402
+ every pixel location, can be either `"closest"` or `"mean"`.
403
+ latents (`torch.Tensor`, *optional*, defaults to `None`):
404
+ Latent noise tensors to replace the random initialization. These can be taken from the previous
405
+ function call's output.
406
+ generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
407
+ Random number generator object to ensure reproducibility.
408
+ output_type (`str`, *optional*, defaults to `"np"`):
409
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
410
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
411
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
412
+ When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
413
+ the `ensemble_size` argument is set to a value above 2.
414
+ output_latent (`bool`, *optional*, defaults to `False`):
415
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
416
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
417
+ `latents` argument.
418
+ return_dict (`bool`, *optional*, defaults to `True`):
419
+ Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
420
+
421
+ Examples:
422
+
423
+ Returns:
424
+ [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
425
+ If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
426
+ `tuple` is returned where the first element is the prediction, the second element is the uncertainty
427
+ (or `None`), and the third is the latent (or `None`).
428
+ """
429
+
430
+ # 0. Resolving variables.
431
+ device = self._execution_device
432
+ dtype = self.dtype
433
+
434
+ # Model-specific optimal default values leading to fast and reasonable results.
435
+ if num_inference_steps is None:
436
+ num_inference_steps = self.default_denoising_steps
437
+ if processing_resolution is None:
438
+ processing_resolution = self.default_processing_resolution
439
+
440
+
441
+ image, padding, original_resolution = self.image_processor.preprocess(
442
+ image, processing_resolution, resample_method_input, device, dtype
443
+ ) # [N,3,PPH,PPW]
444
+
445
+ # 0. X_start latent obtain
446
+ predictor = self.x_start_pipeline(image, skip_preprocess=True)
447
+ x_start_latent = predictor.latent
448
+ gauss_latent = predictor.gauss_latent
449
+
450
+ # 1. Check inputs.
451
+ num_images = self.check_inputs(
452
+ image,
453
+ num_inference_steps,
454
+ ensemble_size,
455
+ processing_resolution,
456
+ resample_method_input,
457
+ resample_method_output,
458
+ batch_size,
459
+ ensembling_kwargs,
460
+ latents,
461
+ generator,
462
+ output_type,
463
+ output_uncertainty,
464
+ )
465
+
466
+
467
+ # 2. Prepare empty text conditioning.
468
+ # Model invocation: self.tokenizer, self.text_encoder.
469
+ if self.empty_text_embedding is None:
470
+ prompt = ""
471
+ text_inputs = self.tokenizer(
472
+ prompt,
473
+ padding="do_not_pad",
474
+ max_length=self.tokenizer.model_max_length,
475
+ truncation=True,
476
+ return_tensors="pt",
477
+ )
478
+ text_input_ids = text_inputs.input_ids.to(device)
479
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
480
+
481
+
482
+
483
+ # 3. prepare prompt
484
+ if self.prompt_embeds is None:
485
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
486
+ self.prompt,
487
+ device,
488
+ num_images_per_prompt,
489
+ False,
490
+ negative_prompt,
491
+ prompt_embeds=prompt_embeds,
492
+ negative_prompt_embeds=None,
493
+ lora_scale=None,
494
+ clip_skip=None,
495
+ )
496
+ self.prompt_embeds = prompt_embeds
497
+ self.negative_prompt_embeds = negative_prompt_embeds
498
+
499
+
500
+
501
+ # 5. dino guider features obtaining
502
+ ## TODO different case-1
503
+ dino_features = self.prior(image)
504
+ dino_features = self.dino_controlnet.dino_controlnet_cond_embedding(dino_features)
505
+ dino_features = self.match_noisy(dino_features, x_start_latent)
506
+
507
+ # 6. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
508
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
509
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
510
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
511
+ # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
512
+ # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
513
+ # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
514
+ # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
515
+ # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
516
+ # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
517
+ # Model invocation: self.vae.encoder.
518
+ image_latent, pred_latent = self.prepare_latents(
519
+ image, latents, generator, ensemble_size, batch_size
520
+ ) # [N*E,4,h,w], [N*E,4,h,w]
521
+
522
+
523
+ del (
524
+ image,
525
+ )
526
+
527
+ # 7. denoise sampling, using heuritic sampling proposed by Ye.
528
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
529
+
530
+ cond_scale =controlnet_conditioning_scale
531
+ pred_latent = x_start_latent
532
+
533
+ cur_step = 0
534
+
535
+ # dino controlnet
536
+ dino_down_block_res_samples, dino_mid_block_res_sample = self.dino_controlnet(
537
+ dino_features.detach(),
538
+ 0, # not depend on time steps
539
+ encoder_hidden_states=self.prompt_embeds,
540
+ conditioning_scale=cond_scale,
541
+ guess_mode=False,
542
+ return_dict=False,
543
+ )
544
+ assert dino_mid_block_res_sample == None
545
+
546
+ pred_latents = []
547
+
548
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
549
+ image_latent.detach(),
550
+ self.t_start,
551
+ encoder_hidden_states=self.prompt_embeds,
552
+ conditioning_scale=cond_scale,
553
+ guess_mode=False,
554
+ return_dict=False,
555
+ )
556
+ last_pred_latent = pred_latent
557
+ for i in range(4):
558
+ _dino_down_block_res_samples = [dino_down_block_res_sample for dino_down_block_res_sample in dino_down_block_res_samples] # copy, avoid repeat quiery
559
+
560
+ model_output = self.dino_unet_forward(
561
+ self.unet,
562
+ pred_latent,
563
+ self.t_start,
564
+ encoder_hidden_states=self.prompt_embeds,
565
+ down_block_additional_residuals=down_block_res_samples,
566
+ mid_block_additional_residual=mid_block_res_sample,
567
+ dino_down_block_additional_residuals= _dino_down_block_res_samples,
568
+ return_dict=False,
569
+ )[0] # [B,4,h,w]
570
+ pred_latents.append(model_output)
571
+ pred_latent = self.scheduler.add_noise(model_output, gauss_latent, self.t_start)
572
+ pred_latent = 0.4 * pred_latent + 0.6 * last_pred_latent
573
+ last_pred_latent = pred_latent
574
+ pred_latents = torch.cat(pred_latents, dim=0)
575
+ del (
576
+ image_latent,
577
+ dino_features,
578
+ )
579
+
580
+
581
+ # decoder
582
+ if return_intermediate_result:
583
+ prediction = []
584
+ for _pred_latent in pred_latents:
585
+ _prediction = self.decode_prediction(_pred_latent.unsqueeze(dim=0))
586
+ prediction.append(_prediction)
587
+ prediction = torch.cat(prediction, dim=0)
588
+ else:
589
+ prediction = self.decode_prediction(pred_latents[-1].unsqueeze(dim=0))
590
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
591
+
592
+ if match_input_resolution:
593
+ prediction = self.image_processor.resize_antialias(
594
+ prediction, original_resolution, resample_method_output, is_aa=False
595
+ ) # [N,3,H,W]
596
+ prediction = self.normalize_normals(prediction) # [N,3,H,W]
597
+
598
+ if output_type == "np":
599
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
600
+ prediction = prediction.clip(min=-1, max=1)
601
+
602
+ # 11. Offload all models
603
+ self.maybe_free_model_hooks()
604
+
605
+ return StableNormalOutput(
606
+ prediction=prediction,
607
+ latent=pred_latent,
608
+ )
609
+
610
+ # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
611
+ def prepare_latents(
612
+ self,
613
+ image: torch.Tensor,
614
+ latents: Optional[torch.Tensor],
615
+ generator: Optional[torch.Generator],
616
+ ensemble_size: int,
617
+ batch_size: int,
618
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
619
+ def retrieve_latents(encoder_output):
620
+ if hasattr(encoder_output, "latent_dist"):
621
+ return encoder_output.latent_dist.mode()
622
+ elif hasattr(encoder_output, "latents"):
623
+ return encoder_output.latents
624
+ else:
625
+ raise AttributeError("Could not access latents of provided encoder_output")
626
+
627
+
628
+
629
+ image_latent = torch.cat(
630
+ [
631
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
632
+ for i in range(0, image.shape[0], batch_size)
633
+ ],
634
+ dim=0,
635
+ ) # [N,4,h,w]
636
+ image_latent = image_latent * self.vae.config.scaling_factor
637
+ image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
638
+
639
+ pred_latent = latents
640
+ if pred_latent is None:
641
+
642
+
643
+ pred_latent = randn_tensor(
644
+ image_latent.shape,
645
+ generator=generator,
646
+ device=image_latent.device,
647
+ dtype=image_latent.dtype,
648
+ ) # [N*E,4,h,w]
649
+
650
+ return image_latent, pred_latent
651
+
652
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
653
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
654
+ raise ValueError(
655
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
656
+ )
657
+
658
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
659
+
660
+ return prediction # [B,3,H,W]
661
+
662
+ @staticmethod
663
+ def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
664
+ if normals.dim() != 4 or normals.shape[1] != 3:
665
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
666
+
667
+ norm = torch.norm(normals, dim=1, keepdim=True)
668
+ normals /= norm.clamp(min=eps)
669
+
670
+ return normals
671
+
672
+ @staticmethod
673
+ def match_noisy(dino, noisy):
674
+ _, __, dino_h, dino_w = dino.shape
675
+ _, __, h, w = noisy.shape
676
+
677
+ if h == dino_h and w == dino_w:
678
+ return dino
679
+ else:
680
+ return F.interpolate(dino, (h, w), mode='bilinear')
681
+
682
+
683
+
684
+
685
+
686
+
687
+
688
+
689
+
690
+
691
+ @staticmethod
692
+ def dino_unet_forward(
693
+ self, # NOTE that repurpose to UNet
694
+ sample: torch.Tensor,
695
+ timestep: Union[torch.Tensor, float, int],
696
+ encoder_hidden_states: torch.Tensor,
697
+ class_labels: Optional[torch.Tensor] = None,
698
+ timestep_cond: Optional[torch.Tensor] = None,
699
+ attention_mask: Optional[torch.Tensor] = None,
700
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
701
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
702
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
703
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
704
+ dino_down_block_additional_residuals: Optional[torch.Tensor] = None,
705
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
706
+ encoder_attention_mask: Optional[torch.Tensor] = None,
707
+ return_dict: bool = True,
708
+ ) -> Union[UNet2DConditionOutput, Tuple]:
709
+ r"""
710
+ The [`UNet2DConditionModel`] forward method.
711
+
712
+ Args:
713
+ sample (`torch.Tensor`):
714
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
715
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
716
+ encoder_hidden_states (`torch.Tensor`):
717
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
718
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
719
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
720
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
721
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
722
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
723
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
724
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
725
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
726
+ negative values to the attention scores corresponding to "discard" tokens.
727
+ cross_attention_kwargs (`dict`, *optional*):
728
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
729
+ `self.processor` in
730
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
731
+ added_cond_kwargs: (`dict`, *optional*):
732
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
733
+ are passed along to the UNet blocks.
734
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
735
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
736
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
737
+ A tensor that if specified is added to the residual of the middle unet block.
738
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
739
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
740
+ encoder_attention_mask (`torch.Tensor`):
741
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
742
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
743
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
744
+ return_dict (`bool`, *optional*, defaults to `True`):
745
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
746
+ tuple.
747
+
748
+ Returns:
749
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
750
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
751
+ otherwise a `tuple` is returned where the first element is the sample tensor.
752
+ """
753
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
754
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
755
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
756
+ # on the fly if necessary.
757
+
758
+
759
+ default_overall_up_factor = 2**self.num_upsamplers
760
+
761
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
762
+ forward_upsample_size = False
763
+ upsample_size = None
764
+
765
+ for dim in sample.shape[-2:]:
766
+ if dim % default_overall_up_factor != 0:
767
+ # Forward upsample size to force interpolation output size.
768
+ forward_upsample_size = True
769
+ break
770
+
771
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
772
+ # expects mask of shape:
773
+ # [batch, key_tokens]
774
+ # adds singleton query_tokens dimension:
775
+ # [batch, 1, key_tokens]
776
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
777
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
778
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
779
+ if attention_mask is not None:
780
+ # assume that mask is expressed as:
781
+ # (1 = keep, 0 = discard)
782
+ # convert mask into a bias that can be added to attention scores:
783
+ # (keep = +0, discard = -10000.0)
784
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
785
+ attention_mask = attention_mask.unsqueeze(1)
786
+
787
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
788
+ if encoder_attention_mask is not None:
789
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
790
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
791
+
792
+ # 0. center input if necessary
793
+ if self.config.center_input_sample:
794
+ sample = 2 * sample - 1.0
795
+
796
+ # 1. time
797
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
798
+ emb = self.time_embedding(t_emb, timestep_cond)
799
+ aug_emb = None
800
+
801
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
802
+ if class_emb is not None:
803
+ if self.config.class_embeddings_concat:
804
+ emb = torch.cat([emb, class_emb], dim=-1)
805
+ else:
806
+ emb = emb + class_emb
807
+
808
+ aug_emb = self.get_aug_embed(
809
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
810
+ )
811
+ if self.config.addition_embed_type == "image_hint":
812
+ aug_emb, hint = aug_emb
813
+ sample = torch.cat([sample, hint], dim=1)
814
+
815
+ emb = emb + aug_emb if aug_emb is not None else emb
816
+
817
+ if self.time_embed_act is not None:
818
+ emb = self.time_embed_act(emb)
819
+
820
+ encoder_hidden_states = self.process_encoder_hidden_states(
821
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
822
+ )
823
+
824
+ # 2. pre-process
825
+ sample = self.conv_in(sample)
826
+
827
+ # 2.5 GLIGEN position net
828
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
829
+ cross_attention_kwargs = cross_attention_kwargs.copy()
830
+ gligen_args = cross_attention_kwargs.pop("gligen")
831
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
832
+
833
+ # 3. down
834
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
835
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
836
+ if cross_attention_kwargs is not None:
837
+ cross_attention_kwargs = cross_attention_kwargs.copy()
838
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
839
+ else:
840
+ lora_scale = 1.0
841
+
842
+ if USE_PEFT_BACKEND:
843
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
844
+ scale_lora_layers(self, lora_scale)
845
+
846
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
847
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
848
+ is_adapter = down_intrablock_additional_residuals is not None
849
+ # maintain backward compatibility for legacy usage, where
850
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
851
+ # but can only use one or the other
852
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
853
+ deprecate(
854
+ "T2I should not use down_block_additional_residuals",
855
+ "1.3.0",
856
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
857
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
858
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
859
+ standard_warn=False,
860
+ )
861
+ down_intrablock_additional_residuals = down_block_additional_residuals
862
+ is_adapter = True
863
+
864
+
865
+
866
+ def residual_downforward(
867
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None,
868
+ additional_residuals: Optional[torch.Tensor] = None,
869
+ *args, **kwargs,
870
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
871
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
872
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
873
+ deprecate("scale", "1.0.0", deprecation_message)
874
+
875
+ output_states = ()
876
+
877
+ for resnet in self.resnets:
878
+ if self.training and self.gradient_checkpointing:
879
+
880
+ def create_custom_forward(module):
881
+ def custom_forward(*inputs):
882
+ return module(*inputs)
883
+
884
+ return custom_forward
885
+
886
+ if is_torch_version(">=", "1.11.0"):
887
+ hidden_states = torch.utils.checkpoint.checkpoint(
888
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
889
+ )
890
+ else:
891
+ hidden_states = torch.utils.checkpoint.checkpoint(
892
+ create_custom_forward(resnet), hidden_states, temb
893
+ )
894
+ else:
895
+ hidden_states = resnet(hidden_states, temb)
896
+ hidden_states += additional_residuals.pop(0)
897
+
898
+
899
+ output_states = output_states + (hidden_states,)
900
+
901
+ if self.downsamplers is not None:
902
+ for downsampler in self.downsamplers:
903
+ hidden_states = downsampler(hidden_states)
904
+ hidden_states += additional_residuals.pop(0)
905
+
906
+ output_states = output_states + (hidden_states,)
907
+
908
+ return hidden_states, output_states
909
+
910
+
911
+ def residual_blockforward(
912
+ self, ## NOTE that repurpose to unet_blocks
913
+ hidden_states: torch.Tensor,
914
+ temb: Optional[torch.Tensor] = None,
915
+ encoder_hidden_states: Optional[torch.Tensor] = None,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
918
+ encoder_attention_mask: Optional[torch.Tensor] = None,
919
+ additional_residuals: Optional[torch.Tensor] = None,
920
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
921
+ if cross_attention_kwargs is not None:
922
+ if cross_attention_kwargs.get("scale", None) is not None:
923
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
924
+
925
+
926
+
927
+ output_states = ()
928
+
929
+ blocks = list(zip(self.resnets, self.attentions))
930
+
931
+ for i, (resnet, attn) in enumerate(blocks):
932
+ if self.training and self.gradient_checkpointing:
933
+
934
+ def create_custom_forward(module, return_dict=None):
935
+ def custom_forward(*inputs):
936
+ if return_dict is not None:
937
+ return module(*inputs, return_dict=return_dict)
938
+ else:
939
+ return module(*inputs)
940
+
941
+ return custom_forward
942
+
943
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
944
+ hidden_states = torch.utils.checkpoint.checkpoint(
945
+ create_custom_forward(resnet),
946
+ hidden_states,
947
+ temb,
948
+ **ckpt_kwargs,
949
+ )
950
+ hidden_states = attn(
951
+ hidden_states,
952
+ encoder_hidden_states=encoder_hidden_states,
953
+ cross_attention_kwargs=cross_attention_kwargs,
954
+ attention_mask=attention_mask,
955
+ encoder_attention_mask=encoder_attention_mask,
956
+ return_dict=False,
957
+ )[0]
958
+ else:
959
+ hidden_states = resnet(hidden_states, temb)
960
+ hidden_states = attn(
961
+ hidden_states,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ cross_attention_kwargs=cross_attention_kwargs,
964
+ attention_mask=attention_mask,
965
+ encoder_attention_mask=encoder_attention_mask,
966
+ return_dict=False,
967
+ )[0]
968
+
969
+ hidden_states += additional_residuals.pop(0)
970
+
971
+ output_states = output_states + (hidden_states,)
972
+
973
+ if self.downsamplers is not None:
974
+ for downsampler in self.downsamplers:
975
+ hidden_states = downsampler(hidden_states)
976
+ hidden_states += additional_residuals.pop(0)
977
+
978
+ output_states = output_states + (hidden_states,)
979
+
980
+ return hidden_states, output_states
981
+
982
+
983
+ down_intrablock_additional_residuals = dino_down_block_additional_residuals
984
+
985
+ sample += down_intrablock_additional_residuals.pop(0)
986
+ down_block_res_samples = (sample,)
987
+
988
+ for downsample_block in self.down_blocks:
989
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
990
+
991
+ sample, res_samples = residual_blockforward(
992
+ downsample_block,
993
+ hidden_states=sample,
994
+ temb=emb,
995
+ encoder_hidden_states=encoder_hidden_states,
996
+ attention_mask=attention_mask,
997
+ cross_attention_kwargs=cross_attention_kwargs,
998
+ encoder_attention_mask=encoder_attention_mask,
999
+ additional_residuals = down_intrablock_additional_residuals,
1000
+ )
1001
+
1002
+ else:
1003
+ sample, res_samples = residual_downforward(
1004
+ downsample_block,
1005
+ hidden_states=sample,
1006
+ temb=emb,
1007
+ additional_residuals = down_intrablock_additional_residuals,
1008
+ )
1009
+
1010
+
1011
+ down_block_res_samples += res_samples
1012
+
1013
+
1014
+ if is_controlnet:
1015
+ new_down_block_res_samples = ()
1016
+
1017
+ for down_block_res_sample, down_block_additional_residual in zip(
1018
+ down_block_res_samples, down_block_additional_residuals
1019
+ ):
1020
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1021
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1022
+
1023
+ down_block_res_samples = new_down_block_res_samples
1024
+
1025
+ # 4. mid
1026
+ if self.mid_block is not None:
1027
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1028
+ sample = self.mid_block(
1029
+ sample,
1030
+ emb,
1031
+ encoder_hidden_states=encoder_hidden_states,
1032
+ attention_mask=attention_mask,
1033
+ cross_attention_kwargs=cross_attention_kwargs,
1034
+ encoder_attention_mask=encoder_attention_mask,
1035
+ )
1036
+ else:
1037
+ sample = self.mid_block(sample, emb)
1038
+
1039
+ # To support T2I-Adapter-XL
1040
+ if (
1041
+ is_adapter
1042
+ and len(down_intrablock_additional_residuals) > 0
1043
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1044
+ ):
1045
+ sample += down_intrablock_additional_residuals.pop(0)
1046
+
1047
+ if is_controlnet:
1048
+ sample = sample + mid_block_additional_residual
1049
+
1050
+ # 5. up
1051
+ for i, upsample_block in enumerate(self.up_blocks):
1052
+ is_final_block = i == len(self.up_blocks) - 1
1053
+
1054
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1055
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1056
+
1057
+ # if we have not reached the final block and need to forward the
1058
+ # upsample size, we do it here
1059
+ if not is_final_block and forward_upsample_size:
1060
+ upsample_size = down_block_res_samples[-1].shape[2:]
1061
+
1062
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1063
+ sample = upsample_block(
1064
+ hidden_states=sample,
1065
+ temb=emb,
1066
+ res_hidden_states_tuple=res_samples,
1067
+ encoder_hidden_states=encoder_hidden_states,
1068
+ cross_attention_kwargs=cross_attention_kwargs,
1069
+ upsample_size=upsample_size,
1070
+ attention_mask=attention_mask,
1071
+ encoder_attention_mask=encoder_attention_mask,
1072
+ )
1073
+ else:
1074
+ sample = upsample_block(
1075
+ hidden_states=sample,
1076
+ temb=emb,
1077
+ res_hidden_states_tuple=res_samples,
1078
+ upsample_size=upsample_size,
1079
+ )
1080
+
1081
+ # 6. post-process
1082
+ if self.conv_norm_out:
1083
+ sample = self.conv_norm_out(sample)
1084
+ sample = self.conv_act(sample)
1085
+ sample = self.conv_out(sample)
1086
+
1087
+ if USE_PEFT_BACKEND:
1088
+ # remove `lora_scale` from each PEFT layer
1089
+ unscale_lora_layers(self, lora_scale)
1090
+
1091
+ if not return_dict:
1092
+ return (sample,)
1093
+
1094
+ return UNet2DConditionOutput(sample=sample)
1095
+
1096
+
1097
+
1098
+ @staticmethod
1099
+ def ensemble_normals(
1100
+ normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
1101
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1102
+ """
1103
+ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
1104
+ the number of ensemble members for a given prediction of size `(H x W)`.
1105
+
1106
+ Args:
1107
+ normals (`torch.Tensor`):
1108
+ Input ensemble normals maps.
1109
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
1110
+ Whether to output uncertainty map.
1111
+ reduction (`str`, *optional*, defaults to `"closest"`):
1112
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
1113
+ `"mean"`.
1114
+
1115
+ Returns:
1116
+ A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
1117
+ uncertainties of shape `(1, 1, H, W)`.
1118
+ """
1119
+ if normals.dim() != 4 or normals.shape[1] != 3:
1120
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
1121
+ if reduction not in ("closest", "mean"):
1122
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
1123
+
1124
+ mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
1125
+ mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
1126
+
1127
+ sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
1128
+ sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
1129
+
1130
+ uncertainty = None
1131
+ if output_uncertainty:
1132
+ uncertainty = sim_cos.arccos() # [E,1,H,W]
1133
+ uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
1134
+
1135
+ if reduction == "mean":
1136
+ return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
1137
+
1138
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
1139
+ closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
1140
+ closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
1141
+
1142
+ return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
1143
+
1144
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
1145
+ def retrieve_timesteps(
1146
+ scheduler,
1147
+ num_inference_steps: Optional[int] = None,
1148
+ device: Optional[Union[str, torch.device]] = None,
1149
+ timesteps: Optional[List[int]] = None,
1150
+ sigmas: Optional[List[float]] = None,
1151
+ **kwargs,
1152
+ ):
1153
+ """
1154
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
1155
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
1156
+
1157
+ Args:
1158
+ scheduler (`SchedulerMixin`):
1159
+ The scheduler to get timesteps from.
1160
+ num_inference_steps (`int`):
1161
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
1162
+ must be `None`.
1163
+ device (`str` or `torch.device`, *optional*):
1164
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
1165
+ timesteps (`List[int]`, *optional*):
1166
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
1167
+ `num_inference_steps` and `sigmas` must be `None`.
1168
+ sigmas (`List[float]`, *optional*):
1169
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
1170
+ `num_inference_steps` and `timesteps` must be `None`.
1171
+
1172
+ Returns:
1173
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
1174
+ second element is the number of inference steps.
1175
+ """
1176
+ if timesteps is not None and sigmas is not None:
1177
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
1178
+ if timesteps is not None:
1179
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1180
+ if not accepts_timesteps:
1181
+ raise ValueError(
1182
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1183
+ f" timestep schedules. Please check whether you are using the correct scheduler."
1184
+ )
1185
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
1186
+ timesteps = scheduler.timesteps
1187
+ num_inference_steps = len(timesteps)
1188
+ elif sigmas is not None:
1189
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1190
+ if not accept_sigmas:
1191
+ raise ValueError(
1192
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1193
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
1194
+ )
1195
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
1196
+ timesteps = scheduler.timesteps
1197
+ num_inference_steps = len(timesteps)
1198
+ else:
1199
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
1200
+ timesteps = scheduler.timesteps
1201
+ return timesteps, num_inference_steps
stablenormal/pipeline_yoso_normal.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # --------------------------------------------------------------------------
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
+
27
+
28
+ from diffusers.image_processor import PipelineImageInput
29
+ from diffusers.models import (
30
+ AutoencoderKL,
31
+ UNet2DConditionModel,
32
+ ControlNetModel,
33
+ )
34
+ from diffusers.schedulers import (
35
+ DDIMScheduler
36
+ )
37
+
38
+ from diffusers.utils import (
39
+ BaseOutput,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+
44
+
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
49
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
50
+
51
+ import pdb
52
+
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ EXAMPLE_DOC_STRING = """
59
+ Examples:
60
+ ```py
61
+ >>> import diffusers
62
+ >>> import torch
63
+
64
+ >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
65
+ ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
66
+ ... ).to("cuda")
67
+
68
+ >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
69
+ >>> normals = pipe(image)
70
+
71
+ >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
72
+ >>> vis[0].save("einstein_normals.png")
73
+ ```
74
+ """
75
+
76
+
77
+ @dataclass
78
+ class YosoNormalsOutput(BaseOutput):
79
+ """
80
+ Output class for Marigold monocular normals prediction pipeline.
81
+
82
+ Args:
83
+ prediction (`np.ndarray`, `torch.Tensor`):
84
+ Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
85
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
86
+ uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
87
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
88
+ \times 1 \times height \times width$.
89
+ latent (`None`, `torch.Tensor`):
90
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
91
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
92
+ """
93
+
94
+ prediction: Union[np.ndarray, torch.Tensor]
95
+ latent: Union[None, torch.Tensor]
96
+ gauss_latent: Union[None, torch.Tensor]
97
+
98
+
99
+ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
100
+ """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
101
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
102
+
103
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
104
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
105
+
106
+ The pipeline also inherits the following loading methods:
107
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
108
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
109
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
110
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
111
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
112
+
113
+ Args:
114
+ vae ([`AutoencoderKL`]):
115
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
116
+ text_encoder ([`~transformers.CLIPTextModel`]):
117
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
118
+ tokenizer ([`~transformers.CLIPTokenizer`]):
119
+ A `CLIPTokenizer` to tokenize text.
120
+ unet ([`UNet2DConditionModel`]):
121
+ A `UNet2DConditionModel` to denoise the encoded image latents.
122
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
123
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
124
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
125
+ additional conditioning.
126
+ scheduler ([`SchedulerMixin`]):
127
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
128
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
129
+ safety_checker ([`StableDiffusionSafetyChecker`]):
130
+ Classification module that estimates whether generated images could be considered offensive or harmful.
131
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
132
+ about a model's potential harms.
133
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
134
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
135
+ """
136
+
137
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
138
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
139
+ _exclude_from_cpu_offload = ["safety_checker"]
140
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
141
+
142
+
143
+
144
+ def __init__(
145
+ self,
146
+ vae: AutoencoderKL,
147
+ text_encoder: CLIPTextModel,
148
+ tokenizer: CLIPTokenizer,
149
+ unet: UNet2DConditionModel,
150
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
151
+ scheduler: Union[DDIMScheduler],
152
+ safety_checker: StableDiffusionSafetyChecker,
153
+ feature_extractor: CLIPImageProcessor,
154
+ image_encoder: CLIPVisionModelWithProjection = None,
155
+ requires_safety_checker: bool = True,
156
+ default_denoising_steps: Optional[int] = 1,
157
+ default_processing_resolution: Optional[int] = 768,
158
+ prompt="",
159
+ empty_text_embedding=None,
160
+ t_start: Optional[int] = 401,
161
+ ):
162
+ super().__init__(
163
+ vae,
164
+ text_encoder,
165
+ tokenizer,
166
+ unet,
167
+ controlnet,
168
+ scheduler,
169
+ safety_checker,
170
+ feature_extractor,
171
+ image_encoder,
172
+ requires_safety_checker,
173
+ )
174
+
175
+ # TODO yoso ImageProcessor
176
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
177
+ self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
178
+ self.default_denoising_steps = default_denoising_steps
179
+ self.default_processing_resolution = default_processing_resolution
180
+ self.prompt = prompt
181
+ self.prompt_embeds = None
182
+ self.empty_text_embedding = empty_text_embedding
183
+ self.t_start= t_start # target_out latents
184
+ self.gauss_latent = None
185
+
186
+ def check_inputs(
187
+ self,
188
+ image: PipelineImageInput,
189
+ num_inference_steps: int,
190
+ ensemble_size: int,
191
+ processing_resolution: int,
192
+ resample_method_input: str,
193
+ resample_method_output: str,
194
+ batch_size: int,
195
+ ensembling_kwargs: Optional[Dict[str, Any]],
196
+ latents: Optional[torch.Tensor],
197
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
198
+ output_type: str,
199
+ output_uncertainty: bool,
200
+ ) -> int:
201
+ if num_inference_steps is None:
202
+ raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
203
+ if num_inference_steps < 1:
204
+ raise ValueError("`num_inference_steps` must be positive.")
205
+ if ensemble_size < 1:
206
+ raise ValueError("`ensemble_size` must be positive.")
207
+ if ensemble_size == 2:
208
+ logger.warning(
209
+ "`ensemble_size` == 2 results are similar to no ensembling (1); "
210
+ "consider increasing the value to at least 3."
211
+ )
212
+ if ensemble_size == 1 and output_uncertainty:
213
+ raise ValueError(
214
+ "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
215
+ "greater than 1."
216
+ )
217
+ if processing_resolution is None:
218
+ raise ValueError(
219
+ "`processing_resolution` is not specified and could not be resolved from the model config."
220
+ )
221
+ if processing_resolution < 0:
222
+ raise ValueError(
223
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
224
+ "downsampled processing."
225
+ )
226
+ if processing_resolution % self.vae_scale_factor != 0:
227
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
228
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
229
+ raise ValueError(
230
+ "`resample_method_input` takes string values compatible with PIL library: "
231
+ "nearest, nearest-exact, bilinear, bicubic, area."
232
+ )
233
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
234
+ raise ValueError(
235
+ "`resample_method_output` takes string values compatible with PIL library: "
236
+ "nearest, nearest-exact, bilinear, bicubic, area."
237
+ )
238
+ if batch_size < 1:
239
+ raise ValueError("`batch_size` must be positive.")
240
+ if output_type not in ["pt", "np"]:
241
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
242
+ if latents is not None and generator is not None:
243
+ raise ValueError("`latents` and `generator` cannot be used together.")
244
+ if ensembling_kwargs is not None:
245
+ if not isinstance(ensembling_kwargs, dict):
246
+ raise ValueError("`ensembling_kwargs` must be a dictionary.")
247
+ if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
248
+ raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
249
+
250
+ # image checks
251
+ num_images = 0
252
+ W, H = None, None
253
+ if not isinstance(image, list):
254
+ image = [image]
255
+ for i, img in enumerate(image):
256
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
257
+ if img.ndim not in (2, 3, 4):
258
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
259
+ H_i, W_i = img.shape[-2:]
260
+ N_i = 1
261
+ if img.ndim == 4:
262
+ N_i = img.shape[0]
263
+ elif isinstance(img, Image.Image):
264
+ W_i, H_i = img.size
265
+ N_i = 1
266
+ else:
267
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
268
+ if W is None:
269
+ W, H = W_i, H_i
270
+ elif (W, H) != (W_i, H_i):
271
+ raise ValueError(
272
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
273
+ )
274
+ num_images += N_i
275
+
276
+ # latents checks
277
+ if latents is not None:
278
+ if not torch.is_tensor(latents):
279
+ raise ValueError("`latents` must be a torch.Tensor.")
280
+ if latents.dim() != 4:
281
+ raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
282
+
283
+ if processing_resolution > 0:
284
+ max_orig = max(H, W)
285
+ new_H = H * processing_resolution // max_orig
286
+ new_W = W * processing_resolution // max_orig
287
+ if new_H == 0 or new_W == 0:
288
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
289
+ W, H = new_W, new_H
290
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
291
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
292
+ shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
293
+
294
+ if latents.shape != shape_expected:
295
+ raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
296
+
297
+ # generator checks
298
+ if generator is not None:
299
+ if isinstance(generator, list):
300
+ if len(generator) != num_images * ensemble_size:
301
+ raise ValueError(
302
+ "The number of generators must match the total number of ensemble members for all input images."
303
+ )
304
+ if not all(g.device.type == generator[0].device.type for g in generator):
305
+ raise ValueError("`generator` device placement is not consistent in the list.")
306
+ elif not isinstance(generator, torch.Generator):
307
+ raise ValueError(f"Unsupported generator type: {type(generator)}.")
308
+
309
+ return num_images
310
+
311
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
312
+ if not hasattr(self, "_progress_bar_config"):
313
+ self._progress_bar_config = {}
314
+ elif not isinstance(self._progress_bar_config, dict):
315
+ raise ValueError(
316
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
317
+ )
318
+
319
+ progress_bar_config = dict(**self._progress_bar_config)
320
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
321
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
322
+ if iterable is not None:
323
+ return tqdm(iterable, **progress_bar_config)
324
+ elif total is not None:
325
+ return tqdm(total=total, **progress_bar_config)
326
+ else:
327
+ raise ValueError("Either `total` or `iterable` has to be defined.")
328
+
329
+ @torch.no_grad()
330
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
331
+ def __call__(
332
+ self,
333
+ image: PipelineImageInput,
334
+ prompt: Union[str, List[str]] = None,
335
+ negative_prompt: Optional[Union[str, List[str]]] = None,
336
+ num_inference_steps: Optional[int] = None,
337
+ ensemble_size: int = 1,
338
+ processing_resolution: Optional[int] = None,
339
+ match_input_resolution: bool = True,
340
+ resample_method_input: str = "bilinear",
341
+ resample_method_output: str = "bilinear",
342
+ batch_size: int = 1,
343
+ ensembling_kwargs: Optional[Dict[str, Any]] = None,
344
+ latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
345
+ prompt_embeds: Optional[torch.Tensor] = None,
346
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
347
+ num_images_per_prompt: Optional[int] = 1,
348
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
349
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
350
+ output_type: str = "np",
351
+ output_uncertainty: bool = False,
352
+ output_latent: bool = False,
353
+ skip_preprocess: bool = False,
354
+ return_dict: bool = True,
355
+ **kwargs,
356
+ ):
357
+ """
358
+ Function invoked when calling the pipeline.
359
+
360
+ Args:
361
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
362
+ `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
363
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
364
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
365
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
366
+ same width and height.
367
+ num_inference_steps (`int`, *optional*, defaults to `None`):
368
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
369
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
370
+ for Marigold-LCM models.
371
+ ensemble_size (`int`, defaults to `1`):
372
+ Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
373
+ faster inference.
374
+ processing_resolution (`int`, *optional*, defaults to `None`):
375
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
376
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
377
+ value `None` resolves to the optimal value from the model config.
378
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
379
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
380
+ side of the output will equal to `processing_resolution`.
381
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
382
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
383
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
384
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
385
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
386
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
387
+ batch_size (`int`, *optional*, defaults to `1`):
388
+ Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
389
+ ensembling_kwargs (`dict`, *optional*, defaults to `None`)
390
+ Extra dictionary with arguments for precise ensembling control. The following options are available:
391
+ - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
392
+ every pixel location, can be either `"closest"` or `"mean"`.
393
+ latents (`torch.Tensor`, *optional*, defaults to `None`):
394
+ Latent noise tensors to replace the random initialization. These can be taken from the previous
395
+ function call's output.
396
+ generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
397
+ Random number generator object to ensure reproducibility.
398
+ output_type (`str`, *optional*, defaults to `"np"`):
399
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
400
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
401
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
402
+ When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
403
+ the `ensemble_size` argument is set to a value above 2.
404
+ output_latent (`bool`, *optional*, defaults to `False`):
405
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
406
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
407
+ `latents` argument.
408
+ return_dict (`bool`, *optional*, defaults to `True`):
409
+ Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
410
+
411
+ Examples:
412
+
413
+ Returns:
414
+ [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
415
+ If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
416
+ `tuple` is returned where the first element is the prediction, the second element is the uncertainty
417
+ (or `None`), and the third is the latent (or `None`).
418
+ """
419
+
420
+ # 0. Resolving variables.
421
+ device = self._execution_device
422
+ dtype = self.dtype
423
+
424
+ # Model-specific optimal default values leading to fast and reasonable results.
425
+ if num_inference_steps is None:
426
+ num_inference_steps = self.default_denoising_steps
427
+ if processing_resolution is None:
428
+ processing_resolution = self.default_processing_resolution
429
+
430
+ # 1. Check inputs.
431
+ num_images = self.check_inputs(
432
+ image,
433
+ num_inference_steps,
434
+ ensemble_size,
435
+ processing_resolution,
436
+ resample_method_input,
437
+ resample_method_output,
438
+ batch_size,
439
+ ensembling_kwargs,
440
+ latents,
441
+ generator,
442
+ output_type,
443
+ output_uncertainty,
444
+ )
445
+
446
+
447
+ # 2. Prepare empty text conditioning.
448
+ # Model invocation: self.tokenizer, self.text_encoder.
449
+ if self.empty_text_embedding is None:
450
+ prompt = ""
451
+ text_inputs = self.tokenizer(
452
+ prompt,
453
+ padding="do_not_pad",
454
+ max_length=self.tokenizer.model_max_length,
455
+ truncation=True,
456
+ return_tensors="pt",
457
+ )
458
+ text_input_ids = text_inputs.input_ids.to(device)
459
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
460
+
461
+
462
+
463
+ # 3. prepare prompt
464
+ if self.prompt_embeds is None:
465
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
466
+ self.prompt,
467
+ device,
468
+ num_images_per_prompt,
469
+ False,
470
+ negative_prompt,
471
+ prompt_embeds=prompt_embeds,
472
+ negative_prompt_embeds=None,
473
+ lora_scale=None,
474
+ clip_skip=None,
475
+ )
476
+ self.prompt_embeds = prompt_embeds
477
+ self.negative_prompt_embeds = negative_prompt_embeds
478
+
479
+
480
+
481
+ # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
482
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
483
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
484
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
485
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
486
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
487
+ # resolution can lead to loss of either fine details or global context in the output predictions.
488
+ if not skip_preprocess:
489
+ image, padding, original_resolution = self.image_processor.preprocess(
490
+ image, processing_resolution, resample_method_input, device, dtype
491
+ ) # [N,3,PPH,PPW]
492
+ else:
493
+ padding = (0, 0)
494
+ original_resolution = image.shape[2:]
495
+ # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
496
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
497
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
498
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
499
+ # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
500
+ # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
501
+ # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
502
+ # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
503
+ # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
504
+ # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
505
+ # Model invocation: self.vae.encoder.
506
+ image_latent, gauss_latent = self.prepare_latents(
507
+ image, latents, generator, ensemble_size, batch_size
508
+ ) # [N*E,4,h,w], [N*E,4,h,w]
509
+
510
+ del image
511
+
512
+
513
+ # 6. obtain control_output
514
+
515
+ cond_scale =controlnet_conditioning_scale
516
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
517
+ image_latent.detach(),
518
+ self.t_start,
519
+ encoder_hidden_states=self.prompt_embeds,
520
+ conditioning_scale=cond_scale,
521
+ guess_mode=False,
522
+ return_dict=False,
523
+ )
524
+
525
+
526
+ # 7. YOSO sampling
527
+ latent_x_t = self.unet(
528
+ gauss_latent,
529
+ self.t_start,
530
+ encoder_hidden_states=self.prompt_embeds,
531
+ down_block_additional_residuals=down_block_res_samples,
532
+ mid_block_additional_residual=mid_block_res_sample,
533
+ return_dict=False,
534
+ )[0]
535
+
536
+
537
+ del (
538
+ image_latent,
539
+ )
540
+
541
+ # decoder
542
+ prediction = self.decode_prediction(latent_x_t)
543
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
544
+
545
+ prediction = self.image_processor.resize_antialias(
546
+ prediction, original_resolution, resample_method_output, is_aa=False
547
+ ) # [N,3,H,W]
548
+ prediction = self.normalize_normals(prediction) # [N,3,H,W]
549
+
550
+ if output_type == "np":
551
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
552
+
553
+ # 11. Offload all models
554
+ self.maybe_free_model_hooks()
555
+
556
+ return YosoNormalsOutput(
557
+ prediction=prediction,
558
+ latent=latent_x_t,
559
+ gauss_latent=gauss_latent,
560
+ )
561
+
562
+ # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
563
+ def prepare_latents(
564
+ self,
565
+ image: torch.Tensor,
566
+ latents: Optional[torch.Tensor],
567
+ generator: Optional[torch.Generator],
568
+ ensemble_size: int,
569
+ batch_size: int,
570
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
571
+ def retrieve_latents(encoder_output):
572
+ if hasattr(encoder_output, "latent_dist"):
573
+ return encoder_output.latent_dist.mode()
574
+ elif hasattr(encoder_output, "latents"):
575
+ return encoder_output.latents
576
+ else:
577
+ raise AttributeError("Could not access latents of provided encoder_output")
578
+
579
+
580
+
581
+ image_latent = torch.cat(
582
+ [
583
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
584
+ for i in range(0, image.shape[0], batch_size)
585
+ ],
586
+ dim=0,
587
+ ) # [N,4,h,w]
588
+ image_latent = image_latent * self.vae.config.scaling_factor
589
+ image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
590
+
591
+ pred_latent = self.gauss_latent
592
+ if pred_latent is None:
593
+ self.gauss_latent = torch.randn_like(image_latent)
594
+ pred_latent = self.gauss_latent
595
+
596
+ return image_latent, pred_latent
597
+
598
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
599
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
600
+ raise ValueError(
601
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
602
+ )
603
+
604
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
605
+
606
+ prediction = self.normalize_normals(prediction) # [B,3,H,W]
607
+
608
+ return prediction # [B,3,H,W]
609
+
610
+ @staticmethod
611
+ def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
612
+ if normals.dim() != 4 or normals.shape[1] != 3:
613
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
614
+
615
+ norm = torch.norm(normals, dim=1, keepdim=True)
616
+ normals /= norm.clamp(min=eps)
617
+
618
+ return normals
619
+
620
+ @staticmethod
621
+ def ensemble_normals(
622
+ normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
623
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
624
+ """
625
+ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
626
+ the number of ensemble members for a given prediction of size `(H x W)`.
627
+
628
+ Args:
629
+ normals (`torch.Tensor`):
630
+ Input ensemble normals maps.
631
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
632
+ Whether to output uncertainty map.
633
+ reduction (`str`, *optional*, defaults to `"closest"`):
634
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
635
+ `"mean"`.
636
+
637
+ Returns:
638
+ A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
639
+ uncertainties of shape `(1, 1, H, W)`.
640
+ """
641
+ if normals.dim() != 4 or normals.shape[1] != 3:
642
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
643
+ if reduction not in ("closest", "mean"):
644
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
645
+
646
+ mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
647
+ mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
648
+
649
+ sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
650
+ sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
651
+
652
+ uncertainty = None
653
+ if output_uncertainty:
654
+ uncertainty = sim_cos.arccos() # [E,1,H,W]
655
+ uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
656
+
657
+ if reduction == "mean":
658
+ return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
659
+
660
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
661
+ closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
662
+ closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
663
+
664
+ return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
665
+
666
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
667
+ def retrieve_timesteps(
668
+ scheduler,
669
+ num_inference_steps: Optional[int] = None,
670
+ device: Optional[Union[str, torch.device]] = None,
671
+ timesteps: Optional[List[int]] = None,
672
+ sigmas: Optional[List[float]] = None,
673
+ **kwargs,
674
+ ):
675
+ """
676
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
677
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
678
+
679
+ Args:
680
+ scheduler (`SchedulerMixin`):
681
+ The scheduler to get timesteps from.
682
+ num_inference_steps (`int`):
683
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
684
+ must be `None`.
685
+ device (`str` or `torch.device`, *optional*):
686
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
687
+ timesteps (`List[int]`, *optional*):
688
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
689
+ `num_inference_steps` and `sigmas` must be `None`.
690
+ sigmas (`List[float]`, *optional*):
691
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
692
+ `num_inference_steps` and `timesteps` must be `None`.
693
+
694
+ Returns:
695
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
696
+ second element is the number of inference steps.
697
+ """
698
+ if timesteps is not None and sigmas is not None:
699
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
700
+ if timesteps is not None:
701
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
702
+ if not accepts_timesteps:
703
+ raise ValueError(
704
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
705
+ f" timestep schedules. Please check whether you are using the correct scheduler."
706
+ )
707
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
708
+ timesteps = scheduler.timesteps
709
+ num_inference_steps = len(timesteps)
710
+ elif sigmas is not None:
711
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
712
+ if not accept_sigmas:
713
+ raise ValueError(
714
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
715
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
716
+ )
717
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
718
+ timesteps = scheduler.timesteps
719
+ num_inference_steps = len(timesteps)
720
+ else:
721
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
722
+ timesteps = scheduler.timesteps
723
+ return timesteps, num_inference_steps
stablenormal/scheduler/__init__.py ADDED
File without changes
stablenormal/stablecontrolnet.py ADDED
@@ -0,0 +1,1354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
27
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
28
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
29
+ from ...models.lora import adjust_lora_scale_text_encoder
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ USE_PEFT_BACKEND,
33
+ deprecate,
34
+ logging,
35
+ replace_example_docstring,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+ from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
40
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
+ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
42
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
+ from .multicontrolnet import MultiControlNetModel
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> # !pip install opencv-python transformers accelerate
53
+ >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
54
+ >>> from diffusers.utils import load_image
55
+ >>> import numpy as np
56
+ >>> import torch
57
+
58
+ >>> import cv2
59
+ >>> from PIL import Image
60
+
61
+ >>> # download an image
62
+ >>> image = load_image(
63
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
64
+ ... )
65
+ >>> image = np.array(image)
66
+
67
+ >>> # get canny image
68
+ >>> image = cv2.Canny(image, 100, 200)
69
+ >>> image = image[:, :, None]
70
+ >>> image = np.concatenate([image, image, image], axis=2)
71
+ >>> canny_image = Image.fromarray(image)
72
+
73
+ >>> # load control net and stable diffusion v1-5
74
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
75
+ >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
76
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
77
+ ... )
78
+
79
+ >>> # speed up diffusion process with faster scheduler and memory optimization
80
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
81
+ >>> # remove following line if xformers is not installed
82
+ >>> pipe.enable_xformers_memory_efficient_attention()
83
+
84
+ >>> pipe.enable_model_cpu_offload()
85
+
86
+ >>> # generate image
87
+ >>> generator = torch.manual_seed(0)
88
+ >>> image = pipe(
89
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
90
+ ... ).images[0]
91
+ ```
92
+ """
93
+
94
+
95
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
96
+ def retrieve_timesteps(
97
+ scheduler,
98
+ num_inference_steps: Optional[int] = None,
99
+ device: Optional[Union[str, torch.device]] = None,
100
+ timesteps: Optional[List[int]] = None,
101
+ sigmas: Optional[List[float]] = None,
102
+ **kwargs,
103
+ ):
104
+ """
105
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
106
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107
+
108
+ Args:
109
+ scheduler (`SchedulerMixin`):
110
+ The scheduler to get timesteps from.
111
+ num_inference_steps (`int`):
112
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
113
+ must be `None`.
114
+ device (`str` or `torch.device`, *optional*):
115
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
116
+ timesteps (`List[int]`, *optional*):
117
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
118
+ `num_inference_steps` and `sigmas` must be `None`.
119
+ sigmas (`List[float]`, *optional*):
120
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
121
+ `num_inference_steps` and `timesteps` must be `None`.
122
+
123
+ Returns:
124
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
125
+ second element is the number of inference steps.
126
+ """
127
+ if timesteps is not None and sigmas is not None:
128
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
129
+ if timesteps is not None:
130
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
131
+ if not accepts_timesteps:
132
+ raise ValueError(
133
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
134
+ f" timestep schedules. Please check whether you are using the correct scheduler."
135
+ )
136
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
137
+ timesteps = scheduler.timesteps
138
+ num_inference_steps = len(timesteps)
139
+ elif sigmas is not None:
140
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
+ if not accept_sigmas:
142
+ raise ValueError(
143
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
145
+ )
146
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ num_inference_steps = len(timesteps)
149
+ else:
150
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ return timesteps, num_inference_steps
153
+
154
+
155
+ class StableDiffusionControlNetPipeline(
156
+ DiffusionPipeline,
157
+ StableDiffusionMixin,
158
+ TextualInversionLoaderMixin,
159
+ LoraLoaderMixin,
160
+ IPAdapterMixin,
161
+ FromSingleFileMixin,
162
+ ):
163
+ r"""
164
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
165
+
166
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
167
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
168
+
169
+ The pipeline also inherits the following loading methods:
170
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
171
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
172
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
173
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
174
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
175
+
176
+ Args:
177
+ vae ([`AutoencoderKL`]):
178
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
179
+ text_encoder ([`~transformers.CLIPTextModel`]):
180
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
181
+ tokenizer ([`~transformers.CLIPTokenizer`]):
182
+ A `CLIPTokenizer` to tokenize text.
183
+ unet ([`UNet2DConditionModel`]):
184
+ A `UNet2DConditionModel` to denoise the encoded image latents.
185
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
186
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
187
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
188
+ additional conditioning.
189
+ scheduler ([`SchedulerMixin`]):
190
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
191
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
192
+ safety_checker ([`StableDiffusionSafetyChecker`]):
193
+ Classification module that estimates whether generated images could be considered offensive or harmful.
194
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
195
+ about a model's potential harms.
196
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
197
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
198
+ """
199
+
200
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
201
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
202
+ _exclude_from_cpu_offload = ["safety_checker"]
203
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
204
+
205
+ def __init__(
206
+ self,
207
+ vae: AutoencoderKL,
208
+ text_encoder: CLIPTextModel,
209
+ tokenizer: CLIPTokenizer,
210
+ unet: UNet2DConditionModel,
211
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
212
+ scheduler: KarrasDiffusionSchedulers,
213
+ safety_checker: StableDiffusionSafetyChecker,
214
+ feature_extractor: CLIPImageProcessor,
215
+ image_encoder: CLIPVisionModelWithProjection = None,
216
+ requires_safety_checker: bool = True,
217
+ ):
218
+ super().__init__()
219
+
220
+ if safety_checker is None and requires_safety_checker:
221
+ logger.warning(
222
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
223
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
224
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
225
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
226
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
227
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
228
+ )
229
+
230
+ if safety_checker is not None and feature_extractor is None:
231
+ raise ValueError(
232
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
233
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
234
+ )
235
+
236
+ if isinstance(controlnet, (list, tuple)):
237
+ controlnet = MultiControlNetModel(controlnet)
238
+
239
+ self.register_modules(
240
+ vae=vae,
241
+ text_encoder=text_encoder,
242
+ tokenizer=tokenizer,
243
+ unet=unet,
244
+ controlnet=controlnet,
245
+ scheduler=scheduler,
246
+ safety_checker=safety_checker,
247
+ feature_extractor=feature_extractor,
248
+ image_encoder=image_encoder,
249
+ )
250
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
251
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
252
+ self.control_image_processor = VaeImageProcessor(
253
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
254
+ )
255
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
256
+
257
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
258
+ def _encode_prompt(
259
+ self,
260
+ prompt,
261
+ device,
262
+ num_images_per_prompt,
263
+ do_classifier_free_guidance,
264
+ negative_prompt=None,
265
+ prompt_embeds: Optional[torch.Tensor] = None,
266
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
267
+ lora_scale: Optional[float] = None,
268
+ **kwargs,
269
+ ):
270
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
271
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
272
+
273
+ prompt_embeds_tuple = self.encode_prompt(
274
+ prompt=prompt,
275
+ device=device,
276
+ num_images_per_prompt=num_images_per_prompt,
277
+ do_classifier_free_guidance=do_classifier_free_guidance,
278
+ negative_prompt=negative_prompt,
279
+ prompt_embeds=prompt_embeds,
280
+ negative_prompt_embeds=negative_prompt_embeds,
281
+ lora_scale=lora_scale,
282
+ **kwargs,
283
+ )
284
+
285
+ # concatenate for backwards comp
286
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
287
+
288
+ return prompt_embeds
289
+
290
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
291
+ def encode_prompt(
292
+ self,
293
+ prompt,
294
+ device,
295
+ num_images_per_prompt,
296
+ do_classifier_free_guidance,
297
+ negative_prompt=None,
298
+ prompt_embeds: Optional[torch.Tensor] = None,
299
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
300
+ lora_scale: Optional[float] = None,
301
+ clip_skip: Optional[int] = None,
302
+ ):
303
+ r"""
304
+ Encodes the prompt into text encoder hidden states.
305
+
306
+ Args:
307
+ prompt (`str` or `List[str]`, *optional*):
308
+ prompt to be encoded
309
+ device: (`torch.device`):
310
+ torch device
311
+ num_images_per_prompt (`int`):
312
+ number of images that should be generated per prompt
313
+ do_classifier_free_guidance (`bool`):
314
+ whether to use classifier free guidance or not
315
+ negative_prompt (`str` or `List[str]`, *optional*):
316
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
317
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
318
+ less than `1`).
319
+ prompt_embeds (`torch.Tensor`, *optional*):
320
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
321
+ provided, text embeddings will be generated from `prompt` input argument.
322
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
323
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
324
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
325
+ argument.
326
+ lora_scale (`float`, *optional*):
327
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
328
+ clip_skip (`int`, *optional*):
329
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
330
+ the output of the pre-final layer will be used for computing the prompt embeddings.
331
+ """
332
+ # set lora scale so that monkey patched LoRA
333
+ # function of text encoder can correctly access it
334
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
335
+ self._lora_scale = lora_scale
336
+
337
+ # dynamically adjust the LoRA scale
338
+ if not USE_PEFT_BACKEND:
339
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
340
+ else:
341
+ scale_lora_layers(self.text_encoder, lora_scale)
342
+
343
+ if prompt is not None and isinstance(prompt, str):
344
+ batch_size = 1
345
+ elif prompt is not None and isinstance(prompt, list):
346
+ batch_size = len(prompt)
347
+ else:
348
+ batch_size = prompt_embeds.shape[0]
349
+
350
+ if prompt_embeds is None:
351
+ # textual inversion: process multi-vector tokens if necessary
352
+ if isinstance(self, TextualInversionLoaderMixin):
353
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
354
+
355
+ text_inputs = self.tokenizer(
356
+ prompt,
357
+ padding="max_length",
358
+ max_length=self.tokenizer.model_max_length,
359
+ truncation=True,
360
+ return_tensors="pt",
361
+ )
362
+ text_input_ids = text_inputs.input_ids
363
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
364
+
365
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
366
+ text_input_ids, untruncated_ids
367
+ ):
368
+ removed_text = self.tokenizer.batch_decode(
369
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
370
+ )
371
+ logger.warning(
372
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
373
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
374
+ )
375
+
376
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
377
+ attention_mask = text_inputs.attention_mask.to(device)
378
+ else:
379
+ attention_mask = None
380
+
381
+ if clip_skip is None:
382
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
383
+ prompt_embeds = prompt_embeds[0]
384
+ else:
385
+ prompt_embeds = self.text_encoder(
386
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
387
+ )
388
+ # Access the `hidden_states` first, that contains a tuple of
389
+ # all the hidden states from the encoder layers. Then index into
390
+ # the tuple to access the hidden states from the desired layer.
391
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
392
+ # We also need to apply the final LayerNorm here to not mess with the
393
+ # representations. The `last_hidden_states` that we typically use for
394
+ # obtaining the final prompt representations passes through the LayerNorm
395
+ # layer.
396
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
397
+
398
+ if self.text_encoder is not None:
399
+ prompt_embeds_dtype = self.text_encoder.dtype
400
+ elif self.unet is not None:
401
+ prompt_embeds_dtype = self.unet.dtype
402
+ else:
403
+ prompt_embeds_dtype = prompt_embeds.dtype
404
+
405
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
406
+
407
+ bs_embed, seq_len, _ = prompt_embeds.shape
408
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
409
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
410
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
411
+
412
+ # get unconditional embeddings for classifier free guidance
413
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
414
+ uncond_tokens: List[str]
415
+ if negative_prompt is None:
416
+ uncond_tokens = [""] * batch_size
417
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
418
+ raise TypeError(
419
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
420
+ f" {type(prompt)}."
421
+ )
422
+ elif isinstance(negative_prompt, str):
423
+ uncond_tokens = [negative_prompt]
424
+ elif batch_size != len(negative_prompt):
425
+ raise ValueError(
426
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
427
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
428
+ " the batch size of `prompt`."
429
+ )
430
+ else:
431
+ uncond_tokens = negative_prompt
432
+
433
+ # textual inversion: process multi-vector tokens if necessary
434
+ if isinstance(self, TextualInversionLoaderMixin):
435
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
436
+
437
+ max_length = prompt_embeds.shape[1]
438
+ uncond_input = self.tokenizer(
439
+ uncond_tokens,
440
+ padding="max_length",
441
+ max_length=max_length,
442
+ truncation=True,
443
+ return_tensors="pt",
444
+ )
445
+
446
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
447
+ attention_mask = uncond_input.attention_mask.to(device)
448
+ else:
449
+ attention_mask = None
450
+
451
+ negative_prompt_embeds = self.text_encoder(
452
+ uncond_input.input_ids.to(device),
453
+ attention_mask=attention_mask,
454
+ )
455
+ negative_prompt_embeds = negative_prompt_embeds[0]
456
+
457
+ if do_classifier_free_guidance:
458
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
459
+ seq_len = negative_prompt_embeds.shape[1]
460
+
461
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
462
+
463
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
464
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
465
+
466
+ if self.text_encoder is not None:
467
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
468
+ # Retrieve the original scale by scaling back the LoRA layers
469
+ unscale_lora_layers(self.text_encoder, lora_scale)
470
+
471
+ return prompt_embeds, negative_prompt_embeds
472
+
473
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
474
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
475
+ dtype = next(self.image_encoder.parameters()).dtype
476
+
477
+ if not isinstance(image, torch.Tensor):
478
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
479
+
480
+ image = image.to(device=device, dtype=dtype)
481
+ if output_hidden_states:
482
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
483
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
484
+ uncond_image_enc_hidden_states = self.image_encoder(
485
+ torch.zeros_like(image), output_hidden_states=True
486
+ ).hidden_states[-2]
487
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
488
+ num_images_per_prompt, dim=0
489
+ )
490
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
491
+ else:
492
+ image_embeds = self.image_encoder(image).image_embeds
493
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
494
+ uncond_image_embeds = torch.zeros_like(image_embeds)
495
+
496
+ return image_embeds, uncond_image_embeds
497
+
498
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
499
+ def prepare_ip_adapter_image_embeds(
500
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
501
+ ):
502
+ if ip_adapter_image_embeds is None:
503
+ if not isinstance(ip_adapter_image, list):
504
+ ip_adapter_image = [ip_adapter_image]
505
+
506
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
507
+ raise ValueError(
508
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
509
+ )
510
+
511
+ image_embeds = []
512
+ for single_ip_adapter_image, image_proj_layer in zip(
513
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
514
+ ):
515
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
516
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
517
+ single_ip_adapter_image, device, 1, output_hidden_state
518
+ )
519
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
520
+ single_negative_image_embeds = torch.stack(
521
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
522
+ )
523
+
524
+ if do_classifier_free_guidance:
525
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
526
+ single_image_embeds = single_image_embeds.to(device)
527
+
528
+ image_embeds.append(single_image_embeds)
529
+ else:
530
+ repeat_dims = [1]
531
+ image_embeds = []
532
+ for single_image_embeds in ip_adapter_image_embeds:
533
+ if do_classifier_free_guidance:
534
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
535
+ single_image_embeds = single_image_embeds.repeat(
536
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
537
+ )
538
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
539
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
540
+ )
541
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
542
+ else:
543
+ single_image_embeds = single_image_embeds.repeat(
544
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
545
+ )
546
+ image_embeds.append(single_image_embeds)
547
+
548
+ return image_embeds
549
+
550
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
551
+ def run_safety_checker(self, image, device, dtype):
552
+ if self.safety_checker is None:
553
+ has_nsfw_concept = None
554
+ else:
555
+ if torch.is_tensor(image):
556
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
557
+ else:
558
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
559
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
560
+ image, has_nsfw_concept = self.safety_checker(
561
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
562
+ )
563
+ return image, has_nsfw_concept
564
+
565
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
566
+ def decode_latents(self, latents):
567
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
568
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
569
+
570
+ latents = 1 / self.vae.config.scaling_factor * latents
571
+ image = self.vae.decode(latents, return_dict=False)[0]
572
+ image = (image / 2 + 0.5).clamp(0, 1)
573
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
574
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
575
+ return image
576
+
577
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
578
+ def prepare_extra_step_kwargs(self, generator, eta):
579
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
580
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
581
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
582
+ # and should be between [0, 1]
583
+
584
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
585
+ extra_step_kwargs = {}
586
+ if accepts_eta:
587
+ extra_step_kwargs["eta"] = eta
588
+
589
+ # check if the scheduler accepts generator
590
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
591
+ if accepts_generator:
592
+ extra_step_kwargs["generator"] = generator
593
+ return extra_step_kwargs
594
+
595
+ def check_inputs(
596
+ self,
597
+ prompt,
598
+ image,
599
+ callback_steps,
600
+ negative_prompt=None,
601
+ prompt_embeds=None,
602
+ negative_prompt_embeds=None,
603
+ ip_adapter_image=None,
604
+ ip_adapter_image_embeds=None,
605
+ controlnet_conditioning_scale=1.0,
606
+ control_guidance_start=0.0,
607
+ control_guidance_end=1.0,
608
+ callback_on_step_end_tensor_inputs=None,
609
+ ):
610
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
611
+ raise ValueError(
612
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
613
+ f" {type(callback_steps)}."
614
+ )
615
+
616
+ if callback_on_step_end_tensor_inputs is not None and not all(
617
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
618
+ ):
619
+ raise ValueError(
620
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
621
+ )
622
+
623
+ if prompt is not None and prompt_embeds is not None:
624
+ raise ValueError(
625
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
626
+ " only forward one of the two."
627
+ )
628
+ elif prompt is None and prompt_embeds is None:
629
+ raise ValueError(
630
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
631
+ )
632
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
633
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
634
+
635
+ if negative_prompt is not None and negative_prompt_embeds is not None:
636
+ raise ValueError(
637
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
638
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
639
+ )
640
+
641
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
642
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
643
+ raise ValueError(
644
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
645
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
646
+ f" {negative_prompt_embeds.shape}."
647
+ )
648
+
649
+ # Check `image`
650
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
651
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
652
+ )
653
+ if (
654
+ isinstance(self.controlnet, ControlNetModel)
655
+ or is_compiled
656
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
657
+ ):
658
+ self.check_image(image, prompt, prompt_embeds)
659
+ elif (
660
+ isinstance(self.controlnet, MultiControlNetModel)
661
+ or is_compiled
662
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
663
+ ):
664
+ if not isinstance(image, list):
665
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
666
+
667
+ # When `image` is a nested list:
668
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
669
+ elif any(isinstance(i, list) for i in image):
670
+ transposed_image = [list(t) for t in zip(*image)]
671
+ if len(transposed_image) != len(self.controlnet.nets):
672
+ raise ValueError(
673
+ f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
674
+ )
675
+ for image_ in transposed_image:
676
+ self.check_image(image_, prompt, prompt_embeds)
677
+ elif len(image) != len(self.controlnet.nets):
678
+ raise ValueError(
679
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
680
+ )
681
+ else:
682
+ for image_ in image:
683
+ self.check_image(image_, prompt, prompt_embeds)
684
+ else:
685
+ assert False
686
+
687
+ # Check `controlnet_conditioning_scale`
688
+ if (
689
+ isinstance(self.controlnet, ControlNetModel)
690
+ or is_compiled
691
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
692
+ ):
693
+ if not isinstance(controlnet_conditioning_scale, float):
694
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
695
+ elif (
696
+ isinstance(self.controlnet, MultiControlNetModel)
697
+ or is_compiled
698
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
699
+ ):
700
+ if isinstance(controlnet_conditioning_scale, list):
701
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
702
+ raise ValueError(
703
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
704
+ "The conditioning scale must be fixed across the batch."
705
+ )
706
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
707
+ self.controlnet.nets
708
+ ):
709
+ raise ValueError(
710
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
711
+ " the same length as the number of controlnets"
712
+ )
713
+ else:
714
+ assert False
715
+
716
+ if not isinstance(control_guidance_start, (tuple, list)):
717
+ control_guidance_start = [control_guidance_start]
718
+
719
+ if not isinstance(control_guidance_end, (tuple, list)):
720
+ control_guidance_end = [control_guidance_end]
721
+
722
+ if len(control_guidance_start) != len(control_guidance_end):
723
+ raise ValueError(
724
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
725
+ )
726
+
727
+ if isinstance(self.controlnet, MultiControlNetModel):
728
+ if len(control_guidance_start) != len(self.controlnet.nets):
729
+ raise ValueError(
730
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
731
+ )
732
+
733
+ for start, end in zip(control_guidance_start, control_guidance_end):
734
+ if start >= end:
735
+ raise ValueError(
736
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
737
+ )
738
+ if start < 0.0:
739
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
740
+ if end > 1.0:
741
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
742
+
743
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
744
+ raise ValueError(
745
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
746
+ )
747
+
748
+ if ip_adapter_image_embeds is not None:
749
+ if not isinstance(ip_adapter_image_embeds, list):
750
+ raise ValueError(
751
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
752
+ )
753
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
754
+ raise ValueError(
755
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
756
+ )
757
+
758
+ def check_image(self, image, prompt, prompt_embeds):
759
+ image_is_pil = isinstance(image, PIL.Image.Image)
760
+ image_is_tensor = isinstance(image, torch.Tensor)
761
+ image_is_np = isinstance(image, np.ndarray)
762
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
763
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
764
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
765
+
766
+ if (
767
+ not image_is_pil
768
+ and not image_is_tensor
769
+ and not image_is_np
770
+ and not image_is_pil_list
771
+ and not image_is_tensor_list
772
+ and not image_is_np_list
773
+ ):
774
+ raise TypeError(
775
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
776
+ )
777
+
778
+ if image_is_pil:
779
+ image_batch_size = 1
780
+ else:
781
+ image_batch_size = len(image)
782
+
783
+ if prompt is not None and isinstance(prompt, str):
784
+ prompt_batch_size = 1
785
+ elif prompt is not None and isinstance(prompt, list):
786
+ prompt_batch_size = len(prompt)
787
+ elif prompt_embeds is not None:
788
+ prompt_batch_size = prompt_embeds.shape[0]
789
+
790
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
791
+ raise ValueError(
792
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
793
+ )
794
+
795
+ def prepare_image(
796
+ self,
797
+ image,
798
+ width,
799
+ height,
800
+ batch_size,
801
+ num_images_per_prompt,
802
+ device,
803
+ dtype,
804
+ do_classifier_free_guidance=False,
805
+ guess_mode=False,
806
+ ):
807
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
808
+ image_batch_size = image.shape[0]
809
+
810
+ if image_batch_size == 1:
811
+ repeat_by = batch_size
812
+ else:
813
+ # image batch size is the same as prompt batch size
814
+ repeat_by = num_images_per_prompt
815
+
816
+ image = image.repeat_interleave(repeat_by, dim=0)
817
+
818
+ image = image.to(device=device, dtype=dtype)
819
+
820
+ if do_classifier_free_guidance and not guess_mode:
821
+ image = torch.cat([image] * 2)
822
+
823
+ return image
824
+
825
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
826
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
827
+ shape = (
828
+ batch_size,
829
+ num_channels_latents,
830
+ int(height) // self.vae_scale_factor,
831
+ int(width) // self.vae_scale_factor,
832
+ )
833
+ if isinstance(generator, list) and len(generator) != batch_size:
834
+ raise ValueError(
835
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
836
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
837
+ )
838
+
839
+ if latents is None:
840
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
841
+ else:
842
+ latents = latents.to(device)
843
+
844
+ # scale the initial noise by the standard deviation required by the scheduler
845
+ latents = latents * self.scheduler.init_noise_sigma
846
+ return latents
847
+
848
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
849
+ def get_guidance_scale_embedding(
850
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
851
+ ) -> torch.Tensor:
852
+ """
853
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
854
+
855
+ Args:
856
+ w (`torch.Tensor`):
857
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
858
+ embedding_dim (`int`, *optional*, defaults to 512):
859
+ Dimension of the embeddings to generate.
860
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
861
+ Data type of the generated embeddings.
862
+
863
+ Returns:
864
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
865
+ """
866
+ assert len(w.shape) == 1
867
+ w = w * 1000.0
868
+
869
+ half_dim = embedding_dim // 2
870
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
871
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
872
+ emb = w.to(dtype)[:, None] * emb[None, :]
873
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
874
+ if embedding_dim % 2 == 1: # zero pad
875
+ emb = torch.nn.functional.pad(emb, (0, 1))
876
+ assert emb.shape == (w.shape[0], embedding_dim)
877
+ return emb
878
+
879
+ @property
880
+ def guidance_scale(self):
881
+ return self._guidance_scale
882
+
883
+ @property
884
+ def clip_skip(self):
885
+ return self._clip_skip
886
+
887
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
888
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
889
+ # corresponds to doing no classifier free guidance.
890
+ @property
891
+ def do_classifier_free_guidance(self):
892
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
893
+
894
+ @property
895
+ def cross_attention_kwargs(self):
896
+ return self._cross_attention_kwargs
897
+
898
+ @property
899
+ def num_timesteps(self):
900
+ return self._num_timesteps
901
+
902
+ @torch.no_grad()
903
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
904
+ def __call__(
905
+ self,
906
+ prompt: Union[str, List[str]] = None,
907
+ image: PipelineImageInput = None,
908
+ height: Optional[int] = None,
909
+ width: Optional[int] = None,
910
+ num_inference_steps: int = 50,
911
+ timesteps: List[int] = None,
912
+ sigmas: List[float] = None,
913
+ guidance_scale: float = 7.5,
914
+ negative_prompt: Optional[Union[str, List[str]]] = None,
915
+ num_images_per_prompt: Optional[int] = 1,
916
+ eta: float = 0.0,
917
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
918
+ latents: Optional[torch.Tensor] = None,
919
+ prompt_embeds: Optional[torch.Tensor] = None,
920
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
921
+ ip_adapter_image: Optional[PipelineImageInput] = None,
922
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
923
+ output_type: Optional[str] = "pil",
924
+ return_dict: bool = True,
925
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
926
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
927
+ guess_mode: bool = False,
928
+ control_guidance_start: Union[float, List[float]] = 0.0,
929
+ control_guidance_end: Union[float, List[float]] = 1.0,
930
+ clip_skip: Optional[int] = None,
931
+ callback_on_step_end: Optional[
932
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
933
+ ] = None,
934
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
935
+ **kwargs,
936
+ ):
937
+ r"""
938
+ The call function to the pipeline for generation.
939
+
940
+ Args:
941
+ prompt (`str` or `List[str]`, *optional*):
942
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
943
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
944
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
945
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
946
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
947
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
948
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
949
+ images must be passed as a list such that each element of the list can be correctly batched for input
950
+ to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
951
+ ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
952
+ ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
953
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
954
+ The height in pixels of the generated image.
955
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
956
+ The width in pixels of the generated image.
957
+ num_inference_steps (`int`, *optional*, defaults to 50):
958
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
959
+ expense of slower inference.
960
+ timesteps (`List[int]`, *optional*):
961
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
962
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
963
+ passed will be used. Must be in descending order.
964
+ sigmas (`List[float]`, *optional*):
965
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
966
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
967
+ will be used.
968
+ guidance_scale (`float`, *optional*, defaults to 7.5):
969
+ A higher guidance scale value encourages the model to generate images closely linked to the text
970
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
971
+ negative_prompt (`str` or `List[str]`, *optional*):
972
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
973
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
974
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
975
+ The number of images to generate per prompt.
976
+ eta (`float`, *optional*, defaults to 0.0):
977
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
978
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
979
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
980
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
981
+ generation deterministic.
982
+ latents (`torch.Tensor`, *optional*):
983
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
984
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
985
+ tensor is generated by sampling using the supplied random `generator`.
986
+ prompt_embeds (`torch.Tensor`, *optional*):
987
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
988
+ provided, text embeddings are generated from the `prompt` input argument.
989
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
990
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
991
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
992
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
993
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
994
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
995
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
996
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
997
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
998
+ output_type (`str`, *optional*, defaults to `"pil"`):
999
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1000
+ return_dict (`bool`, *optional*, defaults to `True`):
1001
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1002
+ plain tuple.
1003
+ callback (`Callable`, *optional*):
1004
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1005
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1006
+ callback_steps (`int`, *optional*, defaults to 1):
1007
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1008
+ every step.
1009
+ cross_attention_kwargs (`dict`, *optional*):
1010
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1011
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1012
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1013
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1014
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1015
+ the corresponding scale as a list.
1016
+ guess_mode (`bool`, *optional*, defaults to `False`):
1017
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
1018
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1019
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1020
+ The percentage of total steps at which the ControlNet starts applying.
1021
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1022
+ The percentage of total steps at which the ControlNet stops applying.
1023
+ clip_skip (`int`, *optional*):
1024
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1025
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1026
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1027
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1028
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1029
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1030
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1031
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1032
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1033
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1034
+ `._callback_tensor_inputs` attribute of your pipeline class.
1035
+
1036
+ Examples:
1037
+
1038
+ Returns:
1039
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1040
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1041
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1042
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1043
+ "not-safe-for-work" (nsfw) content.
1044
+ """
1045
+
1046
+ callback = kwargs.pop("callback", None)
1047
+ callback_steps = kwargs.pop("callback_steps", None)
1048
+
1049
+ if callback is not None:
1050
+ deprecate(
1051
+ "callback",
1052
+ "1.0.0",
1053
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1054
+ )
1055
+ if callback_steps is not None:
1056
+ deprecate(
1057
+ "callback_steps",
1058
+ "1.0.0",
1059
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1060
+ )
1061
+
1062
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1063
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1064
+
1065
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1066
+
1067
+ # align format for control guidance
1068
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1069
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1070
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1071
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1072
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1073
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1074
+ control_guidance_start, control_guidance_end = (
1075
+ mult * [control_guidance_start],
1076
+ mult * [control_guidance_end],
1077
+ )
1078
+
1079
+ # 1. Check inputs. Raise error if not correct
1080
+ self.check_inputs(
1081
+ prompt,
1082
+ image,
1083
+ callback_steps,
1084
+ negative_prompt,
1085
+ prompt_embeds,
1086
+ negative_prompt_embeds,
1087
+ ip_adapter_image,
1088
+ ip_adapter_image_embeds,
1089
+ controlnet_conditioning_scale,
1090
+ control_guidance_start,
1091
+ control_guidance_end,
1092
+ callback_on_step_end_tensor_inputs,
1093
+ )
1094
+
1095
+ self._guidance_scale = guidance_scale
1096
+ self._clip_skip = clip_skip
1097
+ self._cross_attention_kwargs = cross_attention_kwargs
1098
+
1099
+ # 2. Define call parameters
1100
+ if prompt is not None and isinstance(prompt, str):
1101
+ batch_size = 1
1102
+ elif prompt is not None and isinstance(prompt, list):
1103
+ batch_size = len(prompt)
1104
+ else:
1105
+ batch_size = prompt_embeds.shape[0]
1106
+
1107
+ device = self._execution_device
1108
+
1109
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1110
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1111
+
1112
+ global_pool_conditions = (
1113
+ controlnet.config.global_pool_conditions
1114
+ if isinstance(controlnet, ControlNetModel)
1115
+ else controlnet.nets[0].config.global_pool_conditions
1116
+ )
1117
+ guess_mode = guess_mode or global_pool_conditions
1118
+
1119
+ # 3. Encode input prompt
1120
+ text_encoder_lora_scale = (
1121
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1122
+ )
1123
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1124
+ prompt,
1125
+ device,
1126
+ num_images_per_prompt,
1127
+ self.do_classifier_free_guidance,
1128
+ negative_prompt,
1129
+ prompt_embeds=prompt_embeds,
1130
+ negative_prompt_embeds=negative_prompt_embeds,
1131
+ lora_scale=text_encoder_lora_scale,
1132
+ clip_skip=self.clip_skip,
1133
+ )
1134
+ # For classifier free guidance, we need to do two forward passes.
1135
+ # Here we concatenate the unconditional and text embeddings into a single batch
1136
+ # to avoid doing two forward passes
1137
+ if self.do_classifier_free_guidance:
1138
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1139
+
1140
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1141
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1142
+ ip_adapter_image,
1143
+ ip_adapter_image_embeds,
1144
+ device,
1145
+ batch_size * num_images_per_prompt,
1146
+ self.do_classifier_free_guidance,
1147
+ )
1148
+
1149
+ # 4. Prepare image
1150
+ if isinstance(controlnet, ControlNetModel):
1151
+ image = self.prepare_image(
1152
+ image=image,
1153
+ width=width,
1154
+ height=height,
1155
+ batch_size=batch_size * num_images_per_prompt,
1156
+ num_images_per_prompt=num_images_per_prompt,
1157
+ device=device,
1158
+ dtype=controlnet.dtype,
1159
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1160
+ guess_mode=guess_mode,
1161
+ )
1162
+ height, width = image.shape[-2:]
1163
+ elif isinstance(controlnet, MultiControlNetModel):
1164
+ images = []
1165
+
1166
+ # Nested lists as ControlNet condition
1167
+ if isinstance(image[0], list):
1168
+ # Transpose the nested image list
1169
+ image = [list(t) for t in zip(*image)]
1170
+
1171
+ for image_ in image:
1172
+ image_ = self.prepare_image(
1173
+ image=image_,
1174
+ width=width,
1175
+ height=height,
1176
+ batch_size=batch_size * num_images_per_prompt,
1177
+ num_images_per_prompt=num_images_per_prompt,
1178
+ device=device,
1179
+ dtype=controlnet.dtype,
1180
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1181
+ guess_mode=guess_mode,
1182
+ )
1183
+
1184
+ images.append(image_)
1185
+
1186
+ image = images
1187
+ height, width = image[0].shape[-2:]
1188
+ else:
1189
+ assert False
1190
+
1191
+ # 5. Prepare timesteps
1192
+ timesteps, num_inference_steps = retrieve_timesteps(
1193
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1194
+ )
1195
+ self._num_timesteps = len(timesteps)
1196
+
1197
+ # 6. Prepare latent variables
1198
+ num_channels_latents = self.unet.config.in_channels
1199
+ latents = self.prepare_latents(
1200
+ batch_size * num_images_per_prompt,
1201
+ num_channels_latents,
1202
+ height,
1203
+ width,
1204
+ prompt_embeds.dtype,
1205
+ device,
1206
+ generator,
1207
+ latents,
1208
+ )
1209
+
1210
+ # 6.5 Optionally get Guidance Scale Embedding
1211
+ timestep_cond = None
1212
+ if self.unet.config.time_cond_proj_dim is not None:
1213
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1214
+ timestep_cond = self.get_guidance_scale_embedding(
1215
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1216
+ ).to(device=device, dtype=latents.dtype)
1217
+
1218
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1219
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1220
+
1221
+ # 7.1 Add image embeds for IP-Adapter
1222
+ added_cond_kwargs = (
1223
+ {"image_embeds": image_embeds}
1224
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1225
+ else None
1226
+ )
1227
+
1228
+ # 7.2 Create tensor stating which controlnets to keep
1229
+ controlnet_keep = []
1230
+ for i in range(len(timesteps)):
1231
+ keeps = [
1232
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1233
+ for s, e in zip(control_guidance_start, control_guidance_end)
1234
+ ]
1235
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1236
+
1237
+ # 8. Denoising loop
1238
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1239
+ is_unet_compiled = is_compiled_module(self.unet)
1240
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
1241
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1242
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1243
+ for i, t in enumerate(timesteps):
1244
+ # Relevant thread:
1245
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1246
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1247
+ torch._inductor.cudagraph_mark_step_begin()
1248
+ # expand the latents if we are doing classifier free guidance
1249
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1250
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1251
+
1252
+ # controlnet(s) inference
1253
+ if guess_mode and self.do_classifier_free_guidance:
1254
+ # Infer ControlNet only for the conditional batch.
1255
+ control_model_input = latents
1256
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1257
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1258
+ else:
1259
+ control_model_input = latent_model_input
1260
+ controlnet_prompt_embeds = prompt_embeds
1261
+
1262
+ if isinstance(controlnet_keep[i], list):
1263
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1264
+ else:
1265
+ controlnet_cond_scale = controlnet_conditioning_scale
1266
+ if isinstance(controlnet_cond_scale, list):
1267
+ controlnet_cond_scale = controlnet_cond_scale[0]
1268
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1269
+
1270
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1271
+ control_model_input,
1272
+ t,
1273
+ encoder_hidden_states=controlnet_prompt_embeds,
1274
+ controlnet_cond=image,
1275
+ conditioning_scale=cond_scale,
1276
+ guess_mode=guess_mode,
1277
+ return_dict=False,
1278
+ )
1279
+
1280
+ if guess_mode and self.do_classifier_free_guidance:
1281
+ # Infered ControlNet only for the conditional batch.
1282
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1283
+ # add 0 to the unconditional batch to keep it unchanged.
1284
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1285
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1286
+
1287
+ # predict the noise residual
1288
+ noise_pred = self.unet(
1289
+ latent_model_input,
1290
+ t,
1291
+ encoder_hidden_states=prompt_embeds,
1292
+ timestep_cond=timestep_cond,
1293
+ cross_attention_kwargs=self.cross_attention_kwargs,
1294
+ down_block_additional_residuals=down_block_res_samples,
1295
+ mid_block_additional_residual=mid_block_res_sample,
1296
+ added_cond_kwargs=added_cond_kwargs,
1297
+ return_dict=False,
1298
+ )[0]
1299
+
1300
+ # perform guidance
1301
+ if self.do_classifier_free_guidance:
1302
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1303
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1304
+
1305
+ # compute the previous noisy sample x_t -> x_t-1
1306
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1307
+
1308
+ if callback_on_step_end is not None:
1309
+ callback_kwargs = {}
1310
+ for k in callback_on_step_end_tensor_inputs:
1311
+ callback_kwargs[k] = locals()[k]
1312
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1313
+
1314
+ latents = callback_outputs.pop("latents", latents)
1315
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1316
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1317
+
1318
+ # call the callback, if provided
1319
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1320
+ progress_bar.update()
1321
+ if callback is not None and i % callback_steps == 0:
1322
+ step_idx = i // getattr(self.scheduler, "order", 1)
1323
+ callback(step_idx, t, latents)
1324
+
1325
+ # If we do sequential model offloading, let's offload unet and controlnet
1326
+ # manually for max memory savings
1327
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1328
+ self.unet.to("cpu")
1329
+ self.controlnet.to("cpu")
1330
+ torch.cuda.empty_cache()
1331
+
1332
+ if not output_type == "latent":
1333
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1334
+ 0
1335
+ ]
1336
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1337
+ else:
1338
+ image = latents
1339
+ has_nsfw_concept = None
1340
+
1341
+ if has_nsfw_concept is None:
1342
+ do_denormalize = [True] * image.shape[0]
1343
+ else:
1344
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1345
+
1346
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1347
+
1348
+ # Offload all models
1349
+ self.maybe_free_model_hooks()
1350
+
1351
+ if not return_dict:
1352
+ return (image, has_nsfw_concept)
1353
+
1354
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)