diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..fb71045c38df3fd98fcb0f3d6170905e91dca55a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,27 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_basic.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_cut_for_inpaint.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_image_batch.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_inpaint.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_SDXL_basic.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_SDXL_upscale.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_with_CN.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_with_ELLA.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_with_IPA.png filter=lfs diff=lfs merge=lfs -text
+example/BrushNet_with_LoRA.png filter=lfs diff=lfs merge=lfs -text
+example/goblin_toy.png filter=lfs diff=lfs merge=lfs -text
+example/object_removal_fail.png filter=lfs diff=lfs merge=lfs -text
+example/object_removal.png filter=lfs diff=lfs merge=lfs -text
+example/params1.png filter=lfs diff=lfs merge=lfs -text
+example/params13.png filter=lfs diff=lfs merge=lfs -text
+example/PowerPaint_object_removal.png filter=lfs diff=lfs merge=lfs -text
+example/PowerPaint_outpaint.png filter=lfs diff=lfs merge=lfs -text
+example/RAUNet1.png filter=lfs diff=lfs merge=lfs -text
+example/RAUNet2.png filter=lfs diff=lfs merge=lfs -text
+example/sleeping_cat_inpaint1.png filter=lfs diff=lfs merge=lfs -text
+example/sleeping_cat_inpaint3.png filter=lfs diff=lfs merge=lfs -text
+example/sleeping_cat_inpaint5.png filter=lfs diff=lfs merge=lfs -text
+example/sleeping_cat_inpaint6.png filter=lfs diff=lfs merge=lfs -text
+example/test_image3.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2c6f4909b38fa4a3db11059993fe9de1d1434cf0
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,22 @@
+name: Publish to Comfy registry
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ - master
+ paths:
+ - "pyproject.toml"
+
+jobs:
+ publish-node:
+ name: Publish Custom Node to registry
+ runs-on: ubuntu-latest
+ steps:
+ - name: Check out code
+ uses: actions/checkout@v4
+ - name: Publish Custom Node
+ uses: Comfy-Org/publish-node-action@main
+ with:
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
+ personal_access_token: ${{ secrets.COMFY_REGISTRY_KEY }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8705b9691ed1e2309fdf6aab5da9c7c512365e87
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+**/__pycache__/
+.vscode/
+*.tmp
+*.dblite
+*.log
+*.part
+
+Dockerfile
diff --git a/BIG_IMAGE.md b/BIG_IMAGE.md
new file mode 100644
index 0000000000000000000000000000000000000000..84f735aae9605a80aa852ca8a923f3cf32269160
--- /dev/null
+++ b/BIG_IMAGE.md
@@ -0,0 +1,6 @@
+![example workflow](example/BrushNet_cut_for_inpaint.png?raw=true)
+
+[workflow](example/BrushNet_cut_for_inpaint.json)
+
+When you work with big image and your inpaint mask is small it is better to cut part of the image, work with it and then blend it back.
+I created a node for such workflow, see example.
diff --git a/CN.md b/CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..88263a29be0ff26b2a3706e2be9f6da44fa92614
--- /dev/null
+++ b/CN.md
@@ -0,0 +1,39 @@
+## ControlNet Canny Edge
+
+Let's take the pestered cake and try to inpaint it again. Now I would like to use a sleeping cat for it:
+
+![sleeping cat](example/sleeping_cat.png?raw=true)
+
+I use Canny Edge node from [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux). Don't forget to resize canny edge mask to 512 pixels:
+
+![sleeping cat inpaint](example/sleeping_cat_inpaint1.png?raw=true)
+
+Let's look at the result:
+
+![sleeping cat inpaint](example/sleeping_cat_inpaint2.png?raw=true)
+
+The first problem I see here is some kind of object behind the cat. Such objects appear since the inpainting mask strictly aligns with the removed object, the cake in our case. To remove such artifact we should expand our mask a little:
+
+![sleeping cat inpaint](example/sleeping_cat_inpaint3.png?raw=true)
+
+Now. what's up with cat back and tail? Let's see the inpainting mask and canny edge mask side to side:
+
+![masks](example/sleeping_cat_inpaint4.png?raw=true)
+
+The inpainting works (mostly) only in masked (white) area, so we cut off cat's back. **The ControlNet mask should be inside the inpaint mask.**
+
+To address the issue I resized the mask to 256 pixels:
+
+![sleeping cat inpaint](example/sleeping_cat_inpaint5.png?raw=true)
+
+This is better but still have a room for improvement. The problem with edge mask downsampling is that edge lines tend to be broken and after some size we will got a mess:
+
+![sleeping cat inpaint](example/sleeping_cat_inpaint6.png?raw=true)
+
+Look at the edge mask, at this resolution it is so broken:
+
+![masks](example/sleeping_cat_mask.png?raw=true)
+
+
+
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/PARAMS.md b/PARAMS.md
new file mode 100644
index 0000000000000000000000000000000000000000..1be01bed6ee1bdbfa5a1431ea9b8265cd0c47ccc
--- /dev/null
+++ b/PARAMS.md
@@ -0,0 +1,47 @@
+## Start At and End At parameters usage
+
+### start_at
+
+Let's start with a ELLA outpaint [workflow](example/BrushNet_with_ELLA.json) and switch off Blend Inpaint node:
+
+![example workflow](example/params1.png?raw=true)
+
+For this example I use "wargaming shop showcase" prompt, `dpmpp_2m` deterministic sampler and `karras` scheduler with 15 steps. This is the result:
+
+![goblin in the shop](example/params2.png?raw=true)
+
+The `start_at` BrushNet node parameter allows us to delay BrushNet inference for some steps, so the base model will do all the job. Let's see what the result will be without BrushNet. For this I set up `start_at` parameter to 20 - it should be more then `steps` in KSampler node:
+
+![the shop](example/params3.png?raw=true)
+
+So, if we apply BrushNet from the beginning (`start_at` equals 0), the resulting scene will be heavily influenced by BrushNet image. The more we increase this parameter, the more scene will be based on prompt. Let's compare:
+
+| `start_at` = 1 | `start_at` = 2 | `start_at` = 3 |
+|:--------------:|:--------------:|:--------------:|
+| ![p1](example/params4.png?raw=true) | ![p2](example/params5.png?raw=true) | ![p3](example/params6.png?raw=true) |
+| `start_at` = 4 | `start_at` = 5 | `start_at` = 6 |
+| ![p1](example/params7.png?raw=true) | ![p2](example/params8.png?raw=true) | ![p3](example/params9.png?raw=true) |
+| `start_at` = 7 | `start_at` = 8 | `start_at` = 9 |
+| ![p1](example/params10.png?raw=true) | ![p2](example/params11.png?raw=true) | ![p3](example/params12.png?raw=true) |
+
+Look how the floor is aligned with toy's base - at some step it looses consistency. The results will depend on type of sampler and number of KSampler steps, of course.
+
+### end_at
+
+The `end_at` parameter switches off BrushNet at the last steps. If you use deterministic sampler it will only influences details on last steps, but stochastic samplers can change the whole scene. For a description of samplers see, for example, Matteo Spinelli's [video on ComfyUI basics](https://youtu.be/_C7kR2TFIX0?t=516).
+
+Here I use basic BrushNet inpaint [example](example/BrushNet_basic.json), with "intricate teapot" prompt, `dpmpp_2m` deterministic sampler and `karras` scheduler with 15 steps:
+
+![example workflow](example/params13.png?raw=true)
+
+There are almost no changes when we set 'end_at' paramter to 10, but starting from it:
+
+| `end_at` = 10 | `end_at` = 9 | `end_at` = 8 |
+|:--------------:|:--------------:|:--------------:|
+| ![p1](example/params14.png?raw=true) | ![p2](example/params15.png?raw=true) | ![p3](example/params16.png?raw=true) |
+| `end_at` = 7 | `end_at` = 6 | `end_at` = 5 |
+| ![p1](example/params17.png?raw=true) | ![p2](example/params18.png?raw=true) | ![p3](example/params19.png?raw=true) |
+| `end_at` = 4 | `end_at` = 3 | `end_at` = 2 |
+| ![p1](example/params20.png?raw=true) | ![p2](example/params21.png?raw=true) | ![p3](example/params22.png?raw=true) |
+
+You can see how the scene was completely redrawn.
diff --git a/RAUNET.md b/RAUNET.md
new file mode 100644
index 0000000000000000000000000000000000000000..05a75aeb8cbbeae3457e3b690809a569217205f8
--- /dev/null
+++ b/RAUNET.md
@@ -0,0 +1,39 @@
+During investigation of compatibility issues with [WASasquatch's FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced/tree/main) and [blepping's jank HiDiffusion](https://github.com/blepping/comfyui_jankhidiffusion) nodes I stumbled upon some quite hard problems. There are `FreeU` nodes in ComfyUI, but no such for HiDiffusion, so I decided to implement RAUNet on base of my BrushNet implementation. **blepping**, I am sorry. :)
+
+### RAUNet
+
+What is RAUNet? I know many of you saw and generate images with a lot of limbs, fingers and faces all morphed together.
+
+The authors of HiDiffusion invent simple, yet efficient trick to alleviate this problem. Here is an example:
+
+![example workflow](example/RAUNet1.png?raw=true)
+
+[workflow](example/RAUNet_basic.json)
+
+The left picture is created using ZavyChromaXL checkpoint on 2048x2048 canvas. The right one uses RAUNet.
+
+In my experience the node is helpful but quite sensitive to its parameters. And there is no universal solution - you should adjust them for every new image you generate. It also lowers model's imagination, you usually get only what you described in the prompt. Look at the example: in first you have a forest in the background, but RAUNet deleted all except fox which is described in the prompt.
+
+From the [paper](https://arxiv.org/abs/2311.17528): Diffusion models denoise from structures to details. RAU-Net introduces additional downsampling and upsampling operations, leading to a certain degree of information loss. In the early stages of denoising, RAU-Net can generate reasonable structures with minimal impact from information loss. However, in the later stages of denoising when generating fine details, the information loss in RAU-Net results in the loss of image details and a degradation in quality.
+
+### Parameters
+
+There are two independent parts in this node: DU (Downsample/Upsample) and XA (CrossAttention). The four parameters are the start and end steps for applying these parts.
+
+The Downsample/Upsample part lowers models degrees of freedom. If you apply it a lot (for more steps) the resulting images will have a lot of symmetries.
+
+The CrossAttension part lowers number of objects which model tracks in image.
+
+Usually you apply DU and after several steps apply XA, sometimes you will need only XA, you should try it yourself.
+
+### Compatibility
+
+It is compatible with BrushNet and most other nodes.
+
+This is ControlNet example. The lower image is pure model, the upper is after using RAUNet. You can see small fox and two tails in lower image.
+
+![example workflow](example/RAUNet2.png?raw=true)
+
+[workflow](example/RAUNet_with_CN.json)
+
+The node can be implemented for any model. Right now it can be applied to SD15 and SDXL models.
\ No newline at end of file
diff --git a/README.md b/README.md
index 5567f0dcfaa042303b5d44e070bb5faddf72f5b8..33e64587963bf35d7eba6e7b6c19a8f3df1fe8d7 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,262 @@
----
-title: BrushNET
-emoji: 🐢
-colorFrom: gray
-colorTo: indigo
-sdk: docker
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+## ComfyUI-BrushNet
+
+These are custom nodes for ComfyUI native implementation of
+
+- Brushnet: ["BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"](https://arxiv.org/abs/2403.06976)
+- PowerPaint: [A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting](https://arxiv.org/abs/2312.03594)
+- HiDiffusion: [HiDiffusion: Unlocking Higher-Resolution Creativity and Efficiency in Pretrained Diffusion Models](https://arxiv.org/abs/2311.17528)
+
+My contribution is limited to the ComfyUI adaptation, and all credit goes to the authors of the papers.
+
+## Updates
+
+May 16, 2024. Internal rework to improve compatibility with other nodes. [RAUNet](RAUNET.md) is implemented.
+
+May 12, 2024. CutForInpaint node, see [example](BIG_IMAGE.md).
+
+May 11, 2024. Image batch is implemented. You can even add BrushNet to AnimateDiff vid2vid workflow, but they don't work together - they are different models and both try to patch UNet. Added some more examples.
+
+May 6, 2024. PowerPaint v2 model is implemented. After update your workflow probably will not work. Don't panic! Check `end_at` parameter of BrushNode, if it equals 1, change it to some big number. Read about parameters in Usage section below.
+
+May 2, 2024. BrushNet SDXL is live. It needs positive and negative conditioning though, so workflow changes a little, see example.
+
+Apr 28, 2024. Another rework, sorry for inconvenience. But now BrushNet is native to ComfyUI. Famous cubiq's [IPAdapter Plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) is now working with BrushNet! I hope... :) Please, report any bugs you found.
+
+Apr 18, 2024. Complete rework, no more custom `diffusers` library. It is possible to use LoRA models.
+
+Apr 11, 2024. Initial commit.
+
+## Plans
+
+- [x] BrushNet SDXL
+- [x] PowerPaint v2
+- [x] Image batch
+
+## Installation
+
+Clone the repo into the `custom_nodes` directory and install the requirements:
+
+```
+git clone https://github.com/nullquant/ComfyUI-BrushNet.git
+pip install -r requirements.txt
+```
+
+Checkpoints of BrushNet can be downloaded from [here](https://drive.google.com/drive/folders/1fqmS1CEOvXCxNWFrsSYd_jHYXxrydh1n?usp=drive_link).
+
+The checkpoint in `segmentation_mask_brushnet_ckpt` provides checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The `random_mask_brushnet_ckpt` provides a more general ckpt for random mask shape.
+
+`segmentation_mask_brushnet_ckpt` and `random_mask_brushnet_ckpt` contains BrushNet for SD 1.5 models while
+`segmentation_mask_brushnet_ckpt_sdxl_v0` and `random_mask_brushnet_ckpt_sdxl_v0` for SDXL.
+
+You should place `diffusion_pytorch_model.safetensors` files to your `models/inpaint` folder. You can also specify `inpaint` folder in your `extra_model_paths.yaml`.
+
+For PowerPaint you should download three files. Both `diffusion_pytorch_model.safetensors` and `pytorch_model.bin` from [here](https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/tree/main/PowerPaint_Brushnet) should be placed in your `models/inpaint` folder.
+
+Also you need SD1.5 text encoder model `model.safetensors`. You can take it from [here](https://huggingface.co/ashllay/stable-diffusion-v1-5-archive/tree/main/text_encoder) or from another place. You can also use fp16 [version](https://huggingface.co/nmkd/stable-diffusion-1.5-fp16/tree/main/text_encoder). It should be placed in your `models/clip` folder.
+
+This is a structure of my `models/inpaint` folder:
+
+![inpaint folder](example/inpaint_folder.png?raw=true)
+
+Yours can be different.
+
+## Usage
+
+Below is an example for the intended workflow. The [workflow](example/BrushNet_basic.json) for the example can be found inside the 'example' directory.
+
+![example workflow](example/BrushNet_basic.png?raw=true)
+
+
+ SDXL
+
+![example workflow](example/BrushNet_SDXL_basic.png?raw=true)
+
+[workflow](example/BrushNet_SDXL_basic.json)
+
+
+
+
+ IPAdapter plus
+
+![example workflow](example/BrushNet_with_IPA.png?raw=true)
+
+[workflow](example/BrushNet_with_IPA.json)
+
+
+
+
+ LoRA
+
+![example workflow](example/BrushNet_with_LoRA.png?raw=true)
+
+[workflow](example/BrushNet_with_LoRA.json)
+
+
+
+
+ Blending inpaint
+
+![example workflow](example/BrushNet_inpaint.png?raw=true)
+
+Sometimes inference and VAE broke image, so you need to blend inpaint image with the original: [workflow](example/BrushNet_inpaint.json). You can see blurred and broken text after inpainting in the first image and how I suppose to repair it.
+
+
+
+
+ ControlNet
+
+![example workflow](example/BrushNet_with_CN.png?raw=true)
+
+[workflow](example/BrushNet_with_CN.json)
+
+[ControlNet canny edge](CN.md)
+
+
+
+
+ ELLA outpaint
+
+![example workflow](example/BrushNet_with_ELLA.png?raw=true)
+
+[workflow](example/BrushNet_with_ELLA.json)
+
+
+
+
+ Upscale
+
+![example workflow](example/BrushNet_SDXL_upscale.png?raw=true)
+
+[workflow](example/BrushNet_SDXL_upscale.json)
+
+To upscale you should use base model, not BrushNet. The same is true for conditioning. Latent upscaling between BrushNet and KSampler will not work or will give you wierd results. These limitations are due to structure of BrushNet and its influence on UNet calculations.
+
+
+
+
+ Image batch
+
+![example workflow](example/BrushNet_image_batch.png?raw=true)
+
+[workflow](example/BrushNet_image_batch.json)
+
+If you have OOM problems, you can use Evolved Sampling from [AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved):
+
+![example workflow](example/BrushNet_image_big_batch.png?raw=true)
+
+[workflow](example/BrushNet_image_big_batch.json)
+
+In Context Options set context_length to number of images which can be loaded into VRAM. Images will be processed in chunks of this size.
+
+
+
+
+
+ Big image inpaint
+
+![example workflow](example/BrushNet_cut_for_inpaint.png?raw=true)
+
+[workflow](example/BrushNet_cut_for_inpaint.json)
+
+When you work with big image and your inpaint mask is small it is better to cut part of the image, work with it and then blend it back.
+I created a node for such workflow, see example.
+
+
+
+
+
+ PowerPaint outpaint
+
+![example workflow](example/PowerPaint_outpaint.png?raw=true)
+
+[workflow](example/PowerPaint_outpaint.json)
+
+
+
+
+ PowerPaint object removal
+
+![example workflow](example/PowerPaint_object_removal.png?raw=true)
+
+[workflow](example/PowerPaint_object_removal.json)
+
+It is often hard to completely remove the object, especially if it is at the front:
+
+![object removal example](example/object_removal_fail.png?raw=true)
+
+You should try to add object description to negative prompt and describe empty scene, like here:
+
+![object removal example](example/object_removal.png?raw=true)
+
+
+
+### Parameters
+
+#### Brushnet Loader
+
+- `dtype`, defaults to `torch.float16`. The torch.dtype of BrushNet. If you have old GPU or NVIDIA 16 series card try to switch to `torch.float32`.
+
+#### Brushnet
+
+- `scale`, defaults to 1.0: The "strength" of BrushNet. The outputs of the BrushNet are multiplied by `scale` before they are added to the residual in the original unet.
+- `start_at`, defaults to 0: step at which the BrushNet starts applying.
+- `end_at`, defaults to 10000: step at which the BrushNet stops applying.
+
+[Here](PARAMS.md) are examples of use these two last parameters.
+
+#### PowerPaint
+
+- `CLIP`: PowerPaint CLIP that should be passed from PowerPaintCLIPLoader node.
+- `fitting`: PowerPaint fitting degree.
+- `function`: PowerPaint function, see its [page](https://github.com/open-mmlab/PowerPaint) for details.
+- `save_memory`: If this option is set, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a decrease in speed. If you run out of VRAM or get `Error: total bytes of NDArray > 2**32` on Mac try to set this option to `max`.
+
+When using certain network functions, the authors of PowerPaint recommend adding phrases to the prompt:
+
+- object removal: `empty scene blur`
+- context aware: `empty scene`
+- outpainting: `empty scene`
+
+Many of ComfyUI users use custom text generation nodes, CLIP nodes and a lot of other conditioning. I don't want to break all of these nodes, so I didn't add prompt updating and instead rely on users. Also my own experiments show that these additions to prompt are not strictly necessary.
+
+The latent image can be from BrushNet node or not, but it should be the same size as original image (divided by 8 in latent space).
+
+The both conditioning `positive` and `negative` in BrushNet and PowerPaint nodes are used for calculation inside, but then simply copied to output.
+
+Be advised, not all workflows and nodes will work with BrushNet due to its structure. Also put model changes before BrushNet nodes, not after. If you need model to work with image after BrushNet inference use base one (see Upscale example below).
+
+#### RAUNet
+
+- `du_start`, defaults to 0: step at which the Downsample/Upsample resize starts applying.
+- `du_end`, defaults to 4: step at which the Downsample/Upsample resize stops applying.
+- `xa_start`, defaults to 4: step at which the CrossAttention resize starts applying.
+- `xa_end`, defaults to 10: step at which the CrossAttention resize stops applying.
+
+For an examples and explanation, please look [here](RAUNET.md).
+
+## Limitations
+
+BrushNet has some limitations (from the [paper](https://arxiv.org/abs/2403.06976)):
+
+- The quality and content generated by the model are heavily dependent on the chosen base model.
+The results can exhibit incoherence if, for example, the given image is a natural image while the base model primarily focuses on anime.
+- Even with BrushNet, we still observe poor generation results in cases where the given mask has an unusually shaped
+or irregular form, or when the given text does not align well with the masked image.
+
+## Notes
+
+Unfortunately, due to the nature of BrushNet code some nodes are not compatible with these, since we are trying to patch the same ComfyUI's functions.
+
+List of known uncompartible nodes.
+
+- [WASasquatch's FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced/tree/main)
+- [blepping's jank HiDiffusion](https://github.com/blepping/comfyui_jankhidiffusion)
+
+## Credits
+
+The code is based on
+
+- [BrushNet](https://github.com/TencentARC/BrushNet)
+- [PowerPaint](https://github.com/zhuang2002/PowerPaint)
+- [HiDiffusion](https://github.com/megvii-research/HiDiffusion)
+- [diffusers](https://github.com/huggingface/diffusers)
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab78039654dd3fb2052fc0a024241cd3c92bbe26
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,62 @@
+from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, PowerPaintCLIPLoader, PowerPaint, CutForInpaint
+from .raunet_nodes import RAUNet
+import torch
+from subprocess import getoutput
+
+"""
+@author: nullquant
+@title: BrushNet
+@nickname: BrushName nodes
+@description: These are custom nodes for ComfyUI native implementation of BrushNet, PowerPaint and RAUNet models
+"""
+
+class Terminal:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return { "required": {
+ "text": ("STRING", {"multiline": True})
+ }
+ }
+
+ CATEGORY = "utils"
+ RETURN_TYPES = ("IMAGE", )
+ RETURN_NAMES = ("image", )
+ OUTPUT_NODE = True
+
+ FUNCTION = "execute"
+
+ def execute(self, text):
+ if text[0] == '"' and text[-1] == '"':
+ out = getoutput(f"{text[1:-1]}")
+ print(out)
+ else:
+ exec(f"{text}")
+ return (torch.zeros(1, 128, 128, 4), )
+
+
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "BrushNetLoader": BrushNetLoader,
+ "BrushNet": BrushNet,
+ "BlendInpaint": BlendInpaint,
+ "PowerPaintCLIPLoader": PowerPaintCLIPLoader,
+ "PowerPaint": PowerPaint,
+ "CutForInpaint": CutForInpaint,
+ "RAUNet": RAUNet,
+ "Terminal": Terminal,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "BrushNetLoader": "BrushNet Loader",
+ "BrushNet": "BrushNet",
+ "BlendInpaint": "Blend Inpaint",
+ "PowerPaintCLIPLoader": "PowerPaint CLIP Loader",
+ "PowerPaint": "PowerPaint",
+ "CutForInpaint": "Cut For Inpaint",
+ "RAUNet": "RAUNet",
+ "Terminal": "Terminal",
+}
diff --git a/brushnet/brushnet.json b/brushnet/brushnet.json
new file mode 100644
index 0000000000000000000000000000000000000000..65713bfcd0113271496bd06fe6b57299822e0f76
--- /dev/null
+++ b/brushnet/brushnet.json
@@ -0,0 +1,58 @@
+{
+ "_class_name": "BrushNetModel",
+ "_diffusers_version": "0.27.0.dev0",
+ "_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
+ "act_fn": "silu",
+ "addition_embed_type": null,
+ "addition_embed_type_num_heads": 64,
+ "addition_time_embed_dim": null,
+ "attention_head_dim": 8,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "brushnet_conditioning_channel_order": "rgb",
+ "class_embed_type": null,
+ "conditioning_channels": 5,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "cross_attention_dim": 768,
+ "down_block_types": [
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "encoder_hid_dim": null,
+ "encoder_hid_dim_type": null,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "global_pool_conditions": false,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "mid_block_type": "MidBlock2D",
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_attention_heads": null,
+ "num_class_embeds": null,
+ "only_cross_attention": false,
+ "projection_class_embeddings_input_dim": null,
+ "resnet_time_scale_shift": "default",
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D"
+ ],
+ "upcast_attention": false,
+ "use_linear_projection": false
+}
diff --git a/brushnet/brushnet.py b/brushnet/brushnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..aed1cfde30b1ab27286066746058b7b1afcd8a84
--- /dev/null
+++ b/brushnet/brushnet.py
@@ -0,0 +1,949 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+ MidBlock2D
+)
+
+from .unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class BrushNetOutput(BaseOutput):
+ """
+ The output of [`BrushNetModel`].
+
+ Args:
+ up_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's upsampling activations.
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ up_block_res_samples: Tuple[torch.Tensor]
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class BrushNetModel(ModelMixin, ConfigMixin):
+ """
+ A BrushNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 5,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
+ up_block_types: Tuple[str, ...] = (
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ brushnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in_condition = nn.Conv2d(
+ in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ self.down_blocks = nn.ModuleList([])
+ self.brushnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ if not is_final_block:
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_mid_block = brushnet_block
+
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+
+ self.up_blocks = nn.ModuleList([])
+ self.brushnet_up_blocks = nn.ModuleList([])
+
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block+1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=reversed_num_attention_heads[i],
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ for _ in range(layers_per_block+1):
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_up_blocks.append(brushnet_block)
+
+ if not is_final_block:
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_up_blocks.append(brushnet_block)
+
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ brushnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 5,
+ ):
+ r"""
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ brushnet = cls(
+ in_channels=unet.config.in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=["DownBlock2D" for block_name in unet.config.down_block_types],
+ mid_block_type='MidBlock2D',
+ up_block_types=["UpBlock2D" for block_name in unet.config.down_block_types],
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
+ conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
+ conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
+ brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
+ brushnet.conv_in_condition.bias=unet.conv_in.bias
+
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if brushnet.class_embedding:
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
+
+ return brushnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ encoder_hidden_states: torch.Tensor,
+ brushnet_cond: torch.FloatTensor,
+ timestep = None,
+ time_emb = None,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ debug = False,
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
+ """
+ The [`BrushNetModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ brushnet_cond (`torch.FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for BrushNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+
+ # check channel order
+ channel_order = self.config.brushnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ if timestep is None and time_emb is None:
+ raise ValueError(f"`timestep` and `emb` are both None")
+
+ #print("BN: sample.device", sample.device)
+ #print("BN: TE.device", self.time_embedding.linear_1.weight.device)
+
+ if timestep is not None:
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ #print("t_emb.device =",t_emb.device)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ #print('emb.shape', emb.shape)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ #print('text_embeds', text_embeds.shape, 'time_ids', time_ids.shape, 'time_embeds', time_embeds.shape, 'add__embeds', add_embeds.shape, 'aug_emb', aug_emb.shape)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+ else:
+ emb = time_emb
+
+ # 2. pre-process
+
+ brushnet_cond=torch.concat([sample,brushnet_cond],1)
+ sample = self.conv_in_condition(brushnet_cond)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. PaintingNet down blocks
+ brushnet_down_block_res_samples = ()
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
+
+
+ # 5. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 6. BrushNet mid blocks
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
+
+ # 7. up
+ up_block_res_samples = ()
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample, up_res_samples = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ return_res_samples=True
+ )
+ else:
+ sample, up_res_samples = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ return_res_samples=True
+ )
+
+ up_block_res_samples += up_res_samples
+
+ # 8. BrushNet up blocks
+ brushnet_up_block_res_samples = ()
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+
+ brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
+ brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
+ else:
+ brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
+
+
+ if self.config.global_pool_conditions:
+ brushnet_down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
+ ]
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
+ brushnet_up_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
+ ]
+
+ if not return_dict:
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
+
+ return BrushNetOutput(
+ down_block_res_samples=brushnet_down_block_res_samples,
+ mid_block_res_sample=brushnet_mid_block_res_sample,
+ up_block_res_samples=brushnet_up_block_res_samples
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/brushnet/brushnet_ca.py b/brushnet/brushnet_ca.py
new file mode 100644
index 0000000000000000000000000000000000000000..780a87b23f30e2192a19469c506a22056ea52ba7
--- /dev/null
+++ b/brushnet/brushnet_ca.py
@@ -0,0 +1,983 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+ MidBlock2D
+)
+
+from .unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class BrushNetOutput(BaseOutput):
+ """
+ The output of [`BrushNetModel`].
+
+ Args:
+ up_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's upsampling activations.
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ up_block_res_samples: Tuple[torch.Tensor]
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class BrushNetModel(ModelMixin, ConfigMixin):
+ """
+ A BrushNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 5,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str, ...] = (
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ brushnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in_condition = nn.Conv2d(
+ in_channels + conditioning_channels,
+ block_out_channels[0],
+ kernel_size=conv_in_kernel,
+ padding=conv_in_padding,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ self.down_blocks = nn.ModuleList([])
+ self.brushnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ if not is_final_block:
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_mid_block = brushnet_block
+
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+
+ self.up_blocks = nn.ModuleList([])
+ self.brushnet_up_blocks = nn.ModuleList([])
+
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=reversed_num_attention_heads[i],
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ for _ in range(layers_per_block + 1):
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_up_blocks.append(brushnet_block)
+
+ if not is_final_block:
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_up_blocks.append(brushnet_block)
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ brushnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 5,
+ ):
+ r"""
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ brushnet = cls(
+ in_channels=unet.config.in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
+ down_block_types=[
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ],
+ # mid_block_type='MidBlock2D',
+ mid_block_type="UNetMidBlock2DCrossAttn",
+ # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
+ up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
+ conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
+ conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
+ brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
+ brushnet.conv_in_condition.bias = unet.conv_in.bias
+
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if brushnet.class_embedding:
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
+
+ return brushnet.to(unet.dtype)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ brushnet_cond: torch.FloatTensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ debug=False,
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
+ """
+ The [`BrushNetModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ brushnet_cond (`torch.FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for BrushNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.brushnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
+
+ if debug: print('BrushNet CA: attn mask')
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ if debug: print('BrushNet CA: time')
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if debug: print('BrushNet CA: pre-process')
+
+
+ # 2. pre-process
+ brushnet_cond = torch.concat([sample, brushnet_cond], 1)
+ sample = self.conv_in_condition(brushnet_cond)
+
+ if debug: print('BrushNet CA: down')
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ if debug: print('BrushNet CA (down block with XA): ', type(downsample_block))
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ debug=debug,
+ )
+ else:
+ if debug: print('BrushNet CA (down block): ', type(downsample_block))
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, debug=debug)
+
+ down_block_res_samples += res_samples
+
+ if debug: print('BrushNet CA: PP down')
+
+ # 4. PaintingNet down blocks
+ brushnet_down_block_res_samples = ()
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
+
+ if debug: print('BrushNet CA: PP mid')
+
+ # 5. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ if debug: print('BrushNet CA: mid')
+
+ # 6. BrushNet mid blocks
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
+
+ if debug: print('BrushNet CA: PP up')
+
+ # 7. up
+ up_block_res_samples = ()
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample, up_res_samples = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ return_res_samples=True,
+ )
+ else:
+ sample, up_res_samples = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ return_res_samples=True,
+ )
+
+ up_block_res_samples += up_res_samples
+
+ if debug: print('BrushNet CA: up')
+
+ # 8. BrushNet up blocks
+ brushnet_up_block_res_samples = ()
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
+
+ if debug: print('BrushNet CA: scaling')
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(
+ -1,
+ 0,
+ len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
+ device=sample.device,
+ ) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+
+ brushnet_down_block_res_samples = [
+ sample * scale
+ for sample, scale in zip(
+ brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
+ )
+ ]
+ brushnet_mid_block_res_sample = (
+ brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
+ )
+ brushnet_up_block_res_samples = [
+ sample * scale
+ for sample, scale in zip(
+ brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
+ )
+ ]
+ else:
+ brushnet_down_block_res_samples = [
+ sample * conditioning_scale for sample in brushnet_down_block_res_samples
+ ]
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
+
+ if self.config.global_pool_conditions:
+ brushnet_down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
+ ]
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
+ brushnet_up_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
+ ]
+
+ if debug: print('BrushNet CA: finish')
+
+ if not return_dict:
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
+
+ return BrushNetOutput(
+ down_block_res_samples=brushnet_down_block_res_samples,
+ mid_block_res_sample=brushnet_mid_block_res_sample,
+ up_block_res_samples=brushnet_up_block_res_samples,
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/brushnet/brushnet_xl.json b/brushnet/brushnet_xl.json
new file mode 100644
index 0000000000000000000000000000000000000000..c1a3c655549879fb2e9d7441ec71eef5167eac12
--- /dev/null
+++ b/brushnet/brushnet_xl.json
@@ -0,0 +1,63 @@
+{
+ "_class_name": "BrushNetModel",
+ "_diffusers_version": "0.27.0.dev0",
+ "_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
+ "act_fn": "silu",
+ "addition_embed_type": "text_time",
+ "addition_embed_type_num_heads": 64,
+ "addition_time_embed_dim": 256,
+ "attention_head_dim": [
+ 5,
+ 10,
+ 20
+ ],
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280
+ ],
+ "brushnet_conditioning_channel_order": "rgb",
+ "class_embed_type": null,
+ "conditioning_channels": 5,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "cross_attention_dim": 2048,
+ "down_block_types": [
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "encoder_hid_dim": null,
+ "encoder_hid_dim_type": null,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "global_pool_conditions": false,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "mid_block_type": "MidBlock2D",
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_attention_heads": null,
+ "num_class_embeds": null,
+ "only_cross_attention": false,
+ "projection_class_embeddings_input_dim": 2816,
+ "resnet_time_scale_shift": "default",
+ "transformer_layers_per_block": [
+ 1,
+ 2,
+ 10
+ ],
+ "up_block_types": [
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D"
+ ],
+ "upcast_attention": null,
+ "use_linear_projection": true
+}
diff --git a/brushnet/powerpaint.json b/brushnet/powerpaint.json
new file mode 100644
index 0000000000000000000000000000000000000000..4d7c73e9f5654cd775db99a0d77234765f808e6c
--- /dev/null
+++ b/brushnet/powerpaint.json
@@ -0,0 +1,57 @@
+{
+ "_class_name": "BrushNetModel",
+ "_diffusers_version": "0.27.2",
+ "act_fn": "silu",
+ "addition_embed_type": null,
+ "addition_embed_type_num_heads": 64,
+ "addition_time_embed_dim": null,
+ "attention_head_dim": 8,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "brushnet_conditioning_channel_order": "rgb",
+ "class_embed_type": null,
+ "conditioning_channels": 5,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "cross_attention_dim": 768,
+ "down_block_types": [
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "encoder_hid_dim": null,
+ "encoder_hid_dim_type": null,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "global_pool_conditions": false,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_attention_heads": null,
+ "num_class_embeds": null,
+ "only_cross_attention": false,
+ "projection_class_embeddings_input_dim": null,
+ "resnet_time_scale_shift": "default",
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D"
+ ],
+ "upcast_attention": false,
+ "use_linear_projection": false
+}
diff --git a/brushnet/powerpaint_utils.py b/brushnet/powerpaint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..244089ad9dbc00f9a6bae63dfdaabddd969739d1
--- /dev/null
+++ b/brushnet/powerpaint_utils.py
@@ -0,0 +1,497 @@
+import copy
+import random
+
+import torch
+import torch.nn as nn
+from transformers import CLIPTokenizer
+from typing import Any, List, Optional, Union
+
+class TokenizerWrapper:
+ """Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
+ currently. This wrapper is modified from https://github.com/huggingface/dif
+ fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
+ py#L358 # noqa.
+
+ Args:
+ from_pretrained (Union[str, os.PathLike], optional): The *model id*
+ of a pretrained model or a path to a *directory* containing
+ model weights and config. Defaults to None.
+ from_config (Union[str, os.PathLike], optional): The *model id*
+ of a pretrained model or a path to a *directory* containing
+ model weights and config. Defaults to None.
+
+ *args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
+ will be passed to `from_pretrained` function. Otherwise, *args
+ and **kwargs will be used to initialize the model by
+ `self._module_cls(*args, **kwargs)`.
+ """
+
+ def __init__(self, tokenizer: CLIPTokenizer):
+ self.wrapped = tokenizer
+ self.token_map = {}
+
+ def __getattr__(self, name: str) -> Any:
+ if name in self.__dict__:
+ return getattr(self, name)
+ #if name == "wrapped":
+ # return getattr(self, 'wrapped')#super().__getattr__("wrapped")
+
+ try:
+ return getattr(self.wrapped, name)
+ except AttributeError:
+ raise AttributeError(
+ "'name' cannot be found in both "
+ f"'{self.__class__.__name__}' and "
+ f"'{self.__class__.__name__}.tokenizer'."
+ )
+
+ def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
+ """Attempt to add tokens to the tokenizer.
+
+ Args:
+ tokens (Union[str, List[str]]): The tokens to be added.
+ """
+ num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
+ assert num_added_tokens != 0, (
+ f"The tokenizer already contains the token {tokens}. Please pass "
+ "a different `placeholder_token` that is not already in the "
+ "tokenizer."
+ )
+
+ def get_token_info(self, token: str) -> dict:
+ """Get the information of a token, including its start and end index in
+ the current tokenizer.
+
+ Args:
+ token (str): The token to be queried.
+
+ Returns:
+ dict: The information of the token, including its start and end
+ index in current tokenizer.
+ """
+ token_ids = self.__call__(token).input_ids
+ start, end = token_ids[1], token_ids[-2] + 1
+ return {"name": token, "start": start, "end": end}
+
+ def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
+ """Add placeholder tokens to the tokenizer.
+
+ Args:
+ placeholder_token (str): The placeholder token to be added.
+ num_vec_per_token (int, optional): The number of vectors of
+ the added placeholder token.
+ *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
+ """
+ output = []
+ if num_vec_per_token == 1:
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
+ output.append(placeholder_token)
+ else:
+ output = []
+ for i in range(num_vec_per_token):
+ ith_token = placeholder_token + f"_{i}"
+ self.try_adding_tokens(ith_token, *args, **kwargs)
+ output.append(ith_token)
+
+ for token in self.token_map:
+ if token in placeholder_token:
+ raise ValueError(
+ f"The tokenizer already has placeholder token {token} "
+ f"that can get confused with {placeholder_token} "
+ "keep placeholder tokens independent"
+ )
+ self.token_map[placeholder_token] = output
+
+ def replace_placeholder_tokens_in_text(
+ self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
+ ) -> Union[str, List[str]]:
+ """Replace the keywords in text with placeholder tokens. This function
+ will be called in `self.__call__` and `self.encode`.
+
+ Args:
+ text (Union[str, List[str]]): The text to be processed.
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
+ Defaults to False.
+ prop_tokens_to_load (float, optional): The proportion of tokens to
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
+
+ Returns:
+ Union[str, List[str]]: The processed text.
+ """
+ if isinstance(text, list):
+ output = []
+ for i in range(len(text)):
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
+ return output
+
+ for placeholder_token in self.token_map:
+ if placeholder_token in text:
+ tokens = self.token_map[placeholder_token]
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
+ if vector_shuffle:
+ tokens = copy.copy(tokens)
+ random.shuffle(tokens)
+ text = text.replace(placeholder_token, " ".join(tokens))
+ return text
+
+ def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
+ """Replace the placeholder tokens in text with the original keywords.
+ This function will be called in `self.decode`.
+
+ Args:
+ text (Union[str, List[str]]): The text to be processed.
+
+ Returns:
+ Union[str, List[str]]: The processed text.
+ """
+ if isinstance(text, list):
+ output = []
+ for i in range(len(text)):
+ output.append(self.replace_text_with_placeholder_tokens(text[i]))
+ return output
+
+ for placeholder_token, tokens in self.token_map.items():
+ merged_tokens = " ".join(tokens)
+ if merged_tokens in text:
+ text = text.replace(merged_tokens, placeholder_token)
+ return text
+
+ def __call__(
+ self,
+ text: Union[str, List[str]],
+ *args,
+ vector_shuffle: bool = False,
+ prop_tokens_to_load: float = 1.0,
+ **kwargs,
+ ):
+ """The call function of the wrapper.
+
+ Args:
+ text (Union[str, List[str]]): The text to be tokenized.
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
+ Defaults to False.
+ prop_tokens_to_load (float, optional): The proportion of tokens to
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
+ """
+ replaced_text = self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ )
+
+ return self.wrapped.__call__(replaced_text, *args, **kwargs)
+
+ def encode(self, text: Union[str, List[str]], *args, **kwargs):
+ """Encode the passed text to token index.
+
+ Args:
+ text (Union[str, List[str]]): The text to be encode.
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
+ """
+ replaced_text = self.replace_placeholder_tokens_in_text(text)
+ return self.wrapped(replaced_text, *args, **kwargs)
+
+ def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
+ """Decode the token index to text.
+
+ Args:
+ token_ids: The token index to be decoded.
+ return_raw: Whether keep the placeholder token in the text.
+ Defaults to False.
+ *args, **kwargs: The arguments for `self.wrapped.decode`.
+
+ Returns:
+ Union[str, List[str]]: The decoded text.
+ """
+ text = self.wrapped.decode(token_ids, *args, **kwargs)
+ if return_raw:
+ return text
+ replaced_text = self.replace_text_with_placeholder_tokens(text)
+ return replaced_text
+
+ def __repr__(self):
+ """The representation of the wrapper."""
+ s = super().__repr__()
+ prefix = f"Wrapped Module Class: {self._module_cls}\n"
+ prefix += f"Wrapped Module Name: {self._module_name}\n"
+ if self._from_pretrained:
+ prefix += f"From Pretrained: {self._from_pretrained}\n"
+ s = prefix + s
+ return s
+
+
+class EmbeddingLayerWithFixes(nn.Module):
+ """The revised embedding layer to support external embeddings. This design
+ of this class is inspired by https://github.com/AUTOMATIC1111/stable-
+ diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
+ jack.py#L224 # noqa.
+
+ Args:
+ wrapped (nn.Emebdding): The embedding layer to be wrapped.
+ external_embeddings (Union[dict, List[dict]], optional): The external
+ embeddings added to this layer. Defaults to None.
+ """
+
+ def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
+ super().__init__()
+ self.wrapped = wrapped
+ self.num_embeddings = wrapped.weight.shape[0]
+
+ self.external_embeddings = []
+ if external_embeddings:
+ self.add_embeddings(external_embeddings)
+
+ self.trainable_embeddings = nn.ParameterDict()
+
+ @property
+ def weight(self):
+ """Get the weight of wrapped embedding layer."""
+ return self.wrapped.weight
+
+ def check_duplicate_names(self, embeddings: List[dict]):
+ """Check whether duplicate names exist in list of 'external
+ embeddings'.
+
+ Args:
+ embeddings (List[dict]): A list of embedding to be check.
+ """
+ names = [emb["name"] for emb in embeddings]
+ assert len(names) == len(set(names)), (
+ "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
+ )
+
+ def check_ids_overlap(self, embeddings):
+ """Check whether overlap exist in token ids of 'external_embeddings'.
+
+ Args:
+ embeddings (List[dict]): A list of embedding to be check.
+ """
+ ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
+ ids_range.sort() # sort by 'start'
+ # check if 'end' has overlapping
+ for idx in range(len(ids_range) - 1):
+ name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
+ assert ids_range[idx][1] <= ids_range[idx + 1][0], (
+ f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
+ )
+
+ def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
+ """Add external embeddings to this layer.
+
+ Use case:
+
+ >>> 1. Add token to tokenizer and get the token id.
+ >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
+ >>> # 'how much' in kiswahili
+ >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
+ >>>
+ >>> 2. Add external embeddings to the model.
+ >>> new_embedding = {
+ >>> 'name': 'ngapi', # 'how much' in kiswahili
+ >>> 'embedding': torch.ones(1, 15) * 4,
+ >>> 'start': tokenizer.get_token_info('kwaheri')['start'],
+ >>> 'end': tokenizer.get_token_info('kwaheri')['end'],
+ >>> 'trainable': False # if True, will registry as a parameter
+ >>> }
+ >>> embedding_layer = nn.Embedding(10, 15)
+ >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
+ >>> embedding_layer_wrapper.add_embeddings(new_embedding)
+ >>>
+ >>> 3. Forward tokenizer and embedding layer!
+ >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
+ >>> input_ids = tokenizer(
+ >>> input_text, padding='max_length', truncation=True,
+ >>> return_tensors='pt')['input_ids']
+ >>> out_feat = embedding_layer_wrapper(input_ids)
+ >>>
+ >>> 4. Let's validate the result!
+ >>> assert (out_feat[0, 3: 7] == 2.3).all()
+ >>> assert (out_feat[2, 5: 9] == 2.3).all()
+
+ Args:
+ embeddings (Union[dict, list[dict]]): The external embeddings to
+ be added. Each dict must contain the following 4 fields: 'name'
+ (the name of this embedding), 'embedding' (the embedding
+ tensor), 'start' (the start token id of this embedding), 'end'
+ (the end token id of this embedding). For example:
+ `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
+ """
+ if isinstance(embeddings, dict):
+ embeddings = [embeddings]
+
+ self.external_embeddings += embeddings
+ self.check_duplicate_names(self.external_embeddings)
+ self.check_ids_overlap(self.external_embeddings)
+
+ # set for trainable
+ added_trainable_emb_info = []
+ for embedding in embeddings:
+ trainable = embedding.get("trainable", False)
+ if trainable:
+ name = embedding["name"]
+ embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
+ self.trainable_embeddings[name] = embedding["embedding"]
+ added_trainable_emb_info.append(name)
+
+ added_emb_info = [emb["name"] for emb in embeddings]
+ added_emb_info = ", ".join(added_emb_info)
+ print(f"Successfully add external embeddings: {added_emb_info}.", "current")
+
+ if added_trainable_emb_info:
+ added_trainable_emb_info = ", ".join(added_trainable_emb_info)
+ print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
+
+ def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ """Replace external input ids to 0.
+
+ Args:
+ input_ids (torch.Tensor): The input ids to be replaced.
+
+ Returns:
+ torch.Tensor: The replaced input ids.
+ """
+ input_ids_fwd = input_ids.clone()
+ input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
+ return input_ids_fwd
+
+ def replace_embeddings(
+ self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
+ ) -> torch.Tensor:
+ """Replace external embedding to the embedding layer. Noted that, in
+ this function we use `torch.cat` to avoid inplace modification.
+
+ Args:
+ input_ids (torch.Tensor): The original token ids. Shape like
+ [LENGTH, ].
+ embedding (torch.Tensor): The embedding of token ids after
+ `replace_input_ids` function.
+ external_embedding (dict): The external embedding to be replaced.
+
+ Returns:
+ torch.Tensor: The replaced embedding.
+ """
+ new_embedding = []
+
+ name = external_embedding["name"]
+ start = external_embedding["start"]
+ end = external_embedding["end"]
+ target_ids_to_replace = [i for i in range(start, end)]
+ ext_emb = external_embedding["embedding"].to(embedding.device)
+
+ # do not need to replace
+ if not (input_ids == start).any():
+ return embedding
+
+ # start replace
+ s_idx, e_idx = 0, 0
+ while e_idx < len(input_ids):
+ if input_ids[e_idx] == start:
+ if e_idx != 0:
+ # add embedding do not need to replace
+ new_embedding.append(embedding[s_idx:e_idx])
+
+ # check if the next embedding need to replace is valid
+ actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
+ assert actually_ids_to_replace == target_ids_to_replace, (
+ f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
+ f"Expect '{target_ids_to_replace}' for embedding "
+ f"'{name}' but found '{actually_ids_to_replace}'."
+ )
+
+ new_embedding.append(ext_emb)
+
+ s_idx = e_idx + end - start
+ e_idx = s_idx + 1
+ else:
+ e_idx += 1
+
+ if e_idx == len(input_ids):
+ new_embedding.append(embedding[s_idx:e_idx])
+
+ return torch.cat(new_embedding, dim=0)
+
+ def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None):
+ """The forward function.
+
+ Args:
+ input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
+ [LENGTH, ].
+ external_embeddings (Optional[List[dict]]): The external
+ embeddings. If not passed, only `self.external_embeddings`
+ will be used. Defaults to None.
+
+ input_ids: shape like [bz, LENGTH] or [LENGTH].
+ """
+
+ assert input_ids.ndim in [1, 2]
+ if input_ids.ndim == 1:
+ input_ids = input_ids.unsqueeze(0)
+
+ if external_embeddings is None and not self.external_embeddings:
+ return self.wrapped(input_ids, out_dtype=out_dtype)
+
+ input_ids_fwd = self.replace_input_ids(input_ids)
+ inputs_embeds = self.wrapped(input_ids_fwd)
+
+ vecs = []
+
+ if external_embeddings is None:
+ external_embeddings = []
+ elif isinstance(external_embeddings, dict):
+ external_embeddings = [external_embeddings]
+ embeddings = self.external_embeddings + external_embeddings
+
+ for input_id, embedding in zip(input_ids, inputs_embeds):
+ new_embedding = embedding
+ for external_embedding in embeddings:
+ new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
+ vecs.append(new_embedding)
+
+ return torch.stack(vecs).to(out_dtype)
+
+
+
+def add_tokens(
+ tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
+):
+ """Add token for training.
+
+ # TODO: support add tokens as dict, then we can load pretrained tokens.
+ """
+ if initialize_tokens is not None:
+ assert len(initialize_tokens) == len(
+ placeholder_tokens
+ ), "placeholder_token should be the same length as initialize_token"
+ for ii in range(len(placeholder_tokens)):
+ tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
+
+ # text_encoder.set_embedding_layer()
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
+ text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
+
+ assert embedding_layer is not None, (
+ "Do not support get embedding layer for current text encoder. " "Please check your configuration."
+ )
+ initialize_embedding = []
+ if initialize_tokens is not None:
+ for ii in range(len(placeholder_tokens)):
+ init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
+ temp_embedding = embedding_layer.weight[init_id]
+ initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
+ else:
+ for ii in range(len(placeholder_tokens)):
+ init_id = tokenizer("a").input_ids[1]
+ temp_embedding = embedding_layer.weight[init_id]
+ len_emb = temp_embedding.shape[0]
+ init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
+ initialize_embedding.append(init_weight)
+
+ # initialize_embedding = torch.cat(initialize_embedding,dim=0)
+
+ token_info_all = []
+ for ii in range(len(placeholder_tokens)):
+ token_info = tokenizer.get_token_info(placeholder_tokens[ii])
+ token_info["embedding"] = initialize_embedding[ii]
+ token_info["trainable"] = True
+ token_info_all.append(token_info)
+ embedding_layer.add_embeddings(token_info_all)
diff --git a/brushnet/unet_2d_blocks.py b/brushnet/unet_2d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a083673867f2568d499480f7dcec1480b20ead0
--- /dev/null
+++ b/brushnet/unet_2d_blocks.py
@@ -0,0 +1,3907 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate, is_torch_version, logging
+from diffusers.utils.torch_utils import apply_freeu
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+from diffusers.models.normalization import AdaGroupNorm
+from diffusers.models.resnet import (
+ Downsample2D,
+ FirDownsample2D,
+ FirUpsample2D,
+ KDownsample2D,
+ KUpsample2D,
+ ResnetBlock2D,
+ ResnetBlockCondNorm2D,
+ Upsample2D,
+)
+from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
+from diffusers.models.transformers.transformer_2d import Transformer2DModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ add_downsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ downsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "ResnetDownsampleBlock2D":
+ return ResnetDownsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ if add_downsample is False:
+ downsample_type = None
+ else:
+ downsample_type = downsample_type or "conv" # default to 'conv'
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ )
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
+ return SimpleCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "KDownBlock2D":
+ return KDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif down_block_type == "KCrossAttnDownBlock2D":
+ return KCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ add_self_attention=True if not add_downsample else False,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_mid_block(
+ mid_block_type: str,
+ temb_channels: int,
+ in_channels: int,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resnet_groups: int,
+ output_scale_factor: float = 1.0,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ mid_block_only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = 1,
+ dropout: float = 0.0,
+):
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ return UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resnet_groups=resnet_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ return UNetMidBlock2DSimpleCrossAttn(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ return UNetMidBlock2D(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ num_layers=0,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ elif mid_block_type == "MidBlock2D":
+ return MidBlock2D(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ )
+ elif mid_block_type is None:
+ return None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+
+def get_up_block(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ add_upsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resolution_idx: Optional[int] = None,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ upsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+) -> nn.Module:
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "ResnetUpsampleBlock2D":
+ return ResnetUpsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ )
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
+ return SimpleCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ if add_upsample is False:
+ upsample_type = None
+ else:
+ upsample_type = upsample_type or "conv" # default to 'conv'
+
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "KUpBlock2D":
+ return KUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "KCrossAttnUpBlock2D":
+ return KCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class AutoencoderTinyBlock(nn.Module):
+ """
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
+ blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ out_channels (`int`): The number of output channels.
+ act_fn (`str`):
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
+
+ Returns:
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
+ `out_channels`.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
+ super().__init__()
+ act_fn = get_activation(act_fn)
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+ act_fn,
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ act_fn,
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ )
+ self.skip = (
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ if in_channels != out_channels
+ else nn.Identity()
+ )
+ self.fuse = nn.ReLU()
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+class UNetMidBlock2D(nn.Module):
+ """
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ temb_channels (`int`): The number of temporal embedding channels.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
+ model on tasks with long-range temporal dependencies.
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
+ resnet_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use in the group normalization layers of the resnet blocks.
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
+ Whether to use pre-normalization for the resnet blocks.
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
+ attention_head_dim (`int`, *optional*, defaults to 1):
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
+ the number of input channels.
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
+
+ Returns:
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
+ in_channels, height, width)`.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ attn_groups: Optional[int] = None,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ if attn_groups is None:
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
+
+ # there is always at least one resnet
+ if resnet_time_scale_shift == "spatial":
+ resnets = [
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ ]
+ else:
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+ )
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=attn_groups,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # support for variable transformer layers per block
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for i in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DSimpleCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ self.attention_head_dim = attention_head_dim
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ self.num_heads = in_channels // self.attention_head_dim
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ cross_attention_dim=in_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ # attn
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ # resnet
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class MidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ use_linear_projection: bool = False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = False
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+
+ for i in range(num_layers):
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ lora_scale = 1.0
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+ for resnet in self.resnets[1:]:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ downsample_type: str = "conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ self.downsample_type = downsample_type
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_type == "conv":
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ elif downsample_type == "resnet":
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ if self.downsample_type == "resnet":
+ hidden_states = downsampler(hidden_states, temb=temb)
+ else:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ add_downsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals: Optional[torch.FloatTensor] = None,
+ down_block_add_samples: Optional[torch.FloatTensor] = None,
+ debug = False,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+
+ if debug: print(' XAD2: forward')
+
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ if debug: print(' XAD2: resnet hs #', i, hidden_states.shape)
+ if debug and temb is not None: print(' XAD2: resnet temb #', i, temb.shape)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if debug: print(' XAD2: attn hs #', i, hidden_states.shape)
+ if debug and encoder_hidden_states is not None: print(' XAD2: attn ehs #', i, encoder_hidden_states.shape)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+
+ if debug: print(' XAD2: add res', additional_residuals.shape)
+
+ hidden_states = hidden_states + additional_residuals
+
+ if down_block_add_samples is not None:
+
+ if debug: print(' XAD2: add samples', down_block_add_samples.shape)
+
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ if debug: print(' XAD2: output', hidden_states.shape)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
+
+ output_states = output_states + (hidden_states,)
+
+ if debug:
+ print(' XAD2: finish')
+ for st in output_states:
+ print(' XAD2: ',st.shape)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
+ down_block_add_samples: Optional[torch.FloatTensor] = None, *args, **kwargs
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ if kwargs.get("debug", False): print(' D2: forward', hidden_states.shape)
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+
+ if kwargs.get("debug", False): print(' D2: resnet', hidden_states.shape)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
+
+ output_states = output_states + (hidden_states,)
+
+ if kwargs.get("debug", False): print(' D2: finish', hidden_states.shape)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class ResnetDownsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ skip_time_act: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class SimpleCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ resnets = []
+ attentions = []
+
+ self.attention_head_dim = attention_head_dim
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class KDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ add_downsample: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ # YiYi's comments- might be able to use FirDownsample2D, look into details later
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class KCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ cross_attention_dim: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_group_size: int = 32,
+ add_downsample: bool = True,
+ attention_head_dim: int = 64,
+ add_self_attention: bool = False,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ out_channels,
+ out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ group_size=resnet_group_size,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.downsamplers is None:
+ output_states += (None,)
+ else:
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ upsample_type: str = "conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.upsample_type = upsample_type
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_type == "conv":
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ elif upsample_type == "resnet":
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ if self.upsample_type == "resnet":
+ hidden_states = upsampler(hidden_states, temb=temb)
+ else:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ return_res_samples: Optional[bool]=False,
+ up_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ if return_res_samples:
+ output_states=()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ if return_res_samples:
+ return hidden_states, output_states
+ else:
+ return hidden_states
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ return_res_samples: Optional[bool]=False,
+ up_block_add_samples: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ if return_res_samples:
+ output_states = ()
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
+
+ if return_res_samples:
+ return hidden_states, output_states
+ else:
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=temb)
+ hidden_states = attn(hidden_states, temb=temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample=None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_upsample: bool = True,
+ upsample_padding: int = 1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample=None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class ResnetUpsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ skip_time_act: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class SimpleCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # resnet
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class KUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int,
+ dropout: float = 0.0,
+ num_layers: int = 5,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: Optional[int] = 32,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ k_in_channels = 2 * out_channels
+ k_out_channels = in_channels
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=k_out_channels if (i == num_layers - 1) else out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class KCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ attention_head_dim: int = 1, # attention dim_head
+ cross_attention_dim: int = 768,
+ add_upsample: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ is_first_block = in_channels == out_channels == temb_channels
+ is_middle_block = in_channels != out_channels
+ add_self_attention = True if is_first_block else False
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ # in_channels, and out_channels for the block (k-unet)
+ k_in_channels = out_channels if is_first_block else 2 * out_channels
+ k_out_channels = in_channels
+
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ if is_middle_block and (i == num_layers - 1):
+ conv_2d_out_channels = k_out_channels
+ else:
+ conv_2d_out_channels = None
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ conv_2d_out_channels=conv_2d_out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ k_out_channels if (i == num_layers - 1) else out_channels,
+ k_out_channels // attention_head_dim
+ if (i == num_layers - 1)
+ else out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+# can potentially later be renamed to `No-feed-forward` attention
+class KAttentionBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Configure if the attention layers should contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to upcast the attention computation to `float32`.
+ temb_channels (`int`, *optional*, defaults to 768):
+ The number of channels in the token embedding.
+ add_self_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add self-attention to the block.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ group_size (`int`, *optional*, defaults to 32):
+ The number of groups to separate the channels into for group normalization.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout: float = 0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ temb_channels: int = 768, # for ada_group_norm
+ add_self_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ group_size: int = 32,
+ ):
+ super().__init__()
+ self.add_self_attention = add_self_attention
+
+ # 1. Self-Attn
+ if add_self_attention:
+ self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ cross_attention_norm=None,
+ )
+
+ # 2. Cross-Attn
+ self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+
+ def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
+
+ def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ # TODO: mark emb as non-optional (self.norm2 requires it).
+ # requires assessing impact of change to positional param interface.
+ emb: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ # 1. Self-Attention
+ if self.add_self_attention:
+ norm_hidden_states = self.norm1(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention/None
+ norm_hidden_states = self.norm2(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
diff --git a/brushnet/unet_2d_condition.py b/brushnet/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..088e0efdba9f481c57137e5413e795fcca74c6a5
--- /dev/null
+++ b/brushnet/unet_2d_condition.py
@@ -0,0 +1,1355 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ GLIGENTextBoundingboxProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ self._check_config(
+ down_block_types=down_block_types,
+ up_block_types=up_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
+ time_embedding_type,
+ block_out_channels=block_out_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ time_embedding_dim=time_embedding_dim,
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ self._set_encoder_hid_proj(
+ encoder_hid_dim_type,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ )
+
+ # class embedding
+ self._set_class_embedding(
+ class_embed_type,
+ act_fn=act_fn,
+ num_class_embeds=num_class_embeds,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ timestep_input_dim=timestep_input_dim,
+ )
+
+ self._set_add_embedding(
+ addition_embed_type,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
+ addition_time_embed_dim=addition_time_embed_dim,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ )
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ temb_channels=blocks_time_embed_dim,
+ in_channels=block_out_channels[-1],
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ output_scale_factor=mid_block_scale_factor,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[-1],
+ dropout=dropout,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
+
+ def _check_config(
+ self,
+ down_block_types: Tuple[str],
+ up_block_types: Tuple[str],
+ only_cross_attention: Union[bool, Tuple[bool]],
+ block_out_channels: Tuple[int],
+ layers_per_block: Union[int, Tuple[int]],
+ cross_attention_dim: Union[int, Tuple[int]],
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
+ reverse_transformer_layers_per_block: bool,
+ attention_head_dim: int,
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
+ ):
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ def _set_time_proj(
+ self,
+ time_embedding_type: str,
+ block_out_channels: int,
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ time_embedding_dim: int,
+ ) -> Tuple[int, int]:
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ return time_embed_dim, timestep_input_dim
+
+ def _set_encoder_hid_proj(
+ self,
+ encoder_hid_dim_type: Optional[str],
+ cross_attention_dim: Union[int, Tuple[int]],
+ encoder_hid_dim: Optional[int],
+ ):
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ def _set_class_embedding(
+ self,
+ class_embed_type: Optional[str],
+ act_fn: str,
+ num_class_embeds: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ timestep_input_dim: int,
+ ):
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ def _set_add_embedding(
+ self,
+ addition_embed_type: str,
+ addition_embed_type_num_heads: int,
+ addition_time_embed_dim: Optional[int],
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ cross_attention_dim: Optional[int],
+ encoder_hid_dim: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ ):
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def unload_lora(self):
+ """Unloads LoRA weights."""
+ deprecate(
+ "unload_lora",
+ "0.28.0",
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
+ )
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ def get_time_embed(
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
+ ) -> Optional[torch.Tensor]:
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ return t_emb
+
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ class_emb = None
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+ return class_emb
+
+ def get_aug_embed(
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> Optional[torch.Tensor]:
+ aug_emb = None
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb = self.add_embedding(image_embs, hint)
+ return aug_emb
+
+ def process_encoder_hidden_states(
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> torch.Tensor:
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
+ return encoder_hidden_states
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
+ if cross_attention_kwargs is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+
+ if is_brushnet:
+ sample = sample + down_block_add_samples.pop(0)
+
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ i = len(down_block_add_samples)
+
+ if is_brushnet and len(down_block_add_samples)>0:
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ additional_residuals = {}
+
+ i = len(down_block_add_samples)
+
+ if is_brushnet and len(down_block_add_samples)>0:
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
+
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, **additional_residuals)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ if is_brushnet:
+ sample = sample + mid_block_add_sample
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ additional_residuals = {}
+
+ i = len(up_block_add_samples)
+
+ if is_brushnet and len(up_block_add_samples)>0:
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
+
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ additional_residuals = {}
+
+ i = len(up_block_add_samples)
+
+ if is_brushnet and len(up_block_add_samples)>0:
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
+
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ **additional_residuals,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/brushnet_nodes.py b/brushnet_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eaa3669cd6438bb93577c4f8bb6f9e81a3e2497
--- /dev/null
+++ b/brushnet_nodes.py
@@ -0,0 +1,1085 @@
+import os
+import types
+from typing import Tuple
+
+import torch
+import torchvision.transforms as T
+import torch.nn.functional as F
+from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
+import comfy
+import folder_paths
+
+from .model_patch import add_model_patch_option, patch_model_function_wrapper
+
+from .brushnet.brushnet import BrushNetModel
+from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
+
+from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
+
+current_directory = os.path.dirname(os.path.abspath(__file__))
+brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
+brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
+powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
+
+sd15_scaling_factor = 0.18215
+sdxl_scaling_factor = 0.13025
+
+ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
+ comfy.ldm.models.autoencoder.AutoencoderKL
+ ]
+
+
+class BrushNetLoader:
+
+ @classmethod
+ def INPUT_TYPES(self):
+ self.inpaint_files = get_files_with_extension('inpaint')
+ return {"required":
+ {
+ "brushnet": ([file for file in self.inpaint_files], ),
+ "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("BRMODEL",)
+ RETURN_NAMES = ("brushnet",)
+
+ FUNCTION = "brushnet_loading"
+
+ def brushnet_loading(self, brushnet, dtype):
+ brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
+ is_SDXL = False
+ is_PP = False
+ sd = comfy.utils.load_torch_file(brushnet_file)
+ brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
+ del sd
+ if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
+ is_SDXL = False
+ if keys == 322:
+ is_PP = False
+ print('BrushNet model type: SD1.5')
+ else:
+ is_PP = True
+ print('PowerPaint model type: SD1.5')
+ elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
+ print('BrushNet model type: Loading SDXL')
+ is_SDXL = True
+ is_PP = False
+ else:
+ raise Exception("Unknown BrushNet model")
+
+ with init_empty_weights():
+ if is_SDXL:
+ brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
+ elif is_PP:
+ brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
+ brushnet_model = PowerPaintModel.from_config(brushnet_config)
+ else:
+ brushnet_config = BrushNetModel.load_config(brushnet_config_file)
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
+
+ if is_PP:
+ print("PowerPaint model file:", brushnet_file)
+ else:
+ print("BrushNet model file:", brushnet_file)
+
+ if dtype == 'float16':
+ torch_dtype = torch.float16
+ elif dtype == 'bfloat16':
+ torch_dtype = torch.bfloat16
+ elif dtype == 'float32':
+ torch_dtype = torch.float32
+ else:
+ torch_dtype = torch.float64
+
+ brushnet_model = load_checkpoint_and_dispatch(
+ brushnet_model,
+ brushnet_file,
+ device_map="sequential",
+ max_memory=None,
+ offload_folder=None,
+ offload_state_dict=False,
+ dtype=torch_dtype,
+ force_hooks=False,
+ )
+
+ if is_PP:
+ print("PowerPaint model is loaded")
+ elif is_SDXL:
+ print("BrushNet SDXL model is loaded")
+ else:
+ print("BrushNet SD1.5 model is loaded")
+
+ return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
+
+
+class PowerPaintCLIPLoader:
+
+ @classmethod
+ def INPUT_TYPES(self):
+ self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
+ self.clip_files = get_files_with_extension('clip')
+ return {"required":
+ {
+ "base": ([file for file in self.clip_files], ),
+ "powerpaint": ([file for file in self.inpaint_files], ),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("CLIP",)
+ RETURN_NAMES = ("clip",)
+
+ FUNCTION = "ppclip_loading"
+
+ def ppclip_loading(self, base, powerpaint):
+ base_CLIP_file = os.path.join(self.clip_files[base], base)
+ pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
+
+ pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
+
+ print('PowerPaint base CLIP file: ', base_CLIP_file)
+
+ pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
+ pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
+
+ add_tokens(
+ tokenizer = pp_tokenizer,
+ text_encoder = pp_text_encoder,
+ placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
+ initialize_tokens = ["a", "a", "a"],
+ num_vectors_per_token = 10,
+ )
+
+ pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
+
+ print('PowerPaint CLIP file: ', pp_CLIP_file)
+
+ pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
+ pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
+
+ return (pp_clip,)
+
+
+class PowerPaint:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "model": ("MODEL",),
+ "vae": ("VAE", ),
+ "image": ("IMAGE",),
+ "mask": ("MASK",),
+ "powerpaint": ("BRMODEL", ),
+ "clip": ("CLIP", ),
+ "positive": ("CONDITIONING", ),
+ "negative": ("CONDITIONING", ),
+ "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
+ "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
+ "save_memory": (['none', 'auto', 'max'], ),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
+ RETURN_NAMES = ("model","positive","negative","latent",)
+
+ FUNCTION = "model_update"
+
+ def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
+
+ is_SDXL, is_PP = check_compatibilty(model, powerpaint)
+ if not is_PP:
+ raise Exception("BrushNet model was loaded, please use BrushNet node")
+
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
+ model = model.clone()
+
+ # prepare image and mask
+ # no batches for original image and mask
+ masked_image, mask = prepare_image(image, mask)
+
+ batch = masked_image.shape[0]
+ #width = masked_image.shape[2]
+ #height = masked_image.shape[1]
+
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
+ scaling_factor = model.model.model_config.latent_format.scale_factor
+ else:
+ scaling_factor = sd15_scaling_factor
+
+ torch_dtype = powerpaint['dtype']
+
+ # prepare conditioning latents
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
+
+ # prepare embeddings
+
+ if function == "object removal":
+ promptA = "P_ctxt"
+ promptB = "P_ctxt"
+ negative_promptA = "P_obj"
+ negative_promptB = "P_obj"
+ print('You should add to positive prompt: "empty scene blur"')
+ #positive = positive + " empty scene blur"
+ elif function == "context aware":
+ promptA = "P_ctxt"
+ promptB = "P_ctxt"
+ negative_promptA = ""
+ negative_promptB = ""
+ #positive = positive + " empty scene"
+ print('You should add to positive prompt: "empty scene"')
+ elif function == "shape guided":
+ promptA = "P_shape"
+ promptB = "P_ctxt"
+ negative_promptA = "P_shape"
+ negative_promptB = "P_ctxt"
+ elif function == "image outpainting":
+ promptA = "P_ctxt"
+ promptB = "P_ctxt"
+ negative_promptA = "P_obj"
+ negative_promptB = "P_obj"
+ #positive = positive + " empty scene"
+ print('You should add to positive prompt: "empty scene"')
+ else:
+ promptA = "P_obj"
+ promptB = "P_obj"
+ negative_promptA = "P_obj"
+ negative_promptB = "P_obj"
+
+ tokens = clip.tokenize(promptA)
+ prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
+
+ tokens = clip.tokenize(negative_promptA)
+ negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
+
+ tokens = clip.tokenize(promptB)
+ prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
+
+ tokens = clip.tokenize(negative_promptB)
+ negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
+
+ prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
+ negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
+
+ # unload vae and CLIPs
+ del vae
+ del clip
+ for loaded_model in comfy.model_management.current_loaded_models:
+ if type(loaded_model.model.model) in ModelsToUnload:
+ comfy.model_management.current_loaded_models.remove(loaded_model)
+ loaded_model.model_unload()
+ del loaded_model
+
+ # apply patch to model
+
+ brushnet_conditioning_scale = scale
+ control_guidance_start = start_at
+ control_guidance_end = end_at
+
+ if save_memory != 'none':
+ powerpaint['brushnet'].set_attention_slice(save_memory)
+
+ add_brushnet_patch(model,
+ powerpaint['brushnet'],
+ torch_dtype,
+ conditioning_latents,
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
+ negative_prompt_embeds_pp, prompt_embeds_pp,
+ None, None, None,
+ False)
+
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
+
+ return (model, positive, negative, {"samples":latent},)
+
+
+class BrushNet:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "model": ("MODEL",),
+ "vae": ("VAE", ),
+ "image": ("IMAGE",),
+ "mask": ("MASK",),
+ "brushnet": ("BRMODEL", ),
+ "positive": ("CONDITIONING", ),
+ "negative": ("CONDITIONING", ),
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
+ RETURN_NAMES = ("model","positive","negative","latent",)
+
+ FUNCTION = "model_update"
+
+ def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
+
+ is_SDXL, is_PP = check_compatibilty(model, brushnet)
+
+ if is_PP:
+ raise Exception("PowerPaint model was loaded, please use PowerPaint node")
+
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
+ model = model.clone()
+
+ # prepare image and mask
+ # no batches for original image and mask
+ masked_image, mask = prepare_image(image, mask)
+
+ batch = masked_image.shape[0]
+ width = masked_image.shape[2]
+ height = masked_image.shape[1]
+
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
+ scaling_factor = model.model.model_config.latent_format.scale_factor
+ elif is_SDXL:
+ scaling_factor = sdxl_scaling_factor
+ else:
+ scaling_factor = sd15_scaling_factor
+
+ torch_dtype = brushnet['dtype']
+
+ # prepare conditioning latents
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+
+ # unload vae
+ del vae
+ for loaded_model in comfy.model_management.current_loaded_models:
+ if type(loaded_model.model.model) in ModelsToUnload:
+ comfy.model_management.current_loaded_models.remove(loaded_model)
+ loaded_model.model_unload()
+ del loaded_model
+
+ # prepare embeddings
+
+ prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+ negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+
+ max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
+ if prompt_embeds.shape[1] < max_tokens:
+ multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
+ prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
+ print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
+ if negative_prompt_embeds.shape[1] < max_tokens:
+ multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
+ negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
+ print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
+
+ if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
+ pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+ else:
+ print('BrushNet: positive conditioning has not pooled_output')
+ if is_SDXL:
+ print('BrushNet will not produce correct results')
+ pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
+
+ if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
+ negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+ else:
+ print('BrushNet: negative conditioning has not pooled_output')
+ if is_SDXL:
+ print('BrushNet will not produce correct results')
+ negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
+
+ time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
+
+ if not is_SDXL:
+ pooled_prompt_embeds = None
+ negative_pooled_prompt_embeds = None
+ time_ids = None
+
+ # apply patch to model
+
+ brushnet_conditioning_scale = scale
+ control_guidance_start = start_at
+ control_guidance_end = end_at
+
+ add_brushnet_patch(model,
+ brushnet['brushnet'],
+ torch_dtype,
+ conditioning_latents,
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
+ prompt_embeds, negative_prompt_embeds,
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
+ False)
+
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
+
+ return (model, positive, negative, {"samples":latent},)
+
+
+class BlendInpaint:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "inpaint": ("IMAGE",),
+ "original": ("IMAGE",),
+ "mask": ("MASK",),
+ "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
+ "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
+ },
+ "optional":
+ {
+ "origin": ("VECTOR",),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("IMAGE","MASK",)
+ RETURN_NAMES = ("image","MASK",)
+
+ FUNCTION = "blend_inpaint"
+
+ def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
+
+ original, mask = check_image_mask(original, mask, 'Blend Inpaint')
+
+ if len(inpaint.shape) < 4:
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
+ inpaint = inpaint[None,:,:,:]
+
+ if inpaint.shape[0] < original.shape[0]:
+ print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
+ original= original[:inpaint.shape[0],:,:]
+ mask = mask[:inpaint.shape[0],:,:]
+
+ if inpaint.shape[0] > original.shape[0]:
+ # batch over inpaint
+ count = 0
+ original_list = []
+ mask_list = []
+ origin_list = []
+ while (count < inpaint.shape[0]):
+ for i in range(original.shape[0]):
+ original_list.append(original[i][None,:,:,:])
+ mask_list.append(mask[i][None,:,:])
+ if origin is not None:
+ origin_list.append(origin[i][None,:])
+ count += 1
+ if count >= inpaint.shape[0]:
+ break
+ original = torch.concat(original_list, dim=0)
+ mask = torch.concat(mask_list, dim=0)
+ if origin is not None:
+ origin = torch.concat(origin_list, dim=0)
+
+ if kernel % 2 == 0:
+ kernel += 1
+ transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
+
+ ret = []
+ blurred = []
+ for i in range(inpaint.shape[0]):
+ if origin is None:
+ blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
+ blurred.append(blurred_mask[0])
+
+ result = torch.nn.functional.interpolate(
+ inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
+ size=(
+ original[i].shape[0],
+ original[i].shape[1],
+ )
+ ).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
+ else:
+ # got mask from CutForInpaint
+ height, width, _ = original[i].shape
+ x0 = origin[i][0].item()
+ y0 = origin[i][1].item()
+
+ if mask[i].shape[0] < height or mask[i].shape[1] < width:
+ padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
+ y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
+ else:
+ padded_mask = mask[i]
+ blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
+ blurred.append(blurred_mask[0][0])
+
+ result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
+ y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
+ result = result[None,:,:,:].to(original.device).to(original.dtype)
+
+ ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
+
+ return (torch.stack(ret), torch.stack(blurred), )
+
+
+class CutForInpaint:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "image": ("IMAGE",),
+ "mask": ("MASK",),
+ "width": ("INT", {"default": 512, "min": 64, "max": 2048}),
+ "height": ("INT", {"default": 512, "min": 64, "max": 2048}),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
+ RETURN_NAMES = ("image","mask","origin",)
+
+ FUNCTION = "cut_for_inpaint"
+
+ def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
+
+ image, mask = check_image_mask(image, mask, 'BrushNet')
+
+ ret = []
+ msk = []
+ org = []
+ for i in range(image.shape[0]):
+ x0, y0, w, h = cut_with_mask(mask[i], width, height)
+ ret.append((image[i][y0:y0+h,x0:x0+w,:]))
+ msk.append((mask[i][y0:y0+h,x0:x0+w]))
+ org.append(torch.IntTensor([x0,y0]))
+
+ return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
+
+
+#### Utility function
+
+def get_files_with_extension(folder_name, extension=['.safetensors']):
+
+ try:
+ folders = folder_paths.get_folder_paths(folder_name)
+ except:
+ folders = []
+
+ if not folders:
+ folders = [os.path.join(folder_paths.models_dir, folder_name)]
+ if not os.path.isdir(folders[0]):
+ folders = [os.path.join(folder_paths.base_path, folder_name)]
+ if not os.path.isdir(folders[0]):
+ return {}
+
+ filtered_folders = []
+ for x in folders:
+ if not os.path.isdir(x):
+ continue
+ the_same = False
+ for y in filtered_folders:
+ if os.path.samefile(x, y):
+ the_same = True
+ break
+ if not the_same:
+ filtered_folders.append(x)
+
+ if not filtered_folders:
+ return {}
+
+ output = {}
+ for x in filtered_folders:
+ files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
+ filtered_files = folder_paths.filter_files_extensions(files, extension)
+
+ for f in filtered_files:
+ output[f] = x
+
+ return output
+
+
+# get blocks from state_dict so we could know which model it is
+def brushnet_blocks(sd):
+ brushnet_down_block = 0
+ brushnet_mid_block = 0
+ brushnet_up_block = 0
+ for key in sd:
+ if 'brushnet_down_block' in key:
+ brushnet_down_block += 1
+ if 'brushnet_mid_block' in key:
+ brushnet_mid_block += 1
+ if 'brushnet_up_block' in key:
+ brushnet_up_block += 1
+ return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
+
+
+# Check models compatibility
+def check_compatibilty(model, brushnet):
+ is_SDXL = False
+ is_PP = False
+ if isinstance(model.model.model_config, comfy.supported_models.SD15):
+ print('Base model type: SD1.5')
+ is_SDXL = False
+ if brushnet["SDXL"]:
+ raise Exception("Base model is SD15, but BrushNet is SDXL type")
+ if brushnet["PP"]:
+ is_PP = True
+ elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
+ print('Base model type: SDXL')
+ is_SDXL = True
+ if not brushnet["SDXL"]:
+ raise Exception("Base model is SDXL, but BrushNet is SD15 type")
+ else:
+ print('Base model type: ', type(model.model.model_config))
+ raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
+
+ return (is_SDXL, is_PP)
+
+
+def check_image_mask(image, mask, name):
+ if len(image.shape) < 4:
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
+ image = image[None,:,:,:]
+
+ if len(mask.shape) > 3:
+ # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
+ # take first mask, red channel
+ mask = (mask[:,:,:,0])[:,:,:]
+ elif len(mask.shape) < 3:
+ # mask tensor shape should be [B, H, W] but batch somehow is missing
+ mask = mask[None,:,:]
+
+ if image.shape[0] > mask.shape[0]:
+ print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
+ if mask.shape[0] == 1:
+ print(name, "will copy the mask to fill batch")
+ mask = torch.cat([mask] * image.shape[0], dim=0)
+ else:
+ print(name, "will add empty masks to fill batch")
+ empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
+ mask = torch.cat([mask, empty_mask], dim=0)
+ elif image.shape[0] < mask.shape[0]:
+ print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
+ mask = mask[:image.shape[0],:,:]
+
+ return (image, mask)
+
+
+# Prepare image and mask
+def prepare_image(image, mask):
+
+ image, mask = check_image_mask(image, mask, 'BrushNet')
+
+ print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
+
+ if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
+ raise Exception("Image and mask should be the same size")
+
+ # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
+ mask = mask.round()
+
+ masked_image = image * (1.0 - mask[:,:,:,None])
+
+ return (masked_image, mask)
+
+
+# Get origin of the mask
+def cut_with_mask(mask, width, height):
+ iy, ix = (mask == 1).nonzero(as_tuple=True)
+
+ h0, w0 = mask.shape
+
+ if iy.numel() == 0:
+ x_c = w0 / 2.0
+ y_c = h0 / 2.0
+ else:
+ x_min = ix.min().item()
+ x_max = ix.max().item()
+ y_min = iy.min().item()
+ y_max = iy.max().item()
+
+ if x_max - x_min > width or y_max - y_min > height:
+ raise Exception("Masked area is bigger than provided dimensions")
+
+ x_c = (x_min + x_max) / 2.0
+ y_c = (y_min + y_max) / 2.0
+
+ width2 = width / 2.0
+ height2 = height / 2.0
+
+ if w0 <= width:
+ x0 = 0
+ w = w0
+ else:
+ x0 = max(0, x_c - width2)
+ w = width
+ if x0 + width > w0:
+ x0 = w0 - width
+
+ if h0 <= height:
+ y0 = 0
+ h = h0
+ else:
+ y0 = max(0, y_c - height2)
+ h = height
+ if y0 + height > h0:
+ y0 = h0 - height
+
+ return (int(x0), int(y0), int(w), int(h))
+
+
+# Prepare conditioning_latents
+@torch.inference_mode()
+def get_image_latents(masked_image, mask, vae, scaling_factor):
+ processed_image = masked_image.to(vae.device)
+ image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
+ processed_mask = 1. - mask[:,None,:,:]
+ interpolated_mask = torch.nn.functional.interpolate(
+ processed_mask,
+ size=(
+ image_latents.shape[-2],
+ image_latents.shape[-1]
+ )
+ )
+ interpolated_mask = interpolated_mask.to(image_latents.device)
+
+ conditioning_latents = [image_latents, interpolated_mask]
+
+ print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
+
+ return conditioning_latents
+
+
+# Main function where magic happens
+@torch.inference_mode()
+def brushnet_inference(x, timesteps, transformer_options, debug):
+ if 'model_patch' not in transformer_options:
+ print('BrushNet inference: there is no model_patch key in transformer_options')
+ return ([], 0, [])
+ mp = transformer_options['model_patch']
+ if 'brushnet' not in mp:
+ print('BrushNet inference: there is no brushnet key in mdel_patch')
+ return ([], 0, [])
+ bo = mp['brushnet']
+ if 'model' not in bo:
+ print('BrushNet inference: there is no model key in brushnet')
+ return ([], 0, [])
+ brushnet = bo['model']
+ if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
+ print('BrushNet model is not a BrushNetModel class')
+ return ([], 0, [])
+
+ torch_dtype = bo['dtype']
+ cl_list = bo['latents']
+ brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
+ pe = bo['prompt_embeds']
+ npe = bo['negative_prompt_embeds']
+ ppe, nppe, time_ids = bo['add_embeds']
+
+ #do_classifier_free_guidance = mp['free_guidance']
+ do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
+
+ x = x.detach().clone()
+ x = x.to(torch_dtype).to(brushnet.device)
+
+ timesteps = timesteps.detach().clone()
+ timesteps = timesteps.to(torch_dtype).to(brushnet.device)
+
+ total_steps = mp['total_steps']
+ step = mp['step']
+
+ added_cond_kwargs = {}
+
+ if do_classifier_free_guidance and step == 0:
+ print('BrushNet inference: do_classifier_free_guidance is True')
+
+ sub_idx = None
+ if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
+ sub_idx = transformer_options['ad_params']['sub_idxs']
+
+ # we have batch input images
+ batch = cl_list[0].shape[0]
+ # we have incoming latents
+ latents_incoming = x.shape[0]
+ # and we already got some
+ latents_got = bo['latent_id']
+ if step == 0 or batch > 1:
+ print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
+ % (step, batch, latents_incoming, latents_got))
+
+ image_latents = []
+ masks = []
+ prompt_embeds = []
+ negative_prompt_embeds = []
+ pooled_prompt_embeds = []
+ negative_pooled_prompt_embeds = []
+ if sub_idx:
+ # AnimateDiff indexes detected
+ if step == 0:
+ print('BrushNet inference: AnimateDiff indexes detected and applied')
+
+ batch = len(sub_idx)
+
+ if do_classifier_free_guidance:
+ for i in sub_idx:
+ image_latents.append(cl_list[0][i][None,:,:,:])
+ masks.append(cl_list[1][i][None,:,:,:])
+ prompt_embeds.append(pe)
+ negative_prompt_embeds.append(npe)
+ pooled_prompt_embeds.append(ppe)
+ negative_pooled_prompt_embeds.append(nppe)
+ for i in sub_idx:
+ image_latents.append(cl_list[0][i][None,:,:,:])
+ masks.append(cl_list[1][i][None,:,:,:])
+ else:
+ for i in sub_idx:
+ image_latents.append(cl_list[0][i][None,:,:,:])
+ masks.append(cl_list[1][i][None,:,:,:])
+ prompt_embeds.append(pe)
+ pooled_prompt_embeds.append(ppe)
+ else:
+ # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
+ continue_batch = True
+ for i in range(latents_incoming):
+ number = latents_got + i
+ if number < batch:
+ # 1st pass, cond
+ image_latents.append(cl_list[0][number][None,:,:,:])
+ masks.append(cl_list[1][number][None,:,:,:])
+ prompt_embeds.append(pe)
+ pooled_prompt_embeds.append(ppe)
+ elif do_classifier_free_guidance and number < batch * 2:
+ # 2nd pass, uncond
+ image_latents.append(cl_list[0][number-batch][None,:,:,:])
+ masks.append(cl_list[1][number-batch][None,:,:,:])
+ negative_prompt_embeds.append(npe)
+ negative_pooled_prompt_embeds.append(nppe)
+ else:
+ # latent batch
+ image_latents.append(cl_list[0][0][None,:,:,:])
+ masks.append(cl_list[1][0][None,:,:,:])
+ prompt_embeds.append(pe)
+ pooled_prompt_embeds.append(ppe)
+ latents_got = -i
+ continue_batch = False
+
+ if continue_batch:
+ # we don't have full batch yet
+ if do_classifier_free_guidance:
+ if number < batch * 2 - 1:
+ bo['latent_id'] = number + 1
+ else:
+ bo['latent_id'] = 0
+ else:
+ if number < batch - 1:
+ bo['latent_id'] = number + 1
+ else:
+ bo['latent_id'] = 0
+ else:
+ bo['latent_id'] = 0
+
+ cl = []
+ for il, m in zip(image_latents, masks):
+ cl.append(torch.concat([il, m], dim=1))
+ cl2apply = torch.concat(cl, dim=0)
+
+ conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
+
+ prompt_embeds.extend(negative_prompt_embeds)
+ prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
+
+ if ppe is not None:
+ added_cond_kwargs = {}
+ added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
+
+ pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
+ pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
+ added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
+ else:
+ added_cond_kwargs = None
+
+ if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
+ if step == 0:
+ print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
+ conditioning_latents = torch.nn.functional.interpolate(
+ conditioning_latents, size=(
+ x.shape[2],
+ x.shape[3],
+ ), mode='bicubic',
+ ).to(torch_dtype).to(brushnet.device)
+
+ if step == 0:
+ print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
+
+ if debug: print('BrushNet: step =', step)
+
+ if step < control_guidance_start or step > control_guidance_end:
+ cond_scale = 0.0
+ else:
+ cond_scale = brushnet_conditioning_scale
+
+ return brushnet(x,
+ encoder_hidden_states=prompt_embeds,
+ brushnet_cond=conditioning_latents,
+ timestep = timesteps,
+ conditioning_scale=cond_scale,
+ guess_mode=False,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ debug=debug,
+ )
+
+
+# This is main patch function
+def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
+ controls,
+ prompt_embeds, negative_prompt_embeds,
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
+ debug):
+
+ is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
+
+ if is_SDXL:
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
+ [8, comfy.ldm.modules.attention.SpatialTransformer]]
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
+ output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
+ else:
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
+ [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
+ [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
+ output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
+ [6, comfy.ldm.modules.attention.SpatialTransformer],
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
+ [9, comfy.ldm.modules.attention.SpatialTransformer],
+ [10, comfy.ldm.modules.attention.SpatialTransformer],
+ [11, comfy.ldm.modules.attention.SpatialTransformer]]
+
+ def last_layer_index(block, tp):
+ layer_list = []
+ for layer in block:
+ layer_list.append(type(layer))
+ layer_list.reverse()
+ if tp not in layer_list:
+ return -1, layer_list.reverse()
+ return len(layer_list) - 1 - layer_list.index(tp), layer_list
+
+ def brushnet_forward(model, x, timesteps, transformer_options, control):
+ if 'brushnet' not in transformer_options['model_patch']:
+ input_samples = []
+ mid_sample = 0
+ output_samples = []
+ else:
+ # brushnet inference
+ input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
+
+ # give additional samples to blocks
+ for i, tp in input_blocks:
+ idx, layer_list = last_layer_index(model.input_blocks[i], tp)
+ if idx < 0:
+ print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
+ continue
+ model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
+
+ idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
+ if idx < 0:
+ print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
+ model.middle_block[idx].add_sample_after = mid_sample
+
+ for i, tp in output_blocks:
+ idx, layer_list = last_layer_index(model.output_blocks[i], tp)
+ if idx < 0:
+ print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
+ continue
+ model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
+
+ patch_model_function_wrapper(model, brushnet_forward)
+
+ to = add_model_patch_option(model)
+ mp = to['model_patch']
+ if 'brushnet' not in mp:
+ mp['brushnet'] = {}
+ bo = mp['brushnet']
+
+ bo['model'] = brushnet
+ bo['dtype'] = torch_dtype
+ bo['latents'] = conditioning_latents
+ bo['controls'] = controls
+ bo['prompt_embeds'] = prompt_embeds
+ bo['negative_prompt_embeds'] = negative_prompt_embeds
+ bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
+ bo['latent_id'] = 0
+
+ # patch layers `forward` so we can apply brushnet
+ def forward_patched_by_brushnet(self, x, *args, **kwargs):
+ h = self.original_forward(x, *args, **kwargs)
+ if hasattr(self, 'add_sample_after') and type(self):
+ to_add = self.add_sample_after
+ if torch.is_tensor(to_add):
+ # interpolate due to RAUNet
+ if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
+ to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
+ h += to_add.to(h.dtype).to(h.device)
+ else:
+ h += self.add_sample_after
+ self.add_sample_after = 0
+ return h
+
+ for i, block in enumerate(model.model.diffusion_model.input_blocks):
+ for j, layer in enumerate(block):
+ if not hasattr(layer, 'original_forward'):
+ layer.original_forward = layer.forward
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
+ layer.add_sample_after = 0
+
+ for j, layer in enumerate(model.model.diffusion_model.middle_block):
+ if not hasattr(layer, 'original_forward'):
+ layer.original_forward = layer.forward
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
+ layer.add_sample_after = 0
+
+ for i, block in enumerate(model.model.diffusion_model.output_blocks):
+ for j, layer in enumerate(block):
+ if not hasattr(layer, 'original_forward'):
+ layer.original_forward = layer.forward
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
+ layer.add_sample_after = 0
diff --git a/example/BrushNet_SDXL_basic.json b/example/BrushNet_SDXL_basic.json
new file mode 100644
index 0000000000000000000000000000000000000000..e25bae3084a0143298ec57fc4050ef7f287d2bc8
--- /dev/null
+++ b/example/BrushNet_SDXL_basic.json
@@ -0,0 +1 @@
+{"last_node_id": 62, "last_link_id": 128, "nodes": [{"id": 54, "type": "VAEDecode", "pos": [1921, 38], "size": {"0": 210, "1": 46}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 12, "type": "PreviewImage", "pos": [1515, 419], "size": {"0": 617.4000244140625, "1": 673.7999267578125}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 52, "type": "KSampler", "pos": [1564, 101], "size": {"0": 315, "1": 262}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 118}, {"name": "positive", "type": "CONDITIONING", "link": 119}, {"name": "negative", "type": "CONDITIONING", "link": 120}, {"name": "latent_image", "type": "LATENT", "link": 121, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [2, "fixed", 20, 5, "dpmpp_2m_sde_gpu", "karras", 1]}, {"id": 49, "type": "CLIPTextEncode", "pos": [649, 21], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [123], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["a vase"], "color": "#232", "bgcolor": "#353"}, {"id": 50, "type": "CLIPTextEncode", "pos": [651, 168], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [124], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 45, "type": "BrushNetLoader", "pos": [8, 251], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [125], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet_xl/diffusion_pytorch_model.safetensors"]}, {"id": 58, "type": "LoadImage", "pos": [10, 404], "size": {"0": 646.0000610351562, "1": 703.5999755859375}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [126], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image3 (2).png", "image"]}, {"id": 59, "type": "LoadImageMask", "pos": [689, 601], "size": {"0": 315, "1": 318}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [127], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask3 (1).png", "red", "image"]}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [3, 44], "size": {"0": 481, "1": 158}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [122], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 128], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors"]}, {"id": 62, "type": "BrushNet", "pos": [1130, 102], "size": {"0": 315, "1": 226}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 122}, {"name": "vae", "type": "VAE", "link": 128}, {"name": "image", "type": "IMAGE", "link": 126}, {"name": "mask", "type": "MASK", "link": 127}, {"name": "brushnet", "type": "BRMODEL", "link": 125}, {"name": "positive", "type": "CONDITIONING", "link": 123}, {"name": "negative", "type": "CONDITIONING", "link": 124}], "outputs": [{"name": "model", "type": "MODEL", "links": [118], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [119], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [120], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [121], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [118, 62, 0, 52, 0, "MODEL"], [119, 62, 1, 52, 1, "CONDITIONING"], [120, 62, 2, 52, 2, "CONDITIONING"], [121, 62, 3, 52, 3, "LATENT"], [122, 47, 0, 62, 0, "MODEL"], [123, 49, 0, 62, 5, "CONDITIONING"], [124, 50, 0, 62, 6, "CONDITIONING"], [125, 45, 0, 62, 4, "BRMODEL"], [126, 58, 0, 62, 2, "IMAGE"], [127, 59, 0, 62, 3, "MASK"], [128, 47, 2, 62, 1, "VAE"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_SDXL_basic.png b/example/BrushNet_SDXL_basic.png
new file mode 100644
index 0000000000000000000000000000000000000000..249e1a27451bc1982747609d86ef53aae488d4cd
--- /dev/null
+++ b/example/BrushNet_SDXL_basic.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4d8327182c11d7b4f87fdd1fabc9563baa8c22481142e17a4eca8232279f658
+size 2737221
diff --git a/example/BrushNet_SDXL_upscale.json b/example/BrushNet_SDXL_upscale.json
new file mode 100644
index 0000000000000000000000000000000000000000..f20b80c72ade227d729b56308ad0cb07624b90f4
--- /dev/null
+++ b/example/BrushNet_SDXL_upscale.json
@@ -0,0 +1 @@
+{"last_node_id": 76, "last_link_id": 145, "nodes": [{"id": 50, "type": "CLIPTextEncode", "pos": [651, 168], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [111], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 58, "type": "LoadImage", "pos": [10, 404], "size": {"0": 646.0000610351562, "1": 703.5999755859375}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [113], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image3 (2).png", "image"]}, {"id": 69, "type": "CLIPTextEncode", "pos": [1896, -243], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 127}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [128], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 70, "type": "CLIPTextEncode", "pos": [1895, -100], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 129}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [130], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [3, 44], "size": {"0": 481, "1": 158}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [114, 134], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80, 127, 129], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 115, 126], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors"]}, {"id": 72, "type": "UpscaleModelLoader", "pos": [1904, 43], "size": {"0": 315, "1": 58}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [131], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "UpscaleModelLoader"}, "widgets_values": ["4x-UltraSharp.pth"]}, {"id": 59, "type": "LoadImageMask", "pos": [689, 601], "size": {"0": 315, "1": 318}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [139], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask3 (1).png", "red", "image"]}, {"id": 12, "type": "PreviewImage", "pos": [1516, 419], "size": {"0": 617.4000244140625, "1": 673.7999267578125}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 73, "type": "PreviewImage", "pos": [2667, 419], "size": {"0": 639.9539794921875, "1": 667.046142578125}, "flags": {}, "order": 15, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 132}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 45, "type": "BrushNetLoader", "pos": [8, 251], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [116], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet_xl/random_mask.safetensors", "float16"]}, {"id": 68, "type": "UltimateSDUpscale", "pos": [2304, -81], "size": {"0": 315, "1": 614}, "flags": {}, "order": 14, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 145}, {"name": "model", "type": "MODEL", "link": 134}, {"name": "positive", "type": "CONDITIONING", "link": 128}, {"name": "negative", "type": "CONDITIONING", "link": 130}, {"name": "vae", "type": "VAE", "link": 126}, {"name": "upscale_model", "type": "UPSCALE_MODEL", "link": 131}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [132], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "UltimateSDUpscale"}, "widgets_values": [2, 305700295020080, "randomize", 20, 8, "euler", "normal", 0.2, "Linear", 512, 512, 8, 32, "None", 1, 64, 8, 16, true, false]}, {"id": 61, "type": "BrushNet", "pos": [1111, 105], "size": {"0": 315, "1": 246}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 114}, {"name": "vae", "type": "VAE", "link": 115}, {"name": "image", "type": "IMAGE", "link": 113}, {"name": "mask", "type": "MASK", "link": 140}, {"name": "brushnet", "type": "BRMODEL", "link": 116}, {"name": "positive", "type": "CONDITIONING", "link": 110}, {"name": "negative", "type": "CONDITIONING", "link": 111}, {"name": "clip", "type": "PPCLIP", "link": null}], "outputs": [{"name": "model", "type": "MODEL", "links": [106], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [107], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [108], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [143], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [0.8, 0, 10000]}, {"id": 75, "type": "GrowMask", "pos": [1023, 478], "size": {"0": 315, "1": 82}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "mask", "type": "MASK", "link": 139}], "outputs": [{"name": "MASK", "type": "MASK", "links": [140], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GrowMask"}, "widgets_values": [4, false]}, {"id": 49, "type": "CLIPTextEncode", "pos": [649, 21], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [110], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["a vase with flowers"], "color": "#232", "bgcolor": "#353"}, {"id": 52, "type": "KSampler", "pos": [1564, 101], "size": {"0": 315, "1": 262}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 106}, {"name": "positive", "type": "CONDITIONING", "link": 107}, {"name": "negative", "type": "CONDITIONING", "link": 108}, {"name": "latent_image", "type": "LATENT", "link": 143, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [6, "fixed", 15, 8, "euler_ancestral", "karras", 1]}, {"id": 54, "type": "VAEDecode", "pos": [1958, 155], "size": {"0": 210, "1": 46}, "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93, 145], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [106, 61, 0, 52, 0, "MODEL"], [107, 61, 1, 52, 1, "CONDITIONING"], [108, 61, 2, 52, 2, "CONDITIONING"], [110, 49, 0, 61, 5, "CONDITIONING"], [111, 50, 0, 61, 6, "CONDITIONING"], [113, 58, 0, 61, 2, "IMAGE"], [114, 47, 0, 61, 0, "MODEL"], [115, 47, 2, 61, 1, "VAE"], [116, 45, 0, 61, 4, "BRMODEL"], [126, 47, 2, 68, 4, "VAE"], [127, 47, 1, 69, 0, "CLIP"], [128, 69, 0, 68, 2, "CONDITIONING"], [129, 47, 1, 70, 0, "CLIP"], [130, 70, 0, 68, 3, "CONDITIONING"], [131, 72, 0, 68, 5, "UPSCALE_MODEL"], [132, 68, 0, 73, 0, "IMAGE"], [134, 47, 0, 68, 1, "MODEL"], [139, 59, 0, 75, 0, "MASK"], [140, 75, 0, 61, 3, "MASK"], [143, 61, 3, 52, 3, "LATENT"], [145, 54, 0, 68, 0, "IMAGE"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_SDXL_upscale.png b/example/BrushNet_SDXL_upscale.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce3dad290642fa03d509f61e13fa828b0a600da8
--- /dev/null
+++ b/example/BrushNet_SDXL_upscale.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d645692a5e2e07a57ce2c8e0f3998217a5295773e2ca8fb68be411681be2170
+size 2776172
diff --git a/example/BrushNet_basic.json b/example/BrushNet_basic.json
new file mode 100644
index 0000000000000000000000000000000000000000..c8f2c0d7ccad8d930b315aa08d70181a41813dfe
--- /dev/null
+++ b/example/BrushNet_basic.json
@@ -0,0 +1 @@
+{"last_node_id": 64, "last_link_id": 136, "nodes": [{"id": 12, "type": "PreviewImage", "pos": [1549, 441], "size": {"0": 580.6002197265625, "1": 613}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [3, 44], "size": {"0": 481, "1": 158}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [125], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 126], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 49, "type": "CLIPTextEncode", "pos": [649, 21], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [127], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["a burger"], "color": "#232", "bgcolor": "#353"}, {"id": 50, "type": "CLIPTextEncode", "pos": [651, 168], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [128], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 45, "type": "BrushNetLoader", "pos": [8, 251], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [129], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask_brushnet_ckpt/diffusion_pytorch_model.safetensors"]}, {"id": 52, "type": "KSampler", "pos": [1571, 117], "size": {"0": 315, "1": 262}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 121}, {"name": "positive", "type": "CONDITIONING", "link": 122}, {"name": "negative", "type": "CONDITIONING", "link": 123}, {"name": "latent_image", "type": "LATENT", "link": 124, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 50, 7.5, "euler", "normal", 1]}, {"id": 54, "type": "VAEDecode", "pos": [1921, 38], "size": {"0": 210, "1": 46}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 58, "type": "LoadImage", "pos": [10, 404], "size": {"0": 646.0000610351562, "1": 703.5999755859375}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [130], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image (1).jpg", "image"]}, {"id": 59, "type": "LoadImageMask", "pos": [689, 601], "size": {"0": 315, "1": 318}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [131], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask (3).jpg", "red", "image"]}, {"id": 62, "type": "BrushNet", "pos": [1102, 136], "size": {"0": 315, "1": 226}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 125}, {"name": "vae", "type": "VAE", "link": 126}, {"name": "image", "type": "IMAGE", "link": 130}, {"name": "mask", "type": "MASK", "link": 131}, {"name": "brushnet", "type": "BRMODEL", "link": 129}, {"name": "positive", "type": "CONDITIONING", "link": 127}, {"name": "negative", "type": "CONDITIONING", "link": 128}], "outputs": [{"name": "model", "type": "MODEL", "links": [121], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [122], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [123], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [124], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [121, 62, 0, 52, 0, "MODEL"], [122, 62, 1, 52, 1, "CONDITIONING"], [123, 62, 2, 52, 2, "CONDITIONING"], [124, 62, 3, 52, 3, "LATENT"], [125, 47, 0, 62, 0, "MODEL"], [126, 47, 2, 62, 1, "VAE"], [127, 49, 0, 62, 5, "CONDITIONING"], [128, 50, 0, 62, 6, "CONDITIONING"], [129, 45, 0, 62, 4, "BRMODEL"], [130, 58, 0, 62, 2, "IMAGE"], [131, 59, 0, 62, 3, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_basic.png b/example/BrushNet_basic.png
new file mode 100644
index 0000000000000000000000000000000000000000..53dc87f8a80a3e108a2235d6716da9737af6699e
--- /dev/null
+++ b/example/BrushNet_basic.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c604f5564846eae5683e7876aa9821fb9e69f012df80019f97f0a6aafbfe43a6
+size 1970339
diff --git a/example/BrushNet_cut_for_inpaint.json b/example/BrushNet_cut_for_inpaint.json
new file mode 100644
index 0000000000000000000000000000000000000000..3f67f5fae9e3fe2d57b17c7493a53944086090fb
--- /dev/null
+++ b/example/BrushNet_cut_for_inpaint.json
@@ -0,0 +1 @@
+{"last_node_id": 74, "last_link_id": 147, "nodes": [{"id": 45, "type": "BrushNetLoader", "pos": [8, 251], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [129], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}, {"id": 12, "type": "PreviewImage", "pos": [1963, 148], "size": [362.71480126953156, 313.34364410400406], "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 54, "type": "VAEDecode", "pos": [1964, 23], "size": {"0": 210, "1": 46}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93, 142], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 71, "type": "CutForInpaint", "pos": [756, 333], "size": {"0": 315, "1": 122}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 138}, {"name": "mask", "type": "MASK", "link": 139}], "outputs": [{"name": "image", "type": "IMAGE", "links": [140], "shape": 3, "slot_index": 0}, {"name": "mask", "type": "MASK", "links": [141, 144], "shape": 3, "slot_index": 1}, {"name": "origin", "type": "VECTOR", "links": [145], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CutForInpaint"}, "widgets_values": [512, 512]}, {"id": 58, "type": "LoadImage", "pos": [10, 404], "size": [695.4412421875002, 781.0468775024417], "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [138, 143, 147], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [139], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["clipspace/clipspace-mask-2517487.png [input]", "image"]}, {"id": 52, "type": "KSampler", "pos": [1617, 131], "size": {"0": 315, "1": 262}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 121}, {"name": "positive", "type": "CONDITIONING", "link": 122}, {"name": "negative", "type": "CONDITIONING", "link": 123}, {"name": "latent_image", "type": "LATENT", "link": 124, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 8, "euler_ancestral", "normal", 1]}, {"id": 74, "type": "Reroute", "pos": [736.4412231445312, 336.93898856946777], "size": [75, 26], "flags": {}, "order": 3, "mode": 0, "inputs": [{"name": "", "type": "*", "link": 147}], "outputs": [{"name": "", "type": "IMAGE", "links": null}], "properties": {"showOutputText": false, "horizontal": false}}, {"id": 62, "type": "BrushNet", "pos": [1254, 134], "size": {"0": 315, "1": 226}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 125}, {"name": "vae", "type": "VAE", "link": 126}, {"name": "image", "type": "IMAGE", "link": 140}, {"name": "mask", "type": "MASK", "link": 141}, {"name": "brushnet", "type": "BRMODEL", "link": 129}, {"name": "positive", "type": "CONDITIONING", "link": 127}, {"name": "negative", "type": "CONDITIONING", "link": 128}], "outputs": [{"name": "model", "type": "MODEL", "links": [121], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [122], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [123], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [124], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [0.8, 0, 10000]}, {"id": 50, "type": "CLIPTextEncode", "pos": [651, 168], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [128], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [3, 44], "size": {"0": 481, "1": 158}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [125], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 126], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["SD15/toonyou_beta6.safetensors"]}, {"id": 49, "type": "CLIPTextEncode", "pos": [649, 21], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [127], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["a clear blue sky"], "color": "#232", "bgcolor": "#353"}, {"id": 72, "type": "BlendInpaint", "pos": [1385, 616], "size": {"0": 315, "1": 142}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "inpaint", "type": "IMAGE", "link": 142}, {"name": "original", "type": "IMAGE", "link": 143}, {"name": "mask", "type": "MASK", "link": 144}, {"name": "origin", "type": "VECTOR", "link": 145}], "outputs": [{"name": "image", "type": "IMAGE", "links": [146], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "BlendInpaint"}, "widgets_values": [10, 10]}, {"id": 73, "type": "PreviewImage", "pos": [1784, 511], "size": [578.8481262207033, 616.0670013427734], "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 146}], "properties": {"Node name for S&R": "PreviewImage"}}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [121, 62, 0, 52, 0, "MODEL"], [122, 62, 1, 52, 1, "CONDITIONING"], [123, 62, 2, 52, 2, "CONDITIONING"], [124, 62, 3, 52, 3, "LATENT"], [125, 47, 0, 62, 0, "MODEL"], [126, 47, 2, 62, 1, "VAE"], [127, 49, 0, 62, 5, "CONDITIONING"], [128, 50, 0, 62, 6, "CONDITIONING"], [129, 45, 0, 62, 4, "BRMODEL"], [138, 58, 0, 71, 0, "IMAGE"], [139, 58, 1, 71, 1, "MASK"], [140, 71, 0, 62, 2, "IMAGE"], [141, 71, 1, 62, 3, "MASK"], [142, 54, 0, 72, 0, "IMAGE"], [143, 58, 0, 72, 1, "IMAGE"], [144, 71, 1, 72, 2, "MASK"], [145, 71, 2, 72, 3, "VECTOR"], [146, 72, 0, 73, 0, "IMAGE"], [147, 58, 0, 74, 0, "*"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_cut_for_inpaint.png b/example/BrushNet_cut_for_inpaint.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b0d68bd65a811dc11c93148db2f04485ca2fc42
--- /dev/null
+++ b/example/BrushNet_cut_for_inpaint.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89010f64c2d9b1a339c17708b56b1e6bc6dd3866f4c9b01757aa071d54f772b8
+size 2369315
diff --git a/example/BrushNet_image_batch.json b/example/BrushNet_image_batch.json
new file mode 100644
index 0000000000000000000000000000000000000000..61d9ab47c958c829e1cd705296f07634d8a23890
--- /dev/null
+++ b/example/BrushNet_image_batch.json
@@ -0,0 +1 @@
+{"last_node_id": 18, "last_link_id": 24, "nodes": [{"id": 6, "type": "SAMModelLoader (segment anything)", "pos": [329, 68], "size": [347.87583007145474, 58], "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "SAM_MODEL", "type": "SAM_MODEL", "links": [2], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "SAMModelLoader (segment anything)"}, "widgets_values": ["sam_vit_h (2.56GB)"]}, {"id": 4, "type": "GroundingDinoModelLoader (segment anything)", "pos": [324, 175], "size": [361.20001220703125, 58], "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "GROUNDING_DINO_MODEL", "type": "GROUNDING_DINO_MODEL", "links": [3], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GroundingDinoModelLoader (segment anything)"}, "widgets_values": ["GroundingDINO_SwinT_OGC (694MB)"]}, {"id": 10, "type": "BrushNetLoader", "pos": [338, 744], "size": {"0": 315, "1": 82}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [7], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}, {"id": 12, "type": "CLIPTextEncode", "pos": [805, 922], "size": [393.06744384765625, 101.02725219726562], "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 9}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [11], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["burger"]}, {"id": 11, "type": "CLIPTextEncode", "pos": [810, 786], "size": [388.26751708984375, 88.82723999023438], "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 8}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [10], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""]}, {"id": 9, "type": "BrushNet", "pos": [1279, 577], "size": {"0": 315, "1": 226}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 6}, {"name": "vae", "type": "VAE", "link": 12}, {"name": "image", "type": "IMAGE", "link": 13}, {"name": "mask", "type": "MASK", "link": 14}, {"name": "brushnet", "type": "BRMODEL", "link": 7}, {"name": "positive", "type": "CONDITIONING", "link": 10}, {"name": "negative", "type": "CONDITIONING", "link": 11}], "outputs": [{"name": "model", "type": "MODEL", "links": [15], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [16], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [17], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [18], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}, {"id": 8, "type": "CheckpointLoaderSimple", "pos": [333, 574], "size": [404.79998779296875, 98], "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [6], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [8, 9], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [12, 21], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 14, "type": "VAEDecode", "pos": [2049, 464], "size": {"0": 210, "1": 46}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 19}, {"name": "vae", "type": "VAE", "link": 21}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [20], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 15, "type": "PreviewImage", "pos": [2273, 575], "size": [394.86773681640625, 360.6271057128906], "flags": {}, "order": 14, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 20}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 13, "type": "KSampler", "pos": [1709, 576], "size": {"0": 315, "1": 262}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 15}, {"name": "positive", "type": "CONDITIONING", "link": 16}, {"name": "negative", "type": "CONDITIONING", "link": 17}, {"name": "latent_image", "type": "LATENT", "link": 18}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [19], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 8, "euler", "normal", 1]}, {"id": 2, "type": "VHS_LoadImagesPath", "pos": [334, 306], "size": [226.8000030517578, 194], "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [4, 13, 22], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}, {"name": "INT", "type": "INT", "links": null, "shape": 3}], "properties": {"Node name for S&R": "VHS_LoadImagesPath"}, "widgets_values": {"directory": "./output/", "image_load_cap": 6, "skip_first_images": 0, "select_every_nth": 1, "choose folder to upload": "image", "videopreview": {"hidden": false, "paused": false, "params": {"frame_load_cap": 6, "skip_first_images": 0, "filename": "./output/", "type": "path", "format": "folder", "select_every_nth": 1}}}}, {"id": 5, "type": "GroundingDinoSAMSegment (segment anything)", "pos": [997, 71], "size": [352.79998779296875, 122], "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "sam_model", "type": "SAM_MODEL", "link": 2}, {"name": "grounding_dino_model", "type": "GROUNDING_DINO_MODEL", "link": 3}, {"name": "image", "type": "IMAGE", "link": 4}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [14, 23], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "GroundingDinoSAMSegment (segment anything)"}, "widgets_values": ["burger", 0.3]}, {"id": 17, "type": "MaskToImage", "pos": [1424, 90], "size": {"0": 210, "1": 26}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "mask", "type": "MASK", "link": 23}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [24], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "MaskToImage"}}, {"id": 16, "type": "PreviewImage", "pos": [782, 255], "size": [379.2693328857422, 273.5831527709961], "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 22}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 18, "type": "PreviewImage", "pos": [1734, 90], "size": [353.90962829589853, 245.8631393432617], "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 24}], "properties": {"Node name for S&R": "PreviewImage"}}], "links": [[2, 6, 0, 5, 0, "SAM_MODEL"], [3, 4, 0, 5, 1, "GROUNDING_DINO_MODEL"], [4, 2, 0, 5, 2, "IMAGE"], [6, 8, 0, 9, 0, "MODEL"], [7, 10, 0, 9, 4, "BRMODEL"], [8, 8, 1, 11, 0, "CLIP"], [9, 8, 1, 12, 0, "CLIP"], [10, 11, 0, 9, 5, "CONDITIONING"], [11, 12, 0, 9, 6, "CONDITIONING"], [12, 8, 2, 9, 1, "VAE"], [13, 2, 0, 9, 2, "IMAGE"], [14, 5, 1, 9, 3, "MASK"], [15, 9, 0, 13, 0, "MODEL"], [16, 9, 1, 13, 1, "CONDITIONING"], [17, 9, 2, 13, 2, "CONDITIONING"], [18, 9, 3, 13, 3, "LATENT"], [19, 13, 0, 14, 0, "LATENT"], [20, 14, 0, 15, 0, "IMAGE"], [21, 8, 2, 14, 1, "VAE"], [22, 2, 0, 16, 0, "IMAGE"], [23, 5, 1, 17, 0, "MASK"], [24, 17, 0, 18, 0, "IMAGE"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_image_batch.png b/example/BrushNet_image_batch.png
new file mode 100644
index 0000000000000000000000000000000000000000..7240b43de5ee9487fa06f412480401a5f10afd76
--- /dev/null
+++ b/example/BrushNet_image_batch.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:704fe979477f356f52c8dc5922ebdef0eb196668056f3345967ad0df0d45a3a4
+size 1155886
diff --git a/example/BrushNet_image_big_batch.json b/example/BrushNet_image_big_batch.json
new file mode 100644
index 0000000000000000000000000000000000000000..7f1404e6533d6b4218a21ebd5a9dee3dd102a7d3
--- /dev/null
+++ b/example/BrushNet_image_big_batch.json
@@ -0,0 +1 @@
+{"last_node_id": 21, "last_link_id": 29, "nodes": [{"id": 6, "type": "SAMModelLoader (segment anything)", "pos": [329, 68], "size": {"0": 347.8758239746094, "1": 58}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "SAM_MODEL", "type": "SAM_MODEL", "links": [2], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "SAMModelLoader (segment anything)"}, "widgets_values": ["sam_vit_h (2.56GB)"]}, {"id": 4, "type": "GroundingDinoModelLoader (segment anything)", "pos": [324, 175], "size": {"0": 361.20001220703125, "1": 58}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "GROUNDING_DINO_MODEL", "type": "GROUNDING_DINO_MODEL", "links": [3], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GroundingDinoModelLoader (segment anything)"}, "widgets_values": ["GroundingDINO_SwinT_OGC (694MB)"]}, {"id": 14, "type": "VAEDecode", "pos": [2049, 464], "size": {"0": 210, "1": 46}, "flags": {}, "order": 14, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 19}, {"name": "vae", "type": "VAE", "link": 21}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [20], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 15, "type": "PreviewImage", "pos": [2273, 575], "size": {"0": 394.86773681640625, "1": 360.6271057128906}, "flags": {}, "order": 15, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 20}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 13, "type": "KSampler", "pos": [1709, 576], "size": {"0": 315, "1": 262}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 15}, {"name": "positive", "type": "CONDITIONING", "link": 16}, {"name": "negative", "type": "CONDITIONING", "link": 17}, {"name": "latent_image", "type": "LATENT", "link": 18}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [19], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 8, "euler", "normal", 1]}, {"id": 17, "type": "MaskToImage", "pos": [1424, 90], "size": {"0": 210, "1": 26}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "mask", "type": "MASK", "link": 23}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [24], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "MaskToImage"}}, {"id": 18, "type": "PreviewImage", "pos": [1734, 90], "size": [353.9096374511719, 246], "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 24}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 20, "type": "ADE_UseEvolvedSampling", "pos": [829, 579], "size": {"0": 315, "1": 118}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 26, "slot_index": 0}, {"name": "m_models", "type": "M_MODELS", "link": null}, {"name": "context_options", "type": "CONTEXT_OPTIONS", "link": 27}, {"name": "sample_settings", "type": "SAMPLE_SETTINGS", "link": null}], "outputs": [{"name": "MODEL", "type": "MODEL", "links": [25], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ADE_UseEvolvedSampling"}, "widgets_values": ["autoselect"]}, {"id": 21, "type": "VHS_LoadVideoPath", "pos": [337, 275], "size": [315, 238], "flags": {}, "order": 2, "mode": 0, "inputs": [{"name": "meta_batch", "type": "VHS_BatchManager", "link": null}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [28, 29], "shape": 3, "slot_index": 0}, {"name": "frame_count", "type": "INT", "links": null, "shape": 3}, {"name": "audio", "type": "VHS_AUDIO", "links": null, "shape": 3}, {"name": "video_info", "type": "VHS_VIDEOINFO", "links": null, "shape": 3}], "properties": {"Node name for S&R": "VHS_LoadVideoPath"}, "widgets_values": {"video": "./input/AnimateDiff.mp4", "force_rate": 0, "force_size": "Disabled", "custom_width": 512, "custom_height": 512, "frame_load_cap": 0, "skip_first_frames": 0, "select_every_nth": 1, "videopreview": {"hidden": false, "paused": false, "params": {"frame_load_cap": 0, "skip_first_frames": 0, "force_rate": 0, "filename": "./input/AnimateDiff.mp4", "type": "path", "format": "video/mp4", "select_every_nth": 1}}}}, {"id": 5, "type": "GroundingDinoSAMSegment (segment anything)", "pos": [997, 71], "size": {"0": 352.79998779296875, "1": 122}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "sam_model", "type": "SAM_MODEL", "link": 2}, {"name": "grounding_dino_model", "type": "GROUNDING_DINO_MODEL", "link": 3}, {"name": "image", "type": "IMAGE", "link": 28}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [14, 23], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "GroundingDinoSAMSegment (segment anything)"}, "widgets_values": ["tree", 0.3]}, {"id": 19, "type": "ADE_StandardStaticContextOptions", "pos": [335, 557], "size": {"0": 317.4000244140625, "1": 198}, "flags": {}, "order": 3, "mode": 0, "inputs": [{"name": "prev_context", "type": "CONTEXT_OPTIONS", "link": null}, {"name": "view_opts", "type": "VIEW_OPTS", "link": null}], "outputs": [{"name": "CONTEXT_OPTS", "type": "CONTEXT_OPTIONS", "links": [27], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ADE_StandardStaticContextOptions"}, "widgets_values": [8, 4, "pyramid", false, 0, 1]}, {"id": 11, "type": "CLIPTextEncode", "pos": [838, 808], "size": {"0": 388.26751708984375, "1": 88.82723999023438}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 8}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [10], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["mountains"], "color": "#232", "bgcolor": "#353"}, {"id": 12, "type": "CLIPTextEncode", "pos": [833, 966], "size": {"0": 393.06744384765625, "1": 101.02725219726562}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 9}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [11], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 10, "type": "BrushNetLoader", "pos": [332, 1013], "size": {"0": 315, "1": 82}, "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [7], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}, {"id": 8, "type": "CheckpointLoaderSimple", "pos": [319, 822], "size": {"0": 404.79998779296875, "1": 98}, "flags": {}, "order": 5, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [26], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [8, 9], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [12, 21], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["SD15/realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 9, "type": "BrushNet", "pos": [1279, 577], "size": {"0": 315, "1": 226}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 25}, {"name": "vae", "type": "VAE", "link": 12}, {"name": "image", "type": "IMAGE", "link": 29, "slot_index": 2}, {"name": "mask", "type": "MASK", "link": 14}, {"name": "brushnet", "type": "BRMODEL", "link": 7}, {"name": "positive", "type": "CONDITIONING", "link": 10}, {"name": "negative", "type": "CONDITIONING", "link": 11}], "outputs": [{"name": "model", "type": "MODEL", "links": [15], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [16], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [17], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [18], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}], "links": [[2, 6, 0, 5, 0, "SAM_MODEL"], [3, 4, 0, 5, 1, "GROUNDING_DINO_MODEL"], [7, 10, 0, 9, 4, "BRMODEL"], [8, 8, 1, 11, 0, "CLIP"], [9, 8, 1, 12, 0, "CLIP"], [10, 11, 0, 9, 5, "CONDITIONING"], [11, 12, 0, 9, 6, "CONDITIONING"], [12, 8, 2, 9, 1, "VAE"], [14, 5, 1, 9, 3, "MASK"], [15, 9, 0, 13, 0, "MODEL"], [16, 9, 1, 13, 1, "CONDITIONING"], [17, 9, 2, 13, 2, "CONDITIONING"], [18, 9, 3, 13, 3, "LATENT"], [19, 13, 0, 14, 0, "LATENT"], [20, 14, 0, 15, 0, "IMAGE"], [21, 8, 2, 14, 1, "VAE"], [23, 5, 1, 17, 0, "MASK"], [24, 17, 0, 18, 0, "IMAGE"], [25, 20, 0, 9, 0, "MODEL"], [26, 8, 0, 20, 0, "MODEL"], [27, 19, 0, 20, 2, "CONTEXT_OPTIONS"], [28, 21, 0, 5, 2, "IMAGE"], [29, 21, 0, 9, 2, "IMAGE"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_image_big_batch.png b/example/BrushNet_image_big_batch.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0b22f03bd64d0ab0cbd005ff19e30c088cabb27
Binary files /dev/null and b/example/BrushNet_image_big_batch.png differ
diff --git a/example/BrushNet_inpaint.json b/example/BrushNet_inpaint.json
new file mode 100644
index 0000000000000000000000000000000000000000..849c726dfb9c08f468f18c970b8422c836324ea0
--- /dev/null
+++ b/example/BrushNet_inpaint.json
@@ -0,0 +1 @@
+{"last_node_id": 61, "last_link_id": 117, "nodes": [{"id": 12, "type": "PreviewImage", "pos": [2049, 50], "size": {"0": 523.5944213867188, "1": 547.4853515625}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 92}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 56, "type": "VAEDecode", "pos": [1805, 54], "size": {"0": 210, "1": 46}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 97}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [92, 95], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 57, "type": "BlendInpaint", "pos": [1532, 734], "size": {"0": 315, "1": 122}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "inpaint", "type": "IMAGE", "link": 95}, {"name": "original", "type": "IMAGE", "link": 94}, {"name": "mask", "type": "MASK", "link": 117}], "outputs": [{"name": "image", "type": "IMAGE", "links": [96], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "BlendInpaint"}, "widgets_values": [10, 10]}, {"id": 58, "type": "PreviewImage", "pos": [2052, 646], "size": {"0": 509.60009765625, "1": 539.2001953125}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 96}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 49, "type": "CLIPTextEncode", "pos": [698, 274], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [104], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["closeup photo of white goat head"], "color": "#232", "bgcolor": "#353"}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [109, 40], "size": {"0": 481, "1": 158}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [103], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [97, 109], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 1, "type": "LoadImage", "pos": [101, 386], "size": {"0": 470.19439697265625, "1": 578.6854248046875}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [94, 112], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image2 (1).png", "image"]}, {"id": 53, "type": "LoadImageMask", "pos": [612, 638], "size": {"0": 315, "1": 318}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [116, 117], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask2 (1).png", "red", "image"]}, {"id": 45, "type": "BrushNetLoader", "pos": [49, 238], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [110], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}, {"id": 59, "type": "BrushNet", "pos": [1088, 46], "size": {"0": 315, "1": 246}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 103, "slot_index": 0}, {"name": "vae", "type": "VAE", "link": 109}, {"name": "image", "type": "IMAGE", "link": 112}, {"name": "mask", "type": "MASK", "link": 116}, {"name": "brushnet", "type": "BRMODEL", "link": 110}, {"name": "positive", "type": "CONDITIONING", "link": 104}, {"name": "negative", "type": "CONDITIONING", "link": 105}, {"name": "clip", "type": "PPCLIP", "link": null}], "outputs": [{"name": "model", "type": "MODEL", "links": [102], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [108], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [107], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [106], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [0.8, 0, 10000]}, {"id": 50, "type": "CLIPTextEncode", "pos": [700, 444], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [105], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["text, grass, deformed, pink, blue, horns"], "color": "#322", "bgcolor": "#533"}, {"id": 54, "type": "KSampler", "pos": [1449, 44], "size": {"0": 315, "1": 262}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 102}, {"name": "positive", "type": "CONDITIONING", "link": 108}, {"name": "negative", "type": "CONDITIONING", "link": 107}, {"name": "latent_image", "type": "LATENT", "link": 106, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 5, "euler_ancestral", "normal", 1]}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 54, 0, 56, 0, "LATENT"], [92, 56, 0, 12, 0, "IMAGE"], [94, 1, 0, 57, 1, "IMAGE"], [95, 56, 0, 57, 0, "IMAGE"], [96, 57, 0, 58, 0, "IMAGE"], [97, 47, 2, 56, 1, "VAE"], [102, 59, 0, 54, 0, "MODEL"], [103, 47, 0, 59, 0, "MODEL"], [104, 49, 0, 59, 5, "CONDITIONING"], [105, 50, 0, 59, 6, "CONDITIONING"], [106, 59, 3, 54, 3, "LATENT"], [107, 59, 2, 54, 2, "CONDITIONING"], [108, 59, 1, 54, 1, "CONDITIONING"], [109, 47, 2, 59, 1, "VAE"], [110, 45, 0, 59, 4, "BRMODEL"], [112, 1, 0, 59, 2, "IMAGE"], [116, 53, 0, 59, 3, "MASK"], [117, 53, 0, 57, 2, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_inpaint.png b/example/BrushNet_inpaint.png
new file mode 100644
index 0000000000000000000000000000000000000000..c0f477532df106cda4e7d7c1767cde41b9f4e668
--- /dev/null
+++ b/example/BrushNet_inpaint.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae5198be444315129712a7938414215aa105895cbc711493423a9250a1368496
+size 1568064
diff --git a/example/BrushNet_with_CN.json b/example/BrushNet_with_CN.json
new file mode 100644
index 0000000000000000000000000000000000000000..c25242c1f291434a47d2bd99f485821a5ce11b2b
--- /dev/null
+++ b/example/BrushNet_with_CN.json
@@ -0,0 +1 @@
+{"last_node_id": 60, "last_link_id": 115, "nodes": [{"id": 54, "type": "VAEDecode", "pos": [1868, 82], "size": {"0": 210, "1": 46}, "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 12, "type": "PreviewImage", "pos": [1624, 422], "size": {"0": 523.5944213867188, "1": 547.4853515625}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 56, "type": "ControlNetLoader", "pos": [-87, -117], "size": {"0": 437.9234313964844, "1": 79.99897766113281}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "CONTROL_NET", "type": "CONTROL_NET", "links": [96], "shape": 3}], "properties": {"Node name for S&R": "ControlNetLoader"}, "widgets_values": ["control-scribble.safetensors"]}, {"id": 57, "type": "LoadImage", "pos": [415, -339], "size": {"0": 315, "1": 314}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [97], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_cn.png", "image"]}, {"id": 49, "type": "CLIPTextEncode", "pos": [411, 23], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [94], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["red car model on a wooden table"], "color": "#232", "bgcolor": "#353"}, {"id": 52, "type": "KSampler", "pos": [1497, 69], "size": {"0": 315, "1": 262}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 103}, {"name": "positive", "type": "CONDITIONING", "link": 104}, {"name": "negative", "type": "CONDITIONING", "link": 105}, {"name": "latent_image", "type": "LATENT", "link": 106, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 6, "euler_ancestral", "exponential", 1]}, {"id": 55, "type": "ControlNetApply", "pos": [795, -65], "size": {"0": 317.4000244140625, "1": 98}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "conditioning", "type": "CONDITIONING", "link": 94}, {"name": "control_net", "type": "CONTROL_NET", "link": 96, "slot_index": 1}, {"name": "image", "type": "IMAGE", "link": 97}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [107], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ControlNetApply"}, "widgets_values": [0.8]}, {"id": 50, "type": "CLIPTextEncode", "pos": [704, 415], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [108], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["stand, furniture, cover"], "color": "#322", "bgcolor": "#533"}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [-109, 15], "size": {"0": 481, "1": 158}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [110], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 109], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 1, "type": "LoadImage", "pos": [101, 386], "size": {"0": 470.19439697265625, "1": 578.6854248046875}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [112], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image.jpg", "image"]}, {"id": 58, "type": "LoadImageMask", "pos": [640, 604], "size": {"0": 315, "1": 318.0000305175781}, "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [114], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask (6).jpg", "red", "image"]}, {"id": 60, "type": "GrowMask", "pos": [997, 602], "size": {"0": 315, "1": 82}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "mask", "type": "MASK", "link": 114}], "outputs": [{"name": "MASK", "type": "MASK", "links": [115], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GrowMask"}, "widgets_values": [10, true]}, {"id": 59, "type": "BrushNet", "pos": [1140, 63], "size": {"0": 315, "1": 246}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 110}, {"name": "vae", "type": "VAE", "link": 109}, {"name": "image", "type": "IMAGE", "link": 112}, {"name": "mask", "type": "MASK", "link": 115}, {"name": "brushnet", "type": "BRMODEL", "link": 111}, {"name": "positive", "type": "CONDITIONING", "link": 107}, {"name": "negative", "type": "CONDITIONING", "link": 108}, {"name": "clip", "type": "PPCLIP", "link": null}], "outputs": [{"name": "model", "type": "MODEL", "links": [103], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [104], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [105], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [106], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}, {"id": 45, "type": "BrushNetLoader", "pos": [49, 238], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 5, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [111], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [94, 49, 0, 55, 0, "CONDITIONING"], [96, 56, 0, 55, 1, "CONTROL_NET"], [97, 57, 0, 55, 2, "IMAGE"], [103, 59, 0, 52, 0, "MODEL"], [104, 59, 1, 52, 1, "CONDITIONING"], [105, 59, 2, 52, 2, "CONDITIONING"], [106, 59, 3, 52, 3, "LATENT"], [107, 55, 0, 59, 5, "CONDITIONING"], [108, 50, 0, 59, 6, "CONDITIONING"], [109, 47, 2, 59, 1, "VAE"], [110, 47, 0, 59, 0, "MODEL"], [111, 45, 0, 59, 4, "BRMODEL"], [112, 1, 0, 59, 2, "IMAGE"], [114, 58, 0, 60, 0, "MASK"], [115, 60, 0, 59, 3, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_with_CN.png b/example/BrushNet_with_CN.png
new file mode 100644
index 0000000000000000000000000000000000000000..154444681a4b57e19cb3620e16d2af46e17835c9
--- /dev/null
+++ b/example/BrushNet_with_CN.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba0cff542aa34a8a4bfc9076d973a415b7ef908aee3bd7c2c9acba4f28120500
+size 1960577
diff --git a/example/BrushNet_with_ELLA.json b/example/BrushNet_with_ELLA.json
new file mode 100644
index 0000000000000000000000000000000000000000..78065ebbf46f71dd1a5e6d2b2f0bd064f96a4b2d
--- /dev/null
+++ b/example/BrushNet_with_ELLA.json
@@ -0,0 +1 @@
+{"last_node_id": 30, "last_link_id": 53, "nodes": [{"id": 8, "type": "SetEllaTimesteps", "pos": [511, 344], "size": {"0": 315, "1": 146}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 4}, {"name": "ella", "type": "ELLA", "link": 3}, {"name": "sigmas", "type": "SIGMAS", "link": null}], "outputs": [{"name": "ELLA", "type": "ELLA", "links": [14, 15], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "SetEllaTimesteps"}, "widgets_values": ["normal", 20, 1]}, {"id": 2, "type": "ELLALoader", "pos": [89, 389], "size": {"0": 315, "1": 58}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "ELLA", "type": "ELLA", "links": [3], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ELLALoader"}, "widgets_values": ["ella-sd1.5-tsc-t5xl.safetensors"]}, {"id": 1, "type": "CheckpointLoaderSimple", "pos": [8, 216], "size": {"0": 396.80010986328125, "1": 98}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [4, 21], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [8, 10], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [17, 23], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 11, "type": "EllaTextEncode", "pos": [910, 722], "size": {"0": 400, "1": 200}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "ella", "type": "ELLA", "link": 15, "slot_index": 0}, {"name": "text_encoder", "type": "T5_TEXT_ENCODER", "link": 11, "slot_index": 1}, {"name": "clip", "type": "CLIP", "link": 10}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [25], "shape": 3, "slot_index": 0}, {"name": "CLIP CONDITIONING", "type": "CONDITIONING", "links": null, "shape": 3}], "properties": {"Node name for S&R": "EllaTextEncode"}, "widgets_values": ["", ""], "color": "#322", "bgcolor": "#533"}, {"id": 13, "type": "PreviewImage", "pos": [2135, 220], "size": {"0": 389.20013427734375, "1": 413.4000549316406}, "flags": {}, "order": 18, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 53}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 3, "type": "T5TextEncoderLoader #ELLA", "pos": [83, 516], "size": {"0": 315, "1": 106}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "T5_TEXT_ENCODER", "type": "T5_TEXT_ENCODER", "links": [7, 11], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "T5TextEncoderLoader #ELLA"}, "widgets_values": ["models--google--flan-t5-xl--text_encoder", 0, "auto"]}, {"id": 16, "type": "BrushNetLoader", "pos": [970, 328], "size": {"0": 315, "1": 82}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [20], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}, {"id": 17, "type": "LoadImage", "pos": [58, -192], "size": [315, 314], "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [31], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["clipspace/clipspace-mask-1011527.png [input]", "image"]}, {"id": 23, "type": "SAMModelLoader (segment anything)", "pos": [415, -27], "size": [373.11944580078125, 58], "flags": {}, "order": 5, "mode": 0, "outputs": [{"name": "SAM_MODEL", "type": "SAM_MODEL", "links": [43], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "SAMModelLoader (segment anything)"}, "widgets_values": ["sam_vit_h (2.56GB)"]}, {"id": 26, "type": "GroundingDinoModelLoader (segment anything)", "pos": [428, 74], "size": [361.20001220703125, 63.59926813298682], "flags": {}, "order": 6, "mode": 0, "outputs": [{"name": "GROUNDING_DINO_MODEL", "type": "GROUNDING_DINO_MODEL", "links": [44], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GroundingDinoModelLoader (segment anything)"}, "widgets_values": ["GroundingDINO_SwinT_OGC (694MB)"]}, {"id": 24, "type": "PreviewImage", "pos": [1504, -188], "size": [315.91949462890625, 286.26763916015625], "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 46}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 27, "type": "GroundingDinoSAMSegment (segment anything)", "pos": [830, -2], "size": [352.79998779296875, 122], "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "sam_model", "type": "SAM_MODEL", "link": 43}, {"name": "grounding_dino_model", "type": "GROUNDING_DINO_MODEL", "link": 44}, {"name": "image", "type": "IMAGE", "link": 45, "slot_index": 2}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [46], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [47], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "GroundingDinoSAMSegment (segment anything)"}, "widgets_values": ["goblin toy", 0.3]}, {"id": 20, "type": "ImageTransformResizeAbsolute", "pos": [410, -185], "size": {"0": 315, "1": 106}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 31}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [32, 45, 49], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ImageTransformResizeAbsolute"}, "widgets_values": [512, 512, "lanczos"]}, {"id": 12, "type": "VAEDecode", "pos": [2080, 95], "size": {"0": 210, "1": 46}, "flags": {}, "order": 16, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 16}, {"name": "vae", "type": "VAE", "link": 17}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [50], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 28, "type": "InvertMask", "pos": [1201, 21], "size": {"0": 210, "1": 26}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "mask", "type": "MASK", "link": 47}], "outputs": [{"name": "MASK", "type": "MASK", "links": [48, 51], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "InvertMask"}}, {"id": 29, "type": "BlendInpaint", "pos": [1667, 626], "size": {"0": 315, "1": 122}, "flags": {}, "order": 17, "mode": 0, "inputs": [{"name": "inpaint", "type": "IMAGE", "link": 50}, {"name": "original", "type": "IMAGE", "link": 49}, {"name": "mask", "type": "MASK", "link": 51}], "outputs": [{"name": "image", "type": "IMAGE", "links": [53], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "BlendInpaint"}, "widgets_values": [10, 10]}, {"id": 10, "type": "EllaTextEncode", "pos": [911, 473], "size": {"0": 400, "1": 200}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "ella", "type": "ELLA", "link": 14}, {"name": "text_encoder", "type": "T5_TEXT_ENCODER", "link": 7}, {"name": "clip", "type": "CLIP", "link": 8}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [24], "shape": 3, "slot_index": 0}, {"name": "CLIP CONDITIONING", "type": "CONDITIONING", "links": null, "shape": 3}], "properties": {"Node name for S&R": "EllaTextEncode"}, "widgets_values": ["wargaming shop showcase with miniatures", ""], "color": "#232", "bgcolor": "#353"}, {"id": 15, "type": "BrushNet", "pos": [1434, 215], "size": {"0": 315, "1": 226}, "flags": {}, "order": 14, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 21}, {"name": "vae", "type": "VAE", "link": 23}, {"name": "image", "type": "IMAGE", "link": 32}, {"name": "mask", "type": "MASK", "link": 48}, {"name": "brushnet", "type": "BRMODEL", "link": 20}, {"name": "positive", "type": "CONDITIONING", "link": 24}, {"name": "negative", "type": "CONDITIONING", "link": 25}], "outputs": [{"name": "model", "type": "MODEL", "links": [22], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [26], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [27], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [28], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 3, 10000]}, {"id": 9, "type": "KSampler", "pos": [1797, 212], "size": {"0": 315, "1": 262}, "flags": {}, "order": 15, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 22}, {"name": "positive", "type": "CONDITIONING", "link": 26}, {"name": "negative", "type": "CONDITIONING", "link": 27}, {"name": "latent_image", "type": "LATENT", "link": 28, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [16], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 15, 8, "euler_ancestral", "normal", 1]}], "links": [[3, 2, 0, 8, 1, "ELLA"], [4, 1, 0, 8, 0, "MODEL"], [7, 3, 0, 10, 1, "T5_TEXT_ENCODER"], [8, 1, 1, 10, 2, "CLIP"], [10, 1, 1, 11, 2, "CLIP"], [11, 3, 0, 11, 1, "T5_TEXT_ENCODER"], [14, 8, 0, 10, 0, "ELLA"], [15, 8, 0, 11, 0, "ELLA"], [16, 9, 0, 12, 0, "LATENT"], [17, 1, 2, 12, 1, "VAE"], [20, 16, 0, 15, 4, "BRMODEL"], [21, 1, 0, 15, 0, "MODEL"], [22, 15, 0, 9, 0, "MODEL"], [23, 1, 2, 15, 1, "VAE"], [24, 10, 0, 15, 5, "CONDITIONING"], [25, 11, 0, 15, 6, "CONDITIONING"], [26, 15, 1, 9, 1, "CONDITIONING"], [27, 15, 2, 9, 2, "CONDITIONING"], [28, 15, 3, 9, 3, "LATENT"], [31, 17, 0, 20, 0, "IMAGE"], [32, 20, 0, 15, 2, "IMAGE"], [43, 23, 0, 27, 0, "SAM_MODEL"], [44, 26, 0, 27, 1, "GROUNDING_DINO_MODEL"], [45, 20, 0, 27, 2, "IMAGE"], [46, 27, 0, 24, 0, "IMAGE"], [47, 27, 1, 28, 0, "MASK"], [48, 28, 0, 15, 3, "MASK"], [49, 20, 0, 29, 1, "IMAGE"], [50, 12, 0, 29, 0, "IMAGE"], [51, 28, 0, 29, 2, "MASK"], [53, 29, 0, 13, 0, "IMAGE"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_with_ELLA.png b/example/BrushNet_with_ELLA.png
new file mode 100644
index 0000000000000000000000000000000000000000..15bc3c46a7e4d8e85d3006448e03029d5e69b430
--- /dev/null
+++ b/example/BrushNet_with_ELLA.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a10a30feab84b57f22e419a1a53173022edda4b4e65f8c2301bc7dfaf81b6cb7
+size 1084598
diff --git a/example/BrushNet_with_IPA.json b/example/BrushNet_with_IPA.json
new file mode 100644
index 0000000000000000000000000000000000000000..072348de3c2126eb51240d25d631296410ad2a56
--- /dev/null
+++ b/example/BrushNet_with_IPA.json
@@ -0,0 +1 @@
+{"last_node_id": 64, "last_link_id": 137, "nodes": [{"id": 57, "type": "VAEDecode", "pos": [2009.6002197265625, 135.59999084472656], "size": {"0": 210, "1": 46}, "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 105}, {"name": "vae", "type": "VAE", "link": 107}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [106], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 12, "type": "PreviewImage", "pos": [1666, 438], "size": {"0": 523.5944213867188, "1": 547.4853515625}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 106}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 61, "type": "IPAdapterUnifiedLoader", "pos": [452, -96], "size": {"0": 315, "1": 78}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 116}, {"name": "ipadapter", "type": "IPADAPTER", "link": null}], "outputs": [{"name": "model", "type": "MODEL", "links": [117], "shape": 3, "slot_index": 0}, {"name": "ipadapter", "type": "IPADAPTER", "links": [115], "shape": 3}], "properties": {"Node name for S&R": "IPAdapterUnifiedLoader"}, "widgets_values": ["STANDARD (medium strength)"]}, {"id": 60, "type": "LoadImage", "pos": [65, -355], "size": {"0": 315, "1": 314}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [112], "shape": 3}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["ComfyUI_temp_mynbi_00021_ (1).png", "image"]}, {"id": 58, "type": "IPAdapter", "pos": [807, -100], "size": {"0": 315, "1": 190}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 117}, {"name": "ipadapter", "type": "IPADAPTER", "link": 115, "slot_index": 1}, {"name": "image", "type": "IMAGE", "link": 112, "slot_index": 2}, {"name": "attn_mask", "type": "MASK", "link": null}], "outputs": [{"name": "MODEL", "type": "MODEL", "links": [124], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "IPAdapter"}, "widgets_values": [1, 0, 1, "style transfer"]}, {"id": 50, "type": "CLIPTextEncode", "pos": [740, 373], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 114}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [127], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [-71, 21], "size": {"0": 481, "1": 158}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [116], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [114, 132], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [107, 131], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 49, "type": "CLIPTextEncode", "pos": [736, 215], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 132}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [126], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["glowing bowl"], "color": "#232", "bgcolor": "#353"}, {"id": 1, "type": "LoadImage", "pos": [101, 386], "size": {"0": 470.19439697265625, "1": 578.6854248046875}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [134], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image.jpg", "image"]}, {"id": 55, "type": "KSampler", "pos": [1628.000244140625, 69.19998931884766], "size": {"0": 315, "1": 262}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 125}, {"name": "positive", "type": "CONDITIONING", "link": 128}, {"name": "negative", "type": "CONDITIONING", "link": 129}, {"name": "latent_image", "type": "LATENT", "link": 130, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [105], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [1, "fixed", 20, 7, "euler_ancestral", "normal", 1]}, {"id": 63, "type": "LoadImageMask", "pos": [601, 634], "size": {"0": 315, "1": 318}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [136], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask (4).jpg", "red", "image"]}, {"id": 64, "type": "GrowMask", "pos": [946, 633], "size": {"0": 315, "1": 82}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "mask", "type": "MASK", "link": 136}], "outputs": [{"name": "MASK", "type": "MASK", "links": [137], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GrowMask"}, "widgets_values": [10, true]}, {"id": 62, "type": "BrushNet", "pos": [1209, 14], "size": {"0": 315, "1": 246}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 124}, {"name": "vae", "type": "VAE", "link": 131}, {"name": "image", "type": "IMAGE", "link": 134}, {"name": "mask", "type": "MASK", "link": 137}, {"name": "brushnet", "type": "BRMODEL", "link": 133}, {"name": "positive", "type": "CONDITIONING", "link": 126}, {"name": "negative", "type": "CONDITIONING", "link": 127}, {"name": "clip", "type": "PPCLIP", "link": null}], "outputs": [{"name": "model", "type": "MODEL", "links": [125], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [128], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [129], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [130], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}, {"id": 45, "type": "BrushNetLoader", "pos": [49, 238], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [133], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}], "links": [[105, 55, 0, 57, 0, "LATENT"], [106, 57, 0, 12, 0, "IMAGE"], [107, 47, 2, 57, 1, "VAE"], [112, 60, 0, 58, 2, "IMAGE"], [114, 47, 1, 50, 0, "CLIP"], [115, 61, 1, 58, 1, "IPADAPTER"], [116, 47, 0, 61, 0, "MODEL"], [117, 61, 0, 58, 0, "MODEL"], [124, 58, 0, 62, 0, "MODEL"], [125, 62, 0, 55, 0, "MODEL"], [126, 49, 0, 62, 5, "CONDITIONING"], [127, 50, 0, 62, 6, "CONDITIONING"], [128, 62, 1, 55, 1, "CONDITIONING"], [129, 62, 2, 55, 2, "CONDITIONING"], [130, 62, 3, 55, 3, "LATENT"], [131, 47, 2, 62, 1, "VAE"], [132, 47, 1, 49, 0, "CLIP"], [133, 45, 0, 62, 4, "BRMODEL"], [134, 1, 0, 62, 2, "IMAGE"], [136, 63, 0, 64, 0, "MASK"], [137, 64, 0, 62, 3, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_with_IPA.png b/example/BrushNet_with_IPA.png
new file mode 100644
index 0000000000000000000000000000000000000000..31db4f34cc326c77887c48a4285984a4b08ad319
--- /dev/null
+++ b/example/BrushNet_with_IPA.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f82c6420f6987eb82434c03ff3c0e600de47d9854dab8053becb5dbe7f98ac67
+size 2320823
diff --git a/example/BrushNet_with_LoRA.json b/example/BrushNet_with_LoRA.json
new file mode 100644
index 0000000000000000000000000000000000000000..cdbea6abd824815f9bc5547d79351f0f05a0b710
--- /dev/null
+++ b/example/BrushNet_with_LoRA.json
@@ -0,0 +1 @@
+{"last_node_id": 59, "last_link_id": 123, "nodes": [{"id": 57, "type": "VAEDecode", "pos": [2009.6002197265625, 135.59999084472656], "size": {"0": 210, "1": 46}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 105}, {"name": "vae", "type": "VAE", "link": 107}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [106], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 12, "type": "PreviewImage", "pos": [1666, 438], "size": {"0": 523.5944213867188, "1": 547.4853515625}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 106}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 55, "type": "KSampler", "pos": [1628.000244140625, 69.19998931884766], "size": {"0": 315, "1": 262}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 113}, {"name": "positive", "type": "CONDITIONING", "link": 114}, {"name": "negative", "type": "CONDITIONING", "link": 115}, {"name": "latent_image", "type": "LATENT", "link": 116, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [105], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 30, 7, "euler_ancestral", "normal", 1]}, {"id": 51, "type": "LoraLoader", "pos": [641, 43], "size": {"0": 315, "1": 126}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 82}, {"name": "clip", "type": "CLIP", "link": 83}], "outputs": [{"name": "MODEL", "type": "MODEL", "links": [117], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [94, 95], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoraLoader"}, "widgets_values": ["glasssculpture_v8.safetensors", 1, 1]}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [109, 40], "size": {"0": 481, "1": 158}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [82], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [83], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [107, 118], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 50, "type": "CLIPTextEncode", "pos": [883, 427], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 95}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [121], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 1, "type": "LoadImage", "pos": [101, 386], "size": {"0": 470.19439697265625, "1": 578.6854248046875}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [122], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image.jpg", "image"]}, {"id": 58, "type": "LoadImageMask", "pos": [611, 646], "size": {"0": 315, "1": 318.0000305175781}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "MASK", "type": "MASK", "links": [123], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "LoadImageMask"}, "widgets_values": ["test_mask (5).jpg", "red", "image"]}, {"id": 49, "type": "CLIPTextEncode", "pos": [886, 282], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 94}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [120], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["a glasssculpture of burger transparent, translucent, reflections"], "color": "#232", "bgcolor": "#353"}, {"id": 59, "type": "BrushNet", "pos": [1259, 61], "size": {"0": 315, "1": 246}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 117}, {"name": "vae", "type": "VAE", "link": 118}, {"name": "image", "type": "IMAGE", "link": 122}, {"name": "mask", "type": "MASK", "link": 123}, {"name": "brushnet", "type": "BRMODEL", "link": 119}, {"name": "positive", "type": "CONDITIONING", "link": 120}, {"name": "negative", "type": "CONDITIONING", "link": 121}, {"name": "clip", "type": "PPCLIP", "link": null}], "outputs": [{"name": "model", "type": "MODEL", "links": [113], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [114], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [115], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [116], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "BrushNet"}, "widgets_values": [1, 0, 10000]}, {"id": 45, "type": "BrushNetLoader", "pos": [49, 238], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [119], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["brushnet/random_mask.safetensors", "float16"]}], "links": [[82, 47, 0, 51, 0, "MODEL"], [83, 47, 1, 51, 1, "CLIP"], [94, 51, 1, 49, 0, "CLIP"], [95, 51, 1, 50, 0, "CLIP"], [105, 55, 0, 57, 0, "LATENT"], [106, 57, 0, 12, 0, "IMAGE"], [107, 47, 2, 57, 1, "VAE"], [113, 59, 0, 55, 0, "MODEL"], [114, 59, 1, 55, 1, "CONDITIONING"], [115, 59, 2, 55, 2, "CONDITIONING"], [116, 59, 3, 55, 3, "LATENT"], [117, 51, 0, 59, 0, "MODEL"], [118, 47, 2, 59, 1, "VAE"], [119, 45, 0, 59, 4, "BRMODEL"], [120, 49, 0, 59, 5, "CONDITIONING"], [121, 50, 0, 59, 6, "CONDITIONING"], [122, 1, 0, 59, 2, "IMAGE"], [123, 58, 0, 59, 3, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/BrushNet_with_LoRA.png b/example/BrushNet_with_LoRA.png
new file mode 100644
index 0000000000000000000000000000000000000000..a2d286e9f3f7331987d18fa24b24173bda6b8e4c
--- /dev/null
+++ b/example/BrushNet_with_LoRA.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a0bfb3a0e793b2ef28568e55be7c206a58bd958e439d95f903b0ebccfe3e5c3
+size 1900323
diff --git a/example/PowerPaint_object_removal.json b/example/PowerPaint_object_removal.json
new file mode 100644
index 0000000000000000000000000000000000000000..4fb1935d17e816fe7bd9a1a551a93bbea9c2386e
--- /dev/null
+++ b/example/PowerPaint_object_removal.json
@@ -0,0 +1 @@
+{"last_node_id": 78, "last_link_id": 164, "nodes": [{"id": 54, "type": "VAEDecode", "pos": [1921, 38], "size": {"0": 210, "1": 46}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 50, "type": "CLIPTextEncode", "pos": [651, 168], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [146], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 45, "type": "BrushNetLoader", "pos": [8, 251], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [148], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["powerpaint/diffusion_pytorch_model.safetensors"]}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [3, 44], "size": {"0": 481, "1": 158}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [139], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 151], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 52, "type": "KSampler", "pos": [1571, 117], "size": {"0": 315, "1": 262}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 138}, {"name": "positive", "type": "CONDITIONING", "link": 142}, {"name": "negative", "type": "CONDITIONING", "link": 143}, {"name": "latent_image", "type": "LATENT", "link": 144, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 7.5, "euler", "normal", 1]}, {"id": 65, "type": "PowerPaint", "pos": [1154, 136], "size": {"0": 315, "1": 294}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 139}, {"name": "vae", "type": "VAE", "link": 151}, {"name": "image", "type": "IMAGE", "link": 158}, {"name": "mask", "type": "MASK", "link": 164}, {"name": "powerpaint", "type": "BRMODEL", "link": 148}, {"name": "clip", "type": "CLIP", "link": 147}, {"name": "positive", "type": "CONDITIONING", "link": 145}, {"name": "negative", "type": "CONDITIONING", "link": 146}], "outputs": [{"name": "model", "type": "MODEL", "links": [138], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [142], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [143], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [144], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "PowerPaint"}, "widgets_values": [1, "object removal", 1, 0, 10000]}, {"id": 49, "type": "CLIPTextEncode", "pos": [649, 21], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [145], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["empty scene blur"], "color": "#232", "bgcolor": "#353"}, {"id": 58, "type": "LoadImage", "pos": [10, 404], "size": [542.1735076904297, 630.6464691162109], "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [158, 159], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image (1).jpg", "image"]}, {"id": 76, "type": "SAMModelLoader (segment anything)", "pos": [30, 1107], "size": {"0": 315, "1": 58}, "flags": {}, "order": 4, "mode": 0, "outputs": [{"name": "SAM_MODEL", "type": "SAM_MODEL", "links": [163], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "SAMModelLoader (segment anything)"}, "widgets_values": ["sam_vit_h (2.56GB)"]}, {"id": 74, "type": "GroundingDinoModelLoader (segment anything)", "pos": [384, 1105], "size": [401.77337646484375, 63.24662780761719], "flags": {}, "order": 5, "mode": 0, "outputs": [{"name": "GROUNDING_DINO_MODEL", "type": "GROUNDING_DINO_MODEL", "links": [160], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "GroundingDinoModelLoader (segment anything)"}, "widgets_values": ["GroundingDINO_SwinT_OGC (694MB)"]}, {"id": 12, "type": "PreviewImage", "pos": [1502, 455], "size": [552.7734985351562, 568.0465545654297], "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 75, "type": "GroundingDinoSAMSegment (segment anything)", "pos": [642, 587], "size": [368.77362060546875, 122], "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "sam_model", "type": "SAM_MODEL", "link": 163}, {"name": "grounding_dino_model", "type": "GROUNDING_DINO_MODEL", "link": 160}, {"name": "image", "type": "IMAGE", "link": 159}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": null, "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [164], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "GroundingDinoSAMSegment (segment anything)"}, "widgets_values": ["leaves", 0.3]}, {"id": 66, "type": "PowerPaintCLIPLoader", "pos": [654, 343], "size": {"0": 315, "1": 82}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "clip", "type": "CLIP", "links": [147], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "PowerPaintCLIPLoader"}, "widgets_values": ["model.fp16.safetensors", "powerpaint/pytorch_model.bin"]}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [138, 65, 0, 52, 0, "MODEL"], [139, 47, 0, 65, 0, "MODEL"], [142, 65, 1, 52, 1, "CONDITIONING"], [143, 65, 2, 52, 2, "CONDITIONING"], [144, 65, 3, 52, 3, "LATENT"], [145, 49, 0, 65, 6, "CONDITIONING"], [146, 50, 0, 65, 7, "CONDITIONING"], [147, 66, 0, 65, 5, "CLIP"], [148, 45, 0, 65, 4, "BRMODEL"], [151, 47, 2, 65, 1, "VAE"], [158, 58, 0, 65, 2, "IMAGE"], [159, 58, 0, 75, 2, "IMAGE"], [160, 74, 0, 75, 1, "GROUNDING_DINO_MODEL"], [163, 76, 0, 75, 0, "SAM_MODEL"], [164, 75, 1, 65, 3, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/PowerPaint_object_removal.png b/example/PowerPaint_object_removal.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a8b8bce958f7102ebb688d3fda180e66d2e67c7
--- /dev/null
+++ b/example/PowerPaint_object_removal.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be37c27e73e00de2291c59da162b66694f2de9cd5199c61503b636efa026618d
+size 1794003
diff --git a/example/PowerPaint_outpaint.json b/example/PowerPaint_outpaint.json
new file mode 100644
index 0000000000000000000000000000000000000000..1bfecfc7944bb4b27abd6466459cc91a80aa5ecd
--- /dev/null
+++ b/example/PowerPaint_outpaint.json
@@ -0,0 +1 @@
+{"last_node_id": 73, "last_link_id": 157, "nodes": [{"id": 54, "type": "VAEDecode", "pos": [1921, 38], "size": {"0": 210, "1": 46}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 91}, {"name": "vae", "type": "VAE", "link": 92}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [93], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 50, "type": "CLIPTextEncode", "pos": [651, 168], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 80}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [146], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 45, "type": "BrushNetLoader", "pos": [8, 251], "size": {"0": 576.2000122070312, "1": 104}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "brushnet", "type": "BRMODEL", "links": [148], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "BrushNetLoader"}, "widgets_values": ["powerpaint/diffusion_pytorch_model.safetensors"]}, {"id": 47, "type": "CheckpointLoaderSimple", "pos": [3, 44], "size": {"0": 481, "1": 158}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [139], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [78, 80], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [92, 151], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["realisticVisionV60B1_v51VAE.safetensors"]}, {"id": 58, "type": "LoadImage", "pos": [10, 404], "size": [542.1735076904297, 630.6464691162109], "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [152], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["test_image (1).jpg", "image"]}, {"id": 49, "type": "CLIPTextEncode", "pos": [649, 21], "size": {"0": 339.20001220703125, "1": 96.39999389648438}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 78}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [145], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["empty scene"], "color": "#232", "bgcolor": "#353"}, {"id": 66, "type": "PowerPaintCLIPLoader", "pos": [674, 345], "size": {"0": 315, "1": 82}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "clip", "type": "CLIP", "links": [147], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "PowerPaintCLIPLoader"}, "widgets_values": ["model.fp16.safetensors", "powerpaint/pytorch_model.bin"]}, {"id": 70, "type": "ImagePadForOutpaint", "pos": [678, 511], "size": {"0": 315, "1": 174}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 152}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [156], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": [157], "shape": 3, "slot_index": 1}], "properties": {"Node name for S&R": "ImagePadForOutpaint"}, "widgets_values": [200, 0, 200, 0, 0]}, {"id": 12, "type": "PreviewImage", "pos": [1213, 477], "size": [930.6534439086913, 553.5264953613282], "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 93}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 65, "type": "PowerPaint", "pos": [1154, 136], "size": {"0": 315, "1": 294}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 139}, {"name": "vae", "type": "VAE", "link": 151}, {"name": "image", "type": "IMAGE", "link": 156}, {"name": "mask", "type": "MASK", "link": 157}, {"name": "powerpaint", "type": "BRMODEL", "link": 148}, {"name": "clip", "type": "CLIP", "link": 147}, {"name": "positive", "type": "CONDITIONING", "link": 145}, {"name": "negative", "type": "CONDITIONING", "link": 146}], "outputs": [{"name": "model", "type": "MODEL", "links": [138], "shape": 3, "slot_index": 0}, {"name": "positive", "type": "CONDITIONING", "links": [142], "shape": 3, "slot_index": 1}, {"name": "negative", "type": "CONDITIONING", "links": [143], "shape": 3, "slot_index": 2}, {"name": "latent", "type": "LATENT", "links": [144], "shape": 3, "slot_index": 3}], "properties": {"Node name for S&R": "PowerPaint"}, "widgets_values": [1, "image outpainting", 1, 0, 10000]}, {"id": 52, "type": "KSampler", "pos": [1571, 117], "size": {"0": 315, "1": 262}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 138}, {"name": "positive", "type": "CONDITIONING", "link": 142}, {"name": "negative", "type": "CONDITIONING", "link": 143}, {"name": "latent_image", "type": "LATENT", "link": 144, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [91], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 20, 7.5, "euler", "normal", 1]}], "links": [[78, 47, 1, 49, 0, "CLIP"], [80, 47, 1, 50, 0, "CLIP"], [91, 52, 0, 54, 0, "LATENT"], [92, 47, 2, 54, 1, "VAE"], [93, 54, 0, 12, 0, "IMAGE"], [138, 65, 0, 52, 0, "MODEL"], [139, 47, 0, 65, 0, "MODEL"], [142, 65, 1, 52, 1, "CONDITIONING"], [143, 65, 2, 52, 2, "CONDITIONING"], [144, 65, 3, 52, 3, "LATENT"], [145, 49, 0, 65, 6, "CONDITIONING"], [146, 50, 0, 65, 7, "CONDITIONING"], [147, 66, 0, 65, 5, "CLIP"], [148, 45, 0, 65, 4, "BRMODEL"], [151, 47, 2, 65, 1, "VAE"], [152, 58, 0, 70, 0, "IMAGE"], [156, 70, 0, 65, 2, "IMAGE"], [157, 70, 1, 65, 3, "MASK"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/PowerPaint_outpaint.png b/example/PowerPaint_outpaint.png
new file mode 100644
index 0000000000000000000000000000000000000000..22927b47a3c67db8022b9de4fb9640b673879e0c
--- /dev/null
+++ b/example/PowerPaint_outpaint.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8601b20a7709c03c39b33a8cccd1ef5ba45843926cdd382d58aeaca4a51f8351
+size 2176269
diff --git a/example/RAUNet1.png b/example/RAUNet1.png
new file mode 100644
index 0000000000000000000000000000000000000000..304632498f6f74fb4957b93095d50b30fcd722f8
--- /dev/null
+++ b/example/RAUNet1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:23e34063979e843557034a65abd14b76fd1ecd05be3551c25c5c2a32ed21888f
+size 1813441
diff --git a/example/RAUNet2.png b/example/RAUNet2.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ab3f0e4f9abd8b4cf80d6ab44527c42aa892403
--- /dev/null
+++ b/example/RAUNet2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1422246f3c2c42bb15c247024f5274792a827eff69f2b3854d537038d416e9b2
+size 1875458
diff --git a/example/RAUNet_basic.json b/example/RAUNet_basic.json
new file mode 100644
index 0000000000000000000000000000000000000000..ef2c07d006d264ad29990a7bc3ada41d413a9921
--- /dev/null
+++ b/example/RAUNet_basic.json
@@ -0,0 +1 @@
+{"last_node_id": 26, "last_link_id": 48, "nodes": [{"id": 7, "type": "KSamplerAdvanced", "pos": [1281, 461], "size": {"0": 315, "1": 334}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 48}, {"name": "positive", "type": "CONDITIONING", "link": 8}, {"name": "negative", "type": "CONDITIONING", "link": 9}, {"name": "latent_image", "type": "LATENT", "link": 10, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [22], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSamplerAdvanced"}, "widgets_values": ["disable", 0, "fixed", 25, 8, "ddpm", "normal", 0, 10000, "disable"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 23, "type": "KSamplerAdvanced", "pos": [1280, 872], "size": {"0": 315, "1": 334}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 46}, {"name": "positive", "type": "CONDITIONING", "link": 45}, {"name": "negative", "type": "CONDITIONING", "link": 43}, {"name": "latent_image", "type": "LATENT", "link": 44, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [42], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSamplerAdvanced"}, "widgets_values": ["disable", 0, "fixed", 25, 8, "ddpm", "normal", 0, 10000, "disable"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 1, "type": "CheckpointLoaderSimple", "pos": [452, 461], "size": {"0": 320.2000732421875, "1": 108.99996948242188}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [46, 47], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [1, 6], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [23, 40], "shape": 3, "slot_index": 2}], "title": "Load Base Checkpoint", "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["SDXL/zavychromaxl_v70.safetensors"], "color": "#2a363b", "bgcolor": "#3f5159"}, {"id": 5, "type": "CLIPTextEncodeSDXL", "pos": [854, 844], "size": [319.27423095703125, 311.4324369430542], "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 6}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [9, 43], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncodeSDXL"}, "widgets_values": [4096, 4096, 0, 0, 1024, 1024, "ugly, deformed, noisy, low poly, blurry, text, duplicate, poorly drawn, mosaic", "ugly, deformed, noisy, low poly, blurry, text, duplicate, poorly drawn, mosaic"], "color": "#322", "bgcolor": "#533"}, {"id": 8, "type": "EmptyLatentImage", "pos": [868, 1205], "size": [269.2342041015627, 106], "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "LATENT", "type": "LATENT", "links": [10, 44], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "EmptyLatentImage"}, "widgets_values": [2048, 2048, 1]}, {"id": 2, "type": "CLIPTextEncodeSDXL", "pos": [852, 457], "size": [325.67423095703134, 332.3523832321167], "flags": {}, "order": 3, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 1}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [8, 45], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncodeSDXL"}, "widgets_values": [4096, 4096, 0, 0, 1024, 1024, "an exotic fox, cute, chibi kawaii. detailed fur, hyperdetailed, big reflective eyes", "thick strokes, bright colors, fairytale, artstation,centered composition, perfect composition, centered, vibrant colors, muted colors, high detailed, 8k"], "color": "#232", "bgcolor": "#353"}, {"id": 25, "type": "PreviewImage", "pos": [1628, 460], "size": {"0": 650.7540893554688, "1": 766.8323974609375}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 41}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 24, "type": "VAEDecode", "pos": [1683, 350], "size": {"0": 210, "1": 46}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 42, "slot_index": 0}, {"name": "vae", "type": "VAE", "link": 40, "slot_index": 1}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [41], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 14, "type": "VAEDecode", "pos": [1985, 347], "size": {"0": 210, "1": 46}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 22}, {"name": "vae", "type": "VAE", "link": 23}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [24], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 15, "type": "PreviewImage", "pos": [2298, 457], "size": [650.7540725708009, 766.8323699951172], "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 24}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 26, "type": "RAUNet", "pos": [857, 270], "size": {"0": 315, "1": 130}, "flags": {}, "order": 2, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 47}], "outputs": [{"name": "model", "type": "MODEL", "links": [48], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "RAUNet"}, "widgets_values": [0, 2, 2, 6]}], "links": [[1, 1, 1, 2, 0, "CLIP"], [6, 1, 1, 5, 0, "CLIP"], [8, 2, 0, 7, 1, "CONDITIONING"], [9, 5, 0, 7, 2, "CONDITIONING"], [10, 8, 0, 7, 3, "LATENT"], [22, 7, 0, 14, 0, "LATENT"], [23, 1, 2, 14, 1, "VAE"], [24, 14, 0, 15, 0, "IMAGE"], [40, 1, 2, 24, 1, "VAE"], [41, 24, 0, 25, 0, "IMAGE"], [42, 23, 0, 24, 0, "LATENT"], [43, 5, 0, 23, 2, "CONDITIONING"], [44, 8, 0, 23, 3, "LATENT"], [45, 2, 0, 23, 1, "CONDITIONING"], [46, 1, 0, 23, 0, "MODEL"], [47, 1, 0, 26, 0, "MODEL"], [48, 26, 0, 7, 0, "MODEL"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/RAUNet_with_CN.json b/example/RAUNet_with_CN.json
new file mode 100644
index 0000000000000000000000000000000000000000..e7f97cc1c9270aeeb60057c7482ec760759b06e7
--- /dev/null
+++ b/example/RAUNet_with_CN.json
@@ -0,0 +1 @@
+{"last_node_id": 20, "last_link_id": 34, "nodes": [{"id": 5, "type": "VAEDecode", "pos": [1916.4395019531253, 183.40589904785156], "size": {"0": 210, "1": 46}, "flags": {}, "order": 12, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 6}, {"name": "vae", "type": "VAE", "link": 8, "slot_index": 1}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [7], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 10, "type": "ControlNetLoader", "pos": [229, -378], "size": {"0": 432.609130859375, "1": 78.54664611816406}, "flags": {}, "order": 0, "mode": 0, "outputs": [{"name": "CONTROL_NET", "type": "CONTROL_NET", "links": [15], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ControlNetLoader"}, "widgets_values": ["control_canny-fp16.safetensors"]}, {"id": 11, "type": "LoadImage", "pos": [230, -237], "size": {"0": 393.4891357421875, "1": 460.0666809082031}, "flags": {}, "order": 1, "mode": 0, "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [18], "shape": 3, "slot_index": 0}, {"name": "MASK", "type": "MASK", "links": null, "shape": 3}], "properties": {"Node name for S&R": "LoadImage"}, "widgets_values": ["fox_with_sword (1).png", "image"]}, {"id": 2, "type": "CLIPTextEncode", "pos": [713, 274], "size": {"0": 392.8395080566406, "1": 142.005859375}, "flags": {}, "order": 6, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 1}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [16], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["black fox with big sword, standing in a small town cute, chibi kawaii. fairytale, artstation, centered composition, perfect composition, centered, vibrant colors, muted colors, high detailed, 8k"], "color": "#232", "bgcolor": "#353"}, {"id": 16, "type": "VAEDecode", "pos": [1897, -612], "size": {"0": 210, "1": 46}, "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "samples", "type": "LATENT", "link": 24}, {"name": "vae", "type": "VAE", "link": 26, "slot_index": 1}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [25], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "VAEDecode"}}, {"id": 9, "type": "ControlNetApply", "pos": [1153, 111], "size": {"0": 317.4000244140625, "1": 98}, "flags": {}, "order": 9, "mode": 0, "inputs": [{"name": "conditioning", "type": "CONDITIONING", "link": 16}, {"name": "control_net", "type": "CONTROL_NET", "link": 15}, {"name": "image", "type": "IMAGE", "link": 19}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [17, 27], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "ControlNetApply"}, "widgets_values": [0.8]}, {"id": 3, "type": "CLIPTextEncode", "pos": [714, 488], "size": {"0": 399.75958251953125, "1": 111.60586547851562}, "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "clip", "type": "CLIP", "link": 2}], "outputs": [{"name": "CONDITIONING", "type": "CONDITIONING", "links": [5, 28], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["hat, text, blurry, ugly, duplicate, poorly drawn, deformed, mosaic"], "color": "#322", "bgcolor": "#533"}, {"id": 1, "type": "CheckpointLoaderSimple", "pos": [199, 295], "size": {"0": 440.8395080566406, "1": 99.80586242675781}, "flags": {}, "order": 2, "mode": 0, "outputs": [{"name": "MODEL", "type": "MODEL", "links": [23, 29], "shape": 3, "slot_index": 0}, {"name": "CLIP", "type": "CLIP", "links": [1, 2], "shape": 3, "slot_index": 1}, {"name": "VAE", "type": "VAE", "links": [8, 26], "shape": 3, "slot_index": 2}], "properties": {"Node name for S&R": "CheckpointLoaderSimple"}, "widgets_values": ["SD15/revAnimated_v2Rebirth.safetensors"]}, {"id": 13, "type": "PreviewImage", "pos": [1057, -320], "size": {"0": 210, "1": 246}, "flags": {}, "order": 8, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 20}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 6, "type": "PreviewImage", "pos": [1883, 277], "size": {"0": 623.648193359375, "1": 645.5486450195312}, "flags": {}, "order": 14, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 7}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 17, "type": "PreviewImage", "pos": [1897, -508], "size": {"0": 602.8720703125, "1": 630.0126953125}, "flags": {}, "order": 15, "mode": 0, "inputs": [{"name": "images", "type": "IMAGE", "link": 25}], "properties": {"Node name for S&R": "PreviewImage"}}, {"id": 12, "type": "CannyEdgePreprocessor", "pos": [689, -123], "size": {"0": 315, "1": 106}, "flags": {}, "order": 4, "mode": 0, "inputs": [{"name": "image", "type": "IMAGE", "link": 18}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": [19, 20], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "CannyEdgePreprocessor"}, "widgets_values": [100, 200, 1024]}, {"id": 18, "type": "RAUNet", "pos": [1113, -500], "size": {"0": 315, "1": 130}, "flags": {}, "order": 5, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 29}], "outputs": [{"name": "model", "type": "MODEL", "links": [30], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "RAUNet"}, "widgets_values": [0, 2, 2, 8]}, {"id": 4, "type": "KSampler", "pos": [1529, 115], "size": {"0": 315, "1": 262}, "flags": {}, "order": 10, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 23}, {"name": "positive", "type": "CONDITIONING", "link": 17}, {"name": "negative", "type": "CONDITIONING", "link": 5}, {"name": "latent_image", "type": "LATENT", "link": 9, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [6], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 25, 8, "euler_ancestral", "normal", 1]}, {"id": 15, "type": "KSampler", "pos": [1523, -441], "size": {"0": 315, "1": 262}, "flags": {}, "order": 11, "mode": 0, "inputs": [{"name": "model", "type": "MODEL", "link": 30}, {"name": "positive", "type": "CONDITIONING", "link": 27}, {"name": "negative", "type": "CONDITIONING", "link": 28}, {"name": "latent_image", "type": "LATENT", "link": 31, "slot_index": 3}], "outputs": [{"name": "LATENT", "type": "LATENT", "links": [24], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "KSampler"}, "widgets_values": [0, "fixed", 25, 8, "euler_ancestral", "normal", 1]}, {"id": 7, "type": "EmptyLatentImage", "pos": [797, 652], "size": {"0": 315, "1": 106}, "flags": {}, "order": 3, "mode": 0, "outputs": [{"name": "LATENT", "type": "LATENT", "links": [9, 31], "shape": 3, "slot_index": 0}], "properties": {"Node name for S&R": "EmptyLatentImage"}, "widgets_values": [1024, 1024, 1]}], "links": [[1, 1, 1, 2, 0, "CLIP"], [2, 1, 1, 3, 0, "CLIP"], [5, 3, 0, 4, 2, "CONDITIONING"], [6, 4, 0, 5, 0, "LATENT"], [7, 5, 0, 6, 0, "IMAGE"], [8, 1, 2, 5, 1, "VAE"], [9, 7, 0, 4, 3, "LATENT"], [15, 10, 0, 9, 1, "CONTROL_NET"], [16, 2, 0, 9, 0, "CONDITIONING"], [17, 9, 0, 4, 1, "CONDITIONING"], [18, 11, 0, 12, 0, "IMAGE"], [19, 12, 0, 9, 2, "IMAGE"], [20, 12, 0, 13, 0, "IMAGE"], [23, 1, 0, 4, 0, "MODEL"], [24, 15, 0, 16, 0, "LATENT"], [25, 16, 0, 17, 0, "IMAGE"], [26, 1, 2, 16, 1, "VAE"], [27, 9, 0, 15, 1, "CONDITIONING"], [28, 3, 0, 15, 2, "CONDITIONING"], [29, 1, 0, 18, 0, "MODEL"], [30, 18, 0, 15, 0, "MODEL"], [31, 7, 0, 15, 3, "LATENT"]], "groups": [], "config": {}, "extra": {}, "version": 0.4}
\ No newline at end of file
diff --git a/example/goblin_toy.png b/example/goblin_toy.png
new file mode 100644
index 0000000000000000000000000000000000000000..e97bc003a1fafb5295a5bcc64414fb95c2919ee3
--- /dev/null
+++ b/example/goblin_toy.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f70ac70ce8cb28470e487668f780bb5458027eb6c01828d47fe4f76330a478d3
+size 1158017
diff --git a/example/inpaint_folder.png b/example/inpaint_folder.png
new file mode 100644
index 0000000000000000000000000000000000000000..3cdd0acc99de4591fc664b29cb3d66f7b3f3b54e
Binary files /dev/null and b/example/inpaint_folder.png differ
diff --git a/example/object_removal.png b/example/object_removal.png
new file mode 100644
index 0000000000000000000000000000000000000000..70fe13fa9427b9dd38e1647b4a9cae388fae323c
--- /dev/null
+++ b/example/object_removal.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14e5fc2a433607d5dbe986ba90e5675854f14bcebdaa8fac8875ca46ca9b22c4
+size 2119757
diff --git a/example/object_removal_fail.png b/example/object_removal_fail.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc5170281c0a4c1f51a8b47ea2f9cc02012e9f54
--- /dev/null
+++ b/example/object_removal_fail.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb2f277671b652b5f68ac90ca11abf8f926bb8137273e97a8732b73e7c395268
+size 1568702
diff --git a/example/params1.png b/example/params1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0aa561ca902d56411ad6953dace74496ecfc41de
--- /dev/null
+++ b/example/params1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7bc123ca0317f7b608faf12368b584dfd2edad028765724a8386d3481ab3c739
+size 1083627
diff --git a/example/params10.png b/example/params10.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c122d00c5378c2c6a1733b9f1b22a23c736c3a0
Binary files /dev/null and b/example/params10.png differ
diff --git a/example/params11.png b/example/params11.png
new file mode 100644
index 0000000000000000000000000000000000000000..19ca280dc6683530cb5d53eccc7254a4d53325cf
Binary files /dev/null and b/example/params11.png differ
diff --git a/example/params12.png b/example/params12.png
new file mode 100644
index 0000000000000000000000000000000000000000..26aa22fc5cc9b50cf6bc1b795d18df74fcc85bc5
Binary files /dev/null and b/example/params12.png differ
diff --git a/example/params13.png b/example/params13.png
new file mode 100644
index 0000000000000000000000000000000000000000..f1e22d38cccb542b28ccf573075abeb4efbd629d
--- /dev/null
+++ b/example/params13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78765e50789603a67600dcfaa014710d587f246a2a83240f34fbf62ecf2e653c
+size 1708746
diff --git a/example/params14.png b/example/params14.png
new file mode 100644
index 0000000000000000000000000000000000000000..a701e04d0c137cb786c4d3c949832f274979a07c
Binary files /dev/null and b/example/params14.png differ
diff --git a/example/params15.png b/example/params15.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b215d37e52dcc69e716cc0b567ff68edee45090
Binary files /dev/null and b/example/params15.png differ
diff --git a/example/params16.png b/example/params16.png
new file mode 100644
index 0000000000000000000000000000000000000000..3c031613dae3f5af6bfc777d6809eb3bab504cef
Binary files /dev/null and b/example/params16.png differ
diff --git a/example/params17.png b/example/params17.png
new file mode 100644
index 0000000000000000000000000000000000000000..2fd0170fc40757ca5319d29bb99f65eca86497ec
Binary files /dev/null and b/example/params17.png differ
diff --git a/example/params18.png b/example/params18.png
new file mode 100644
index 0000000000000000000000000000000000000000..c736fbc589a1d08ccb5706bc2d3117310c06011f
Binary files /dev/null and b/example/params18.png differ
diff --git a/example/params19.png b/example/params19.png
new file mode 100644
index 0000000000000000000000000000000000000000..2eee10dfc5afd3d89381cc81e328f0868b108d8d
Binary files /dev/null and b/example/params19.png differ
diff --git a/example/params2.png b/example/params2.png
new file mode 100644
index 0000000000000000000000000000000000000000..79e7e540337f3ad1a53e126b16cef9a4de19a6d2
Binary files /dev/null and b/example/params2.png differ
diff --git a/example/params20.png b/example/params20.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c9f56eff6f7c73b9ebffef3ff34fb095f37f9b5
Binary files /dev/null and b/example/params20.png differ
diff --git a/example/params21.png b/example/params21.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c2f540f9aa1ddd42dbc9d02b5757936744c1743
Binary files /dev/null and b/example/params21.png differ
diff --git a/example/params22.png b/example/params22.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e7b0b59ca79c7a2f7a671bda2ddf5a378bc6bc4
Binary files /dev/null and b/example/params22.png differ
diff --git a/example/params3.png b/example/params3.png
new file mode 100644
index 0000000000000000000000000000000000000000..336b16b8f032c11ee5fb62e73f96b567071ed6aa
Binary files /dev/null and b/example/params3.png differ
diff --git a/example/params4.png b/example/params4.png
new file mode 100644
index 0000000000000000000000000000000000000000..833f3c1a1a0fb66f8fc1fbc2f71e0bc775c76d1d
Binary files /dev/null and b/example/params4.png differ
diff --git a/example/params5.png b/example/params5.png
new file mode 100644
index 0000000000000000000000000000000000000000..5e5faed77537306b3bce9e0c6f50a66b71cf8a35
Binary files /dev/null and b/example/params5.png differ
diff --git a/example/params6.png b/example/params6.png
new file mode 100644
index 0000000000000000000000000000000000000000..9358f29b9eaf33a70c6c14aece6022d64936f800
Binary files /dev/null and b/example/params6.png differ
diff --git a/example/params7.png b/example/params7.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb967307547609a8fa08f779a35ada6c2673603a
Binary files /dev/null and b/example/params7.png differ
diff --git a/example/params8.png b/example/params8.png
new file mode 100644
index 0000000000000000000000000000000000000000..d17c0b5ee5a998bdb644d30fd5b207ba89baeb03
Binary files /dev/null and b/example/params8.png differ
diff --git a/example/params9.png b/example/params9.png
new file mode 100644
index 0000000000000000000000000000000000000000..2d85e86d1ecc17927cea32fb3714a07edbd6da15
Binary files /dev/null and b/example/params9.png differ
diff --git a/example/sleepeng_cat_ce.png b/example/sleepeng_cat_ce.png
new file mode 100644
index 0000000000000000000000000000000000000000..773520a9c5c8328daeba9f7f0f23239f371d9f64
Binary files /dev/null and b/example/sleepeng_cat_ce.png differ
diff --git a/example/sleeping_cat.png b/example/sleeping_cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..555253cbf795310b7b52823af17410eac93b7418
Binary files /dev/null and b/example/sleeping_cat.png differ
diff --git a/example/sleeping_cat_inpaint1.png b/example/sleeping_cat_inpaint1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5a20e2502db4d41e1dff14c77fed2981811409b
--- /dev/null
+++ b/example/sleeping_cat_inpaint1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:addbd6377ea207ffa0822c27b7b0a6c9607952e2b176c5553812f10decaab17b
+size 1501010
diff --git a/example/sleeping_cat_inpaint2.png b/example/sleeping_cat_inpaint2.png
new file mode 100644
index 0000000000000000000000000000000000000000..1e686c8abc3b73f9ec82cb3f5d147b48edf4e0b3
Binary files /dev/null and b/example/sleeping_cat_inpaint2.png differ
diff --git a/example/sleeping_cat_inpaint3.png b/example/sleeping_cat_inpaint3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a5c2a021f7f94ea0231feb1a6cf84da34fd948d
--- /dev/null
+++ b/example/sleeping_cat_inpaint3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcf46ecf61e188ef8780d0974ca3c172aebdf943cf052fc29ae8899afae6893e
+size 1490776
diff --git a/example/sleeping_cat_inpaint4.png b/example/sleeping_cat_inpaint4.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2b8d474b8cbcf653f612a9a3dc0cc9023cba9cf
Binary files /dev/null and b/example/sleeping_cat_inpaint4.png differ
diff --git a/example/sleeping_cat_inpaint5.png b/example/sleeping_cat_inpaint5.png
new file mode 100644
index 0000000000000000000000000000000000000000..e345cc9200cd4383474312a6053b146cda1ab62f
--- /dev/null
+++ b/example/sleeping_cat_inpaint5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb2e8b479c62808cc657a322d0bcd2d4828ab5ab687c90c08c565fa7447aecb9
+size 1526816
diff --git a/example/sleeping_cat_inpaint6.png b/example/sleeping_cat_inpaint6.png
new file mode 100644
index 0000000000000000000000000000000000000000..e37d2f8ab163bfa84a2a6e0e14394ec85ac46bea
--- /dev/null
+++ b/example/sleeping_cat_inpaint6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3de0513493326f434ac0a18fcad95371fbe4ade85e12208bbd6ab4e1b8cac097
+size 1500036
diff --git a/example/sleeping_cat_mask.png b/example/sleeping_cat_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..f1ba4d476fac7d486a1a4130e435e3b15f10f8c3
Binary files /dev/null and b/example/sleeping_cat_mask.png differ
diff --git a/example/test_cn.png b/example/test_cn.png
new file mode 100644
index 0000000000000000000000000000000000000000..deeb02f122b588716c7acc200441409493dc800a
Binary files /dev/null and b/example/test_cn.png differ
diff --git a/example/test_image.jpg b/example/test_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dba55ac86d3368caa93c0b382b6ec04bafad60af
Binary files /dev/null and b/example/test_image.jpg differ
diff --git a/example/test_image2.png b/example/test_image2.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab05e4d9a303f294fe4074801f3c0506fee936d3
Binary files /dev/null and b/example/test_image2.png differ
diff --git a/example/test_image3.png b/example/test_image3.png
new file mode 100644
index 0000000000000000000000000000000000000000..772edf145b27656d3d26f14ab08d7e803794777b
--- /dev/null
+++ b/example/test_image3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:adf7d9df10c530ca9104fe5c94d796b70be4d8459a12a0662c4310171dce9f50
+size 1530331
diff --git a/example/test_mask2.png b/example/test_mask2.png
new file mode 100644
index 0000000000000000000000000000000000000000..91611341936466bd9142fa7bdafe7a2e40fd144c
Binary files /dev/null and b/example/test_mask2.png differ
diff --git a/example/test_mask3.png b/example/test_mask3.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b30cb0077406a1face8918bfa7976e93e5325cc
Binary files /dev/null and b/example/test_mask3.png differ
diff --git a/example/test_mask4.png b/example/test_mask4.png
new file mode 100644
index 0000000000000000000000000000000000000000..ef0c7811848111beded51ab4e6cd12c3a0332764
Binary files /dev/null and b/example/test_mask4.png differ
diff --git a/model_patch.py b/model_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe55e167c7f717662983a3ff38bf13977bf7fac
--- /dev/null
+++ b/model_patch.py
@@ -0,0 +1,144 @@
+import torch
+import comfy
+
+
+# Check and add 'model_patch' to model.model_options['transformer_options']
+def add_model_patch_option(model):
+ if 'transformer_options' not in model.model_options:
+ model.model_options['transformer_options'] = {}
+ to = model.model_options['transformer_options']
+ if "model_patch" not in to:
+ to["model_patch"] = {}
+ return to
+
+
+# Patch model with model_function_wrapper
+def patch_model_function_wrapper(model, forward_patch, remove=False):
+
+ def brushnet_model_function_wrapper(apply_model_method, options_dict):
+ to = options_dict['c']['transformer_options']
+
+ control = None
+ if 'control' in options_dict['c']:
+ control = options_dict['c']['control']
+
+ x = options_dict['input']
+ timestep = options_dict['timestep']
+
+ # check if there are patches to execute
+ if 'model_patch' not in to or 'forward' not in to['model_patch']:
+ return apply_model_method(x, timestep, **options_dict['c'])
+
+ mp = to['model_patch']
+ unet = mp['unet']
+
+
+
+ #print(model.get_model_object("model_sampling").sigmas, len(model.get_model_object("model_sampling").sigmas))
+ #print(mp['all_sigmas'], len(mp['all_sigmas']))
+
+
+ all_sigmas = mp['all_sigmas']
+ sigma = to['sigmas'][0].item()
+ total_steps = all_sigmas.shape[0] - 1
+ step = torch.argmin((all_sigmas - sigma).abs()).item()
+
+ mp['step'] = step
+ mp['total_steps'] = total_steps
+
+ # comfy.model_base.apply_model
+ xc = model.model.model_sampling.calculate_input(timestep, x)
+ if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
+ xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
+ t = model.model.model_sampling.timestep(timestep).float()
+ # execute all patches
+ for method in mp['forward']:
+ method(unet, xc, t, to, control)
+
+ return apply_model_method(x, timestep, **options_dict['c'])
+
+ if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
+ print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"])
+ model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
+
+ to = add_model_patch_option(model)
+ mp = to['model_patch']
+
+ if isinstance(model.model.model_config, comfy.supported_models.SD15):
+ mp['SDXL'] = False
+ elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
+ mp['SDXL'] = True
+ else:
+ print('Base model type: ', type(model.model.model_config))
+ raise Exception("Unsupported model type: ", type(model.model.model_config))
+
+ if 'forward' not in mp:
+ mp['forward'] = []
+
+ if remove:
+ if forward_patch in mp['forward']:
+ mp['forward'].remove(forward_patch)
+ else:
+ mp['forward'].append(forward_patch)
+
+ mp['unet'] = model.model.diffusion_model
+ mp['step'] = 0
+ mp['total_steps'] = 1
+
+ # apply patches to code
+ if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
+ comfy.samplers.original_sample = comfy.samplers.sample
+ comfy.samplers.sample = modified_sample
+
+ if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
+ 'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
+ comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
+ comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
+
+
+# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
+# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
+def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
+ latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
+ '''
+ Modified by BrushNet nodes
+ '''
+ cfg_guider = comfy.samplers.CFGGuider(model)
+ cfg_guider.set_conds(positive, negative)
+ cfg_guider.set_cfg(cfg)
+
+ ### Modified part ######################################################################
+ #
+ to = add_model_patch_option(model)
+ to['model_patch']['all_sigmas'] = sigmas
+ #
+ #sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
+ #sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
+ #
+ #
+ #if math.isclose(cfg, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
+ # to['model_patch']['free_guidance'] = False
+ #else:
+ # to['model_patch']['free_guidance'] = True
+ #
+ #######################################################################################
+
+ return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
+
+
+# To use Controlnet with RAUNet it is much easier to modify apply_control a little
+def modified_apply_control(h, control, name):
+ '''
+ Modified by BrushNet nodes
+ '''
+ if control is not None and name in control and len(control[name]) > 0:
+ ctrl = control[name].pop()
+ if ctrl is not None:
+ if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
+ ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device)
+ try:
+ h += ctrl
+ except:
+ print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
+ return h
+
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..32b86a16c239a0017cb4a8deb1126f9ea406ba0a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,15 @@
+[project]
+name = "comfyui-brushnet"
+description = "These are custom nodes for ComfyUI native implementation of BrushNet (inpaint), PowerPaint (inpaint, object removal) and HiDiffusion (higher resolution for SD15 and SDXL)"
+version = "1.0.1"
+license = { text = "Apache License 2.0" }
+dependencies = ["diffusers>=0.27.0", "accelerate>=0.29.0", "peft>=0.7.0"]
+
+[project.urls]
+Repository = "https://github.com/nullquant/ComfyUI-BrushNet"
+# Used by Comfy Registry https://comfyregistry.org
+
+[tool.comfy]
+PublisherId = "nullquant"
+DisplayName = "ComfyUI-BrushNet"
+Icon = ""
diff --git a/raunet_nodes.py b/raunet_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c1b0d66708eded76997bd25a1f693922ad49bad
--- /dev/null
+++ b/raunet_nodes.py
@@ -0,0 +1,158 @@
+import torch.nn.functional as F
+import comfy
+
+from .model_patch import add_model_patch_option, patch_model_function_wrapper
+
+
+
+class RAUNet:
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {
+ "model": ("MODEL",),
+ "du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
+ "du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
+ "xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
+ "xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
+ },
+ }
+
+ CATEGORY = "inpaint"
+ RETURN_TYPES = ("MODEL",)
+ RETURN_NAMES = ("model",)
+
+ FUNCTION = "model_update"
+
+ def model_update(self, model, du_start, du_end, xa_start, xa_end):
+
+ model = model.clone()
+
+ add_raunet_patch(model,
+ du_start,
+ du_end,
+ xa_start,
+ xa_end)
+
+ return (model,)
+
+
+# This is main patch function
+def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
+
+ def raunet_forward(model, x, timesteps, transformer_options, control):
+ if 'model_patch' not in transformer_options:
+ print("RAUNet: 'model_patch' not in transformer_options, skip")
+ return
+
+ mp = transformer_options['model_patch']
+ is_SDXL = mp['SDXL']
+
+ if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
+ print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
+ return
+
+ if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
+ print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
+ return
+
+ if 'raunet' not in mp:
+ print('RAUNet: "raunet" not in model_patch options, skip')
+ return
+
+ if is_SDXL:
+ block = model.input_blocks[6][0]
+ else:
+ block = model.input_blocks[3][0]
+
+ total_steps = mp['total_steps']
+ step = mp['step']
+
+ ro = mp['raunet']
+ du_start = ro['du_start']
+ du_end = ro['du_end']
+
+ if step >= du_start and step < du_end:
+ block.op.stride = (4, 4)
+ block.op.padding = (2, 2)
+ block.op.dilation = (2, 2)
+ else:
+ block.op.stride = (2, 2)
+ block.op.padding = (1, 1)
+ block.op.dilation = (1, 1)
+
+ patch_model_function_wrapper(model, raunet_forward)
+ model.set_model_input_block_patch(in_xattn_patch)
+ model.set_model_output_block_patch(out_xattn_patch)
+
+ to = add_model_patch_option(model)
+ mp = to['model_patch']
+ if 'raunet' not in mp:
+ mp['raunet'] = {}
+ ro = mp['raunet']
+
+ ro['du_start'] = du_start
+ ro['du_end'] = du_end
+ ro['xa_start'] = xa_start
+ ro['xa_end'] = xa_end
+
+
+def in_xattn_patch(h, transformer_options):
+ # both SDXL and SD15 = (input,4)
+ if transformer_options["block"] != ("input", 4):
+ # wrong block
+ return h
+ if 'model_patch' not in transformer_options:
+ print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
+ return h
+ mp = transformer_options['model_patch']
+ if 'raunet' not in mp:
+ print("RAUNet (i-x-p): 'raunet' not in model_patch options")
+ return h
+
+ step = mp['step']
+ ro = mp['raunet']
+ xa_start = ro['xa_start']
+ xa_end = ro['xa_end']
+
+ if step < xa_start or step >= xa_end:
+ return h
+ h = F.avg_pool2d(h, kernel_size=(2,2))
+ return h
+
+
+def out_xattn_patch(h, hsp, transformer_options):
+ if 'model_patch' not in transformer_options:
+ print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
+ return h, hsp
+ mp = transformer_options['model_patch']
+ if 'raunet' not in mp:
+ print("RAUNet (o-x-p): 'raunet' not in model_patch options")
+ return h
+
+ step = mp['step']
+ is_SDXL = mp['SDXL']
+ ro = mp['raunet']
+ xa_start = ro['xa_start']
+ xa_end = ro['xa_end']
+
+ if is_SDXL:
+ if transformer_options["block"] != ("output", 5):
+ # wrong block
+ return h, hsp
+ else:
+ if transformer_options["block"] != ("output", 8):
+ # wrong block
+ return h, hsp
+
+ if step < xa_start or step >= xa_end:
+ return h, hsp
+ #error in hidiffusion codebase, size * 2 for particular sizes only
+ #re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
+ re_size = (hsp.shape[-2], hsp.shape[-1])
+ h = F.interpolate(h, size=re_size, mode='bicubic')
+
+ return h, hsp
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b74f0e3520b6644cbab4c8b223bcb6b8f484b04b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+diffusers>=0.27.0
+accelerate>=0.29.0,<0.32.0
+peft>=0.7.0