diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..1fdd488125124c37eb44f7cf298a53b9de04690a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,52 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +release-doc/asset/github_video.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2338-51.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2340-40.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2349-09.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2350-12.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2353-51.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2355-54.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-01-2357-39.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-02-0000-23.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-02-0002-02.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-1935-28.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-1936-51.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-1937-52.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-1939-28.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-1951-55.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2007-38.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2020-44.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2024-00.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2024-01.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2026-48.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2026-50.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2037-28.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2042-05.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2047-11.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2047-13.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-05-2050-26.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0124-52.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0136-07.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0143-46.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0146-41.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0146-45.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0149-29.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0152-29.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0153-19.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0154-20.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0155-38.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0156-15.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0156-34.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0157-09.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0157-52.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0159-25.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0200-31.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0200-33.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0202-12.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0202-13.png filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0215-08.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0217-26.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0219-21.gif filter=lfs diff=lfs merge=lfs -text +results/2023-12-08-0223-15.gif filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0fb65848bbeaf8d226502a4377fc09fc47fcf201 --- /dev/null +++ b/LICENSE @@ -0,0 +1,218 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +======================================================================= +Apache DragDiffusion Subcomponents: + +The Apache DragDiffusion project contains subcomponents with separate copyright +notices and license terms. Your use of the source code for the these +subcomponents is subject to the terms and conditions of the following +licenses. + +======================================================================== +Apache 2.0 licenses +======================================================================== + +The following components are provided under the Apache License. See project link for details. +The text of each license is the standard Apache 2.0 license. + + files from lora: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py apache 2.0 \ No newline at end of file diff --git a/README.md b/README.md index 7bc3eed3eaf4d8d92b66b1b81b6ebc1f1cd62acb..8aa3b06528222145625205f3b7ed60fe8874b15d 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,152 @@ --- title: DragDiffusion -emoji: 🦀 -colorFrom: indigo -colorTo: purple +app_file: drag_ui.py sdk: gradio -sdk_version: 4.8.0 -app_file: app.py -pinned: false +sdk_version: 3.41.1 --- +

+

DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing

+

+ Yujun Shi +    + Chuhui Xue +    + Jiachun Pan +    + Wenqing Zhang +    + Vincent Y. F. Tan +    + Song Bai +

+
+ +
+
+

+ arXiv + page + Twitter +

+
+

+ +## Disclaimer +This is a research project, NOT a commercial product. + +## News and Update +* [Sept 3rd] v0.1.0 Release. + * Enable **Dragging Diffusion-Generated Images.** + * Introducing a new guidance mechanism that **greatly improve quality of dragging results.** (Inspired by [MasaCtrl](https://ljzycmd.github.io/projects/MasaCtrl/)) + * Enable Dragging Images with arbitrary aspect ratio + * Adding support for DPM++Solver (Generated Images) +* [July 18th] v0.0.1 Release. + * Integrate LoRA training into the User Interface. **No need to use training script and everything can be conveniently done in UI!** + * Optimize User Interface layout. + * Enable using better VAE for eyes and faces (See [this](https://stable-diffusion-art.com/how-to-use-vae/)) +* [July 8th] v0.0.0 Release. + * Implement Basic function of DragDiffusion + +## Installation + +It is recommended to run our code on a Nvidia GPU with a linux system. We have not yet tested on other configurations. Currently, it requires around 14 GB GPU memory to run our method. We will continue to optimize memory efficiency + +To install the required libraries, simply run the following command: +``` +conda env create -f environment.yaml +conda activate dragdiff +``` + +## Run DragDiffusion +To start with, in command line, run the following to start the gradio user interface: +``` +python3 drag_ui_real.py +``` + +You may check our [GIF above](https://github.com/Yujun-Shi/DragDiffusion/blob/main/release-doc/asset/github_video.gif) that demonstrate the usage of UI in a step-by-step manner. + +Basically, it consists of the following steps: + +#### Step 1: train a LoRA +1) Drop our input image into the left-most box. +2) Input a prompt describing the image in the "prompt" field +3) Click the "Train LoRA" button to train a LoRA given the input image + +#### Step 2: do "drag" editing +1) Draw a mask in the left-most box to specify the editable areas. +2) Click handle and target points in the middle box. Also, you may reset all points by clicking "Undo point". +3) Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box. + + +## Explanation for parameters in the user interface: +#### General Parameters +|Parameter|Explanation| +|-----|------| +|prompt|The prompt describing the user input image (This will be used to train the LoRA and conduct "drag" editing).| +|lora_path|The directory where the trained LoRA will be saved.| + + +#### Algorithm Parameters +These parameters are collapsed by default as we normally do not have to tune them. Here are the explanations: +* Base Model Config + +|Parameter|Explanation| +|-----|------| +|Diffusion Model Path|The path to the diffusion models. By default we are using "runwayml/stable-diffusion-v1-5". We will add support for more models in the future.| +|VAE Choice|The Choice of VAE. Now there are two choices, one is "default", which will use the original VAE. Another choice is "stabilityai/sd-vae-ft-mse", which can improve results on images with human eyes and faces (see [explanation](https://stable-diffusion-art.com/how-to-use-vae/))| + +* Drag Parameters + +|Parameter|Explanation| +|-----|------| +|n_pix_step|Maximum number of steps of motion supervision. **Increase this if handle points have not been "dragged" to desired position.**| +|lam|The regularization coefficient controlling unmasked region stays unchanged. Increase this value if the unmasked region has changed more than what was desired (do not have to tune in most cases).| +|n_actual_inference_step|Number of DDIM inversion steps performed (do not have to tune in most cases).| + +* LoRA Parameters + +|Parameter|Explanation| +|-----|------| +|LoRA training steps|Number of LoRA training steps (do not have to tune in most cases).| +|LoRA learning rate|Learning rate of LoRA (do not have to tune in most cases)| +|LoRA rank|Rank of the LoRA (do not have to tune in most cases).| + + +## License +Code related to the DragDiffusion algorithm is under Apache 2.0 license. + + +## BibTeX +```bibtex +@article{shi2023dragdiffusion, + title={DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing}, + author={Shi, Yujun and Xue, Chuhui and Pan, Jiachun and Zhang, Wenqing and Tan, Vincent YF and Bai, Song}, + journal={arXiv preprint arXiv:2306.14435}, + year={2023} +} +``` + +## TODO +- [x] Upload trained LoRAs of our examples +- [x] Integrate the lora training function into the user interface. +- [ ] Support using more diffusion models +- [ ] Support using LoRA downloaded online + +## Contact +For any questions on this project, please contact [Yujun](https://yujun-shi.github.io/) (shi.yujun@u.nus.edu) + +## Acknowledgement +This work is inspired by the amazing [DragGAN](https://vcai.mpi-inf.mpg.de/projects/DragGAN/). The lora training code is modified from an [example](https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py) of diffusers. Image samples are collected from [unsplash](https://unsplash.com/), [pexels](https://www.pexels.com/zh-cn/), [pixabay](https://pixabay.com/). Finally, a huge shout-out to all the amazing open source diffusion models and libraries. + +## Related Links +* [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/) +* [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/) +* [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/) +* [DragonDiffusion: Enabling Drag-style Manipulation on Diffusion Models](https://mc-e.github.io/project/DragonDiffusion/) +* [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://lin-chen.site/projects/freedrag/) + + +## Common Issues and Solutions +1) For users struggling in loading models from huggingface due to internet constraint, please 1) follow this [links](https://zhuanlan.zhihu.com/p/475260268) and download the model into the directory "local\_pretrained\_models"; 2) Run "drag\_ui\_real.py" and select the directory to your pretrained model in "Algorithm Parameters -> Base Model Config -> Diffusion Model Path". + -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/__pycache__/drag_pipeline.cpython-38.pyc b/__pycache__/drag_pipeline.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2974ebc229347bfca84dfd3f1b7fc1c2f560457a Binary files /dev/null and b/__pycache__/drag_pipeline.cpython-38.pyc differ diff --git a/drag_pipeline.py b/drag_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..05f9a1d18c848af62ab49a5864486806eda562aa --- /dev/null +++ b/drag_pipeline.py @@ -0,0 +1,493 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import torch +import numpy as np + +import torch.nn.functional as F +from tqdm import tqdm +from PIL import Image +from typing import Any, Dict, List, Optional, Tuple, Union + +from diffusers import StableDiffusionPipeline + +# override unet forward +# The only difference from diffusers: +# return intermediate UNet features of all UpSample blocks +def override_forward(self): + + def forward( + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_intermediates: bool = False, + last_up_block_idx: int = None, + ): + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + emb = emb + aug_emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None: + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + # only difference from diffusers: + # save the intermediate features of unet upsample blocks + # the 0-th element is the mid-block output + all_intermediate_features = [sample] + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + all_intermediate_features.append(sample) + # return early to save computation time if needed + if last_up_block_idx is not None and i == last_up_block_idx: + return all_intermediate_features + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # only difference from diffusers, return intermediate results + if return_intermediates: + return sample, all_intermediate_features + else: + return sample + + return forward + + +class DragPipeline(StableDiffusionPipeline): + + # must call this function when initialize + def modify_unet_forward(self): + self.unet.forward = override_forward(self.unet) + + def inv_step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + eta=0., + verbose=False + ): + """ + Inverse sampling for DDIM Inversion + """ + if verbose: + print("timestep: ", timestep) + next_step = timestep + timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output + x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir + return x_next, pred_x0 + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + ): + """ + predict the sample of the next step in the denoise process. + """ + prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output + x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir + return x_prev, pred_x0 + + @torch.no_grad() + def image2latent(self, image): + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + if type(image) is Image: + image = np.array(image) + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE) + # input image density range [-1, 1] + latents = self.vae.encode(image)['latent_dist'].mean + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def latent2image(self, latents, return_type='np'): + latents = 1 / 0.18215 * latents.detach() + image = self.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = (image * 255).astype(np.uint8) + elif return_type == "pt": + image = (image / 2 + 0.5).clamp(0, 1) + + return image + + def latent2image_grad(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents)['sample'] + + return image # range [-1, 1] + + @torch.no_grad() + def get_text_embeddings(self, prompt): + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0] + return text_embeddings + + # get all intermediate features and then do bilinear interpolation + # return features in the layer_idx list + def forward_unet_features(self, z, t, encoder_hidden_states, layer_idx=[0], interp_res_h=256, interp_res_w=256): + unet_output, all_intermediate_features = self.unet( + z, + t, + encoder_hidden_states=encoder_hidden_states, + return_intermediates=True + ) + + all_return_features = [] + for idx in layer_idx: + feat = all_intermediate_features[idx] + feat = F.interpolate(feat, (interp_res_h, interp_res_w), mode='bilinear') + all_return_features.append(feat) + return_features = torch.cat(all_return_features, dim=1) + return unet_output, return_features + + @torch.no_grad() + def __call__( + self, + prompt, + prompt_embeds=None, # whether text embedding is directly provided. + batch_size=1, + height=512, + width=512, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=7.5, + latents=None, + unconditioning=None, + neg_prompt=None, + return_intermediates=False, + **kwds): + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + if prompt_embeds is None: + if isinstance(prompt, list): + batch_size = len(prompt) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] + else: + batch_size = prompt_embeds.shape[0] + text_embeddings = prompt_embeds + print("input text embeddings :", text_embeddings.shape) + + # define initial latents if not predefined + if latents is None: + latents_shape = (batch_size, self.unet.in_channels, height//8, width//8) + latents = torch.randn(latents_shape, device=DEVICE, dtype=self.vae.dtype) + + # unconditional embedding for classifier free guidance + if guidance_scale > 1.: + if neg_prompt: + uc_text = neg_prompt + else: + uc_text = "" + unconditional_input = self.tokenizer( + [uc_text] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] + text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0) + + print("latents shape: ", latents.shape) + # iterative sampling + self.scheduler.set_timesteps(num_inference_steps) + # print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + latents_list = [latents] + for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")): + if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps: + continue + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + if unconditioning is not None and isinstance(unconditioning, list): + _, text_embeddings = text_embeddings.chunk(2) + text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings]) + # predict the noise + noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings) + if guidance_scale > 1.0: + noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t -> x_t-1 + # YUJUN: right now, the only difference between step here and step in scheduler + # is that scheduler version would clamp pred_x0 between [-1,1] + # don't know if that's gonna have huge impact + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents_list.append(latents) + + image = self.latent2image(latents, return_type="pt") + if return_intermediates: + return image, latents_list + return image + + @torch.no_grad() + def invert( + self, + image: torch.Tensor, + prompt, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=7.5, + eta=0.0, + return_intermediates=False, + **kwds): + """ + invert a real image into noise map with determinisc DDIM inversion + """ + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + batch_size = image.shape[0] + if isinstance(prompt, list): + if batch_size == 1: + image = image.expand(len(prompt), -1, -1, -1) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] + print("input text embeddings :", text_embeddings.shape) + # define initial latents + latents = self.image2latent(image) + + # unconditional embedding for classifier free guidance + if guidance_scale > 1.: + max_length = text_input.input_ids.shape[-1] + unconditional_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] + text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0) + + print("latents shape: ", latents.shape) + # interative sampling + self.scheduler.set_timesteps(num_inference_steps) + print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + # print("attributes: ", self.scheduler.__dict__) + latents_list = [latents] + pred_x0_list = [latents] + for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): + if num_actual_inference_steps is not None and i >= num_actual_inference_steps: + continue + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + + # predict the noise + noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings) + if guidance_scale > 1.: + noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t-1 -> x_t + latents, pred_x0 = self.inv_step(noise_pred, t, latents) + latents_list.append(latents) + pred_x0_list.append(pred_x0) + + if return_intermediates: + # return the intermediate laters during inversion + # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list] + return latents, latents_list + return latents diff --git a/drag_ui.py b/drag_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..de9cf24f9f304665ac5581b07ff6e0e4a18b52c1 --- /dev/null +++ b/drag_ui.py @@ -0,0 +1,335 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import os +import gradio as gr + +from utils.ui_utils import get_points, undo_points +from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag +from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen + +LENGTH=480 # length of the square area displaying/editing images + +with gr.Blocks() as demo: + # layout definition + with gr.Row(): + gr.Markdown(""" + # Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435) + """) + + # UI components for editing real images + with gr.Tab(label="Editing Real Image"): + mask = gr.State(value=None) # store mask + selected_points = gr.State([]) # store points + original_image = gr.State(value=None) # store original input image + with gr.Row(): + with gr.Column(): + gr.Markdown("""

Draw Mask

""") + canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", + show_label=True, height=LENGTH, width=LENGTH) # for mask painting + train_lora_button = gr.Button("Train LoRA") + with gr.Column(): + gr.Markdown("""

Click Points

""") + input_image = gr.Image(type="numpy", label="Click Points", + show_label=True, height=LENGTH, width=LENGTH) # for points clicking + undo_button = gr.Button("Undo point") + with gr.Column(): + gr.Markdown("""

Editing Results

""") + output_image = gr.Image(type="numpy", label="Editing Results", + show_label=True, height=LENGTH, width=LENGTH) + with gr.Row(): + run_button = gr.Button("Run") + clear_all_button = gr.Button("Clear All") + + # general parameters + with gr.Row(): + prompt = gr.Textbox(label="Prompt") + lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") + lora_status_bar = gr.Textbox(label="display LoRA training status") + + # algorithm specific parameters + with gr.Tab("Drag Config"): + with gr.Row(): + n_pix_step = gr.Number( + value=40, + label="number of pixel steps", + info="Number of gradient descent (motion supervision) steps on latent.", + precision=0) + lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") + # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0) + inversion_strength = gr.Slider(0, 1.0, + value=0.75, + label="inversion strength", + info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") + latent_lr = gr.Number(value=0.01, label="latent lr") + start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) + start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) + + with gr.Tab("Base Model Config"): + with gr.Row(): + local_models_dir = 'local_pretrained_models' + local_models_choice = \ + [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] + model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", + label="Diffusion Model Path", + choices=[ + "runwayml/stable-diffusion-v1-5", + ] + local_models_choice + ) + vae_path = gr.Dropdown(value="default", + label="VAE choice", + choices=["default", + "stabilityai/sd-vae-ft-mse"] + local_models_choice + ) + + with gr.Tab("LoRA Parameters"): + with gr.Row(): + lora_step = gr.Number(value=200, label="LoRA training steps", precision=0) + lora_lr = gr.Number(value=0.0002, label="LoRA learning rate") + lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) + + # UI components for editing generated images + with gr.Tab(label="Editing Generated Image"): + mask_gen = gr.State(value=None) # store mask + selected_points_gen = gr.State([]) # store points + original_image_gen = gr.State(value=None) # store the diffusion-generated image + intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation + with gr.Row(): + with gr.Column(): + gr.Markdown("""

Draw Mask

""") + canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask", + show_label=True, height=LENGTH, width=LENGTH) # for mask painting + gen_img_button = gr.Button("Generate Image") + with gr.Column(): + gr.Markdown("""

Click Points

""") + input_image_gen = gr.Image(type="numpy", label="Click Points", + show_label=True, height=LENGTH, width=LENGTH) # for points clicking + undo_button_gen = gr.Button("Undo point") + with gr.Column(): + gr.Markdown("""

Editing Results

""") + output_image_gen = gr.Image(type="numpy", label="Editing Results", + show_label=True, height=LENGTH, width=LENGTH) + with gr.Row(): + run_button_gen = gr.Button("Run") + clear_all_button_gen = gr.Button("Clear All") + + # general parameters + with gr.Row(): + pos_prompt_gen = gr.Textbox(label="Positive Prompt") + neg_prompt_gen = gr.Textbox(label="Negative Prompt") + + with gr.Tab("Generation Config"): + with gr.Row(): + local_models_dir = 'local_pretrained_models' + local_models_choice = \ + [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] + model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", + label="Diffusion Model Path", + choices=[ + "runwayml/stable-diffusion-v1-5", + "gsdf/Counterfeit-V2.5", + "emilianJR/majicMIX_realistic", + "SG161222/Realistic_Vision_V2.0", + "stablediffusionapi/landscapesupermix", + "huangzhe0803/ArchitectureRealMix", + "stablediffusionapi/interiordesignsuperm" + ] + local_models_choice + ) + vae_path_gen = gr.Dropdown(value="default", + label="VAE choice", + choices=["default", + "stabilityai/sd-vae-ft-mse"] + local_models_choice + ) + lora_path_gen = gr.Textbox(value="", label="LoRA path") + gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0) + height = gr.Number(value=512, label="Height", precision=0) + width = gr.Number(value=512, label="Width", precision=0) + guidance_scale = gr.Number(value=7.5, label="CFG Scale") + scheduler_name_gen = gr.Dropdown( + value="DDIM", + label="Scheduler", + choices=[ + "DDIM", + "DPM++2M", + "DPM++2M_karras" + ] + ) + n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0) + + with gr.Tab(label="Drag Config"): + with gr.Row(): + n_pix_step_gen = gr.Number( + value=40, + label="Number of Pixel Steps", + info="Number of gradient descent (motion supervision) steps on latent.", + precision=0) + lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") + # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0) + inversion_strength_gen = gr.Slider(0, 1.0, + value=0.75, + label="Inversion Strength", + info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") + latent_lr_gen = gr.Number(value=0.01, label="latent lr") + start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) + start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) + # Add a checkbox for users to select if they want a GIF of the process + with gr.Row(): + create_gif_checkbox = gr.Checkbox(label="create_GIF", value=False) + create_tracking_point_checkbox = gr.Checkbox(label="create_tracking_point", value=False) + gif_interval = gr.Number(value=10, label="interval_GIF", precision=0, info="The interval of the GIF, i.e. the number of steps between each frame of the GIF.") + gif_fps = gr.Number(value=1, label="fps_GIF", precision=0, info="The fps of the GIF, i.e. the number of frames per second of the GIF.") + + # event definition + # event for dragging user-input real image + canvas.edit( + store_img, + [canvas], + [original_image, selected_points, input_image, mask] + ) + input_image.select( + get_points, + [input_image, selected_points], + [input_image], + ) + undo_button.click( + undo_points, + [original_image, mask], + [input_image, selected_points] + ) + train_lora_button.click( + train_lora_interface, + [original_image, + prompt, + model_path, + vae_path, + lora_path, + lora_step, + lora_lr, + lora_rank], + [lora_status_bar] + ) + run_button.click( + run_drag, + [original_image, + input_image, + mask, + prompt, + selected_points, + inversion_strength, + lam, + latent_lr, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + create_gif_checkbox, + gif_interval, + ], + [output_image] + ) + clear_all_button.click( + clear_all, + [gr.Number(value=LENGTH, visible=False, precision=0)], + [canvas, + input_image, + output_image, + selected_points, + original_image, + mask] + ) + + # event for dragging generated image + canvas_gen.edit( + store_img_gen, + [canvas_gen], + [original_image_gen, selected_points_gen, input_image_gen, mask_gen] + ) + input_image_gen.select( + get_points, + [input_image_gen, selected_points_gen], + [input_image_gen], + ) + gen_img_button.click( + gen_img, + [ + gr.Number(value=LENGTH, visible=False, precision=0), + height, + width, + n_inference_step_gen, + scheduler_name_gen, + gen_seed, + guidance_scale, + pos_prompt_gen, + neg_prompt_gen, + model_path_gen, + vae_path_gen, + lora_path_gen, + ], + [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen] + ) + undo_button_gen.click( + undo_points, + [original_image_gen, mask_gen], + [input_image_gen, selected_points_gen] + ) + run_button_gen.click( + run_drag_gen, + [ + n_inference_step_gen, + scheduler_name_gen, + original_image_gen, # the original image generated by the diffusion model + input_image_gen, # image with clicking, masking, etc. + intermediate_latents_gen, + guidance_scale, + mask_gen, + pos_prompt_gen, + neg_prompt_gen, + selected_points_gen, + inversion_strength_gen, + lam_gen, + latent_lr_gen, + n_pix_step_gen, + model_path_gen, + vae_path_gen, + lora_path_gen, + start_step_gen, + start_layer_gen, + create_gif_checkbox, + create_tracking_point_checkbox, + gif_interval, + gif_fps + ], + [output_image_gen] + ) + clear_all_button_gen.click( + clear_all_gen, + [gr.Number(value=LENGTH, visible=False, precision=0)], + [canvas_gen, + input_image_gen, + output_image_gen, + selected_points_gen, + original_image_gen, + mask_gen, + intermediate_latents_gen, + ] + ) + + +demo.queue().launch(share=True, debug=True) diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf7499ebef57de89816d7743208d030656a89c8e --- /dev/null +++ b/environment.yaml @@ -0,0 +1,48 @@ +name: dragdiff +channels: + - pytorch + - defaults + - nvidia + - conda-forge +dependencies: + - python=3.8.5 + - pip=22.3.1 + - cudatoolkit=11.7 + - pip: + - torch==2.0.0 + - torchvision==0.15.1 + - gradio==3.41.1 + - pydantic==2.0.2 + - albumentations==1.3.0 + - opencv-contrib-python==4.3.0.38 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.5.0 + - omegaconf==2.3.0 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.6.0 + - transformers==4.27.0 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.16.0 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.6.0 + - timm==0.6.12 + - addict==2.4.0 + - yapf==0.32.0 + - prettytable==3.6.0 + - safetensors==0.2.7 + - basicsr==1.4.2 + - accelerate==0.17.0 + - decord==0.6.0 + - diffusers==0.17.1 + - moviepy==1.0.3 + - opencv_python==4.7.0.68 + - Pillow==9.4.0 + - scikit_image==0.19.3 + - scipy==1.10.1 + - tensorboardX==2.6 + - tqdm==4.64.1 + - numpy==1.24.1 diff --git a/local_pretrained_models/dummy.txt b/local_pretrained_models/dummy.txt new file mode 100644 index 0000000000000000000000000000000000000000..73833234858155dee9886dc709ffe44a353420bd --- /dev/null +++ b/local_pretrained_models/dummy.txt @@ -0,0 +1 @@ +You may put your pretrained model here. \ No newline at end of file diff --git a/lora/lora_ckpt/dummy.txt b/lora/lora_ckpt/dummy.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8e3a0cbc0b6926ec44473ddeff9e55372f058d1 --- /dev/null +++ b/lora/lora_ckpt/dummy.txt @@ -0,0 +1 @@ +lora checkpoints will be saved in this folder diff --git a/lora/samples/cat_dog/andrew-s-ouo1hbizWwo-unsplash.jpg b/lora/samples/cat_dog/andrew-s-ouo1hbizWwo-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35bc91d4b34512867b21bc93c6a67c9431bfc1ac Binary files /dev/null and b/lora/samples/cat_dog/andrew-s-ouo1hbizWwo-unsplash.jpg differ diff --git a/lora/samples/oilpaint1/catherine-kay-greenup-6rhUen8Wrao-unsplash.jpg b/lora/samples/oilpaint1/catherine-kay-greenup-6rhUen8Wrao-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a21fba3a910f800c69eeabeb7bf5c8e9cdde7378 Binary files /dev/null and b/lora/samples/oilpaint1/catherine-kay-greenup-6rhUen8Wrao-unsplash.jpg differ diff --git a/lora/samples/oilpaint2/birmingham-museums-trust-wKlHsooRVbg-unsplash.jpg b/lora/samples/oilpaint2/birmingham-museums-trust-wKlHsooRVbg-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e93477ccb573415fda6d1a2cae7710642a14c186 Binary files /dev/null and b/lora/samples/oilpaint2/birmingham-museums-trust-wKlHsooRVbg-unsplash.jpg differ diff --git a/lora/samples/prompts.txt b/lora/samples/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..30889a1151d27334db2a9f7fbe6ae4a7340c85cf --- /dev/null +++ b/lora/samples/prompts.txt @@ -0,0 +1,6 @@ +# prompts we used when editing the given samples: + +cat_dog: a photo of a cat and a dog +oilpaint1: an oil painting of a mountain besides a lake +oilpaint2: an oil painting of a mountain and forest +sculpture: a photo of a sculpture diff --git a/lora/samples/sculpture/evan-lee-EdAVNRvUVH4-unsplash.jpg b/lora/samples/sculpture/evan-lee-EdAVNRvUVH4-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d02f828b5e320320356c19958daf8dbeb33347b Binary files /dev/null and b/lora/samples/sculpture/evan-lee-EdAVNRvUVH4-unsplash.jpg differ diff --git a/lora/train_dreambooth_lora.py b/lora/train_dreambooth_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..4a85376604a79995cc681170b090f22084f87368 --- /dev/null +++ b/lora/train_dreambooth_lora.py @@ -0,0 +1,1324 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA DreamBooth - {repo_id} + +These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n +{img_str} + +LoRA for the text encoder was enabled: {train_text_encoder}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="rank of lora." + ) + + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + try: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + # We only train the additional adapter LoRA layers + if vae is not None: + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + unet_lora_attn_procs = {} + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + unet_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_rank + ) + + unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, + # we first load a dummy pipeline with the text encoder and then do the monkey-patching. + text_encoder_lora_layers = None + if args.train_text_encoder: + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_proj.out_features, cross_attention_dim=None + ) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, text_encoder=text_encoder + ) + temp_pipeline._modify_text_encoder(text_lora_attn_procs) + text_encoder = temp_pipeline.text_encoder + del temp_pipeline + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + if args.train_text_encoder: + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + + for model in models: + state_dict = model.state_dict() + + if ( + text_encoder_lora_layers is not None + and text_encoder_keys is not None + and state_dict.keys() == text_encoder_keys + ): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_keys: + # unet + unet_lora_layers_to_save = state_dict + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + # Note we DON'T pass the unet and text encoder here an purpose + # so that the we don't accidentally override the LoRA layers of + # unet_lora_layers and text_encoder_lora_layers which are stored in `models` + # with new torch.nn.Modules / weights. We simply use the pipeline class as + # an easy way to load the lora checkpoints + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) + + # load lora weights into models + models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) + if len(models) > 1: + models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + + # delete temporary pipeline and pop models + del temp_pipeline + for _ in range(len(models)): + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + ) + else: + unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth-lora", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist + model_input = model_input.sample() * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), + revision=args.revision, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + if args.validation_images is None: + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + + if text_encoder is not None: + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/lora/train_lora.sh b/lora/train_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..7d091eab55888733978007e7370cd8b8cf246e2e --- /dev/null +++ b/lora/train_lora.sh @@ -0,0 +1,21 @@ +export SAMPLE_DIR="lora/samples/sculpture" +export OUTPUT_DIR="lora/lora_ckpt/sculpture_lora" + +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export LORA_RANK=16 + +accelerate launch lora/train_dreambooth_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$SAMPLE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of a sculpture" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --checkpointing_steps=100 \ + --learning_rate=2e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=200 \ + --lora_rank=$LORA_RANK \ + --seed="0" diff --git a/lora_tmp/pytorch_lora_weights.bin b/lora_tmp/pytorch_lora_weights.bin new file mode 100644 index 0000000000000000000000000000000000000000..00cb35287e5cafb52e4c4488ffc4af9d736cec07 --- /dev/null +++ b/lora_tmp/pytorch_lora_weights.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a91eca307a7df4b4af0d73f52f0cbf8ac8693da50388f28271fb33a4fcdd6df7 +size 12855259 diff --git a/release-doc/asset/accelerate_config.jpg b/release-doc/asset/accelerate_config.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24a10dc084a3a65e1e75063dab52438a0bbc6098 Binary files /dev/null and b/release-doc/asset/accelerate_config.jpg differ diff --git a/release-doc/asset/github_video.gif b/release-doc/asset/github_video.gif new file mode 100644 index 0000000000000000000000000000000000000000..160130db212affc1bf26442ed796d7dabec379b7 --- /dev/null +++ b/release-doc/asset/github_video.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d87b873576337e4066094050203b4d53d1aef728db7979b0f16a0ae2518ea705 +size 7622606 diff --git a/release-doc/licenses/LICENSE-lora.txt b/release-doc/licenses/LICENSE-lora.txt new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/release-doc/licenses/LICENSE-lora.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/results/2023-12-01-2318-20.png b/results/2023-12-01-2318-20.png new file mode 100644 index 0000000000000000000000000000000000000000..9044b777b8eaa219a185dc55afc06889646d247d Binary files /dev/null and b/results/2023-12-01-2318-20.png differ diff --git a/results/2023-12-01-2319-14.png b/results/2023-12-01-2319-14.png new file mode 100644 index 0000000000000000000000000000000000000000..e7988c042c36fd7e81379b235ec3d795c54a0b56 Binary files /dev/null and b/results/2023-12-01-2319-14.png differ diff --git a/results/2023-12-01-2320-47.png b/results/2023-12-01-2320-47.png new file mode 100644 index 0000000000000000000000000000000000000000..cd9204aa795ed2bfd799b665c1e9d5f2c6c75a25 Binary files /dev/null and b/results/2023-12-01-2320-47.png differ diff --git a/results/2023-12-01-2321-38.png b/results/2023-12-01-2321-38.png new file mode 100644 index 0000000000000000000000000000000000000000..8dfd631d8a32c134ce577b5eb6a42fb03c284bdc Binary files /dev/null and b/results/2023-12-01-2321-38.png differ diff --git a/results/2023-12-01-2322-25.png b/results/2023-12-01-2322-25.png new file mode 100644 index 0000000000000000000000000000000000000000..5d83f2dbbf82cd92d8dc5ba173dfedc18b3815db Binary files /dev/null and b/results/2023-12-01-2322-25.png differ diff --git a/results/2023-12-01-2324-23.png b/results/2023-12-01-2324-23.png new file mode 100644 index 0000000000000000000000000000000000000000..a4d56dc4a54fa51f05e9365b4cc3e1fb0bcf6b3b Binary files /dev/null and b/results/2023-12-01-2324-23.png differ diff --git a/results/2023-12-01-2326-06.png b/results/2023-12-01-2326-06.png new file mode 100644 index 0000000000000000000000000000000000000000..0912da60890af121174be3c57880633aace08413 Binary files /dev/null and b/results/2023-12-01-2326-06.png differ diff --git a/results/2023-12-01-2328-23.png b/results/2023-12-01-2328-23.png new file mode 100644 index 0000000000000000000000000000000000000000..130729c133a6fc051d06f5d4f2fde3b462b6acc3 Binary files /dev/null and b/results/2023-12-01-2328-23.png differ diff --git a/results/2023-12-01-2329-06.png b/results/2023-12-01-2329-06.png new file mode 100644 index 0000000000000000000000000000000000000000..c0fca42b16bd7c11c5235587a3ff6ff2c2bceb8e Binary files /dev/null and b/results/2023-12-01-2329-06.png differ diff --git a/results/2023-12-01-2330-14.png b/results/2023-12-01-2330-14.png new file mode 100644 index 0000000000000000000000000000000000000000..c54ef715319fd3d2ca870ca1fd52e1a1cbd7276b Binary files /dev/null and b/results/2023-12-01-2330-14.png differ diff --git a/results/2023-12-01-2331-09.png b/results/2023-12-01-2331-09.png new file mode 100644 index 0000000000000000000000000000000000000000..8a3ecfcc51a667ae73784e17c14ea31a66b1abe1 Binary files /dev/null and b/results/2023-12-01-2331-09.png differ diff --git a/results/2023-12-01-2331-41.png b/results/2023-12-01-2331-41.png new file mode 100644 index 0000000000000000000000000000000000000000..2a3d6467766dce78c6357080c5dc4f24ad7ce6b7 Binary files /dev/null and b/results/2023-12-01-2331-41.png differ diff --git a/results/2023-12-01-2332-17.png b/results/2023-12-01-2332-17.png new file mode 100644 index 0000000000000000000000000000000000000000..de6a767521d8bd08d823a31fce07a347a83f1b68 Binary files /dev/null and b/results/2023-12-01-2332-17.png differ diff --git a/results/2023-12-01-2336-40.png b/results/2023-12-01-2336-40.png new file mode 100644 index 0000000000000000000000000000000000000000..14869bd7e77e47475ae9bcf31ed93324b619c7fe Binary files /dev/null and b/results/2023-12-01-2336-40.png differ diff --git a/results/2023-12-01-2338-51.png b/results/2023-12-01-2338-51.png new file mode 100644 index 0000000000000000000000000000000000000000..5d28f693b942cd2c07a9254983d003dd48e3be66 --- /dev/null +++ b/results/2023-12-01-2338-51.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5af31b413c421167798e03ee03722bacaea27ad37bc9334cf8470ca219e26693 +size 1098265 diff --git a/results/2023-12-01-2340-40.png b/results/2023-12-01-2340-40.png new file mode 100644 index 0000000000000000000000000000000000000000..a6a04001a0f0616ed9eafbb97cda2dfbfaae8157 --- /dev/null +++ b/results/2023-12-01-2340-40.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb5a56e0eb3c48d3f158523678893c0ccba693cbdaecd049e5d34aea51845a60 +size 1105007 diff --git a/results/2023-12-01-2342-40.png b/results/2023-12-01-2342-40.png new file mode 100644 index 0000000000000000000000000000000000000000..628ccce127befdcceb08a1983979f82f2e58b96e Binary files /dev/null and b/results/2023-12-01-2342-40.png differ diff --git a/results/2023-12-01-2349-09.png b/results/2023-12-01-2349-09.png new file mode 100644 index 0000000000000000000000000000000000000000..e4729ba33758ace7729f1f6dff3a8ee200f32050 --- /dev/null +++ b/results/2023-12-01-2349-09.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12c80e3f8ab5f34f5bcd694692690d4ff9c95a5ddbdad898aeaef57bcc6afb3b +size 1126420 diff --git a/results/2023-12-01-2350-12.png b/results/2023-12-01-2350-12.png new file mode 100644 index 0000000000000000000000000000000000000000..28358f05345fbb972f71d90b02f73fba2f83219d --- /dev/null +++ b/results/2023-12-01-2350-12.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e753d9d6a79287979c0dd57ecc5393c854d550f853c6015283f6dfcd09c2e2a9 +size 1104724 diff --git a/results/2023-12-01-2353-51.png b/results/2023-12-01-2353-51.png new file mode 100644 index 0000000000000000000000000000000000000000..0583ac3d92b3203cb98a3f7117501310353c9b12 --- /dev/null +++ b/results/2023-12-01-2353-51.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaec82d0bd8036687a7086ad5d21ad06716e78dfc2fc1a1b6344f87962919472 +size 1165031 diff --git a/results/2023-12-01-2355-54.png b/results/2023-12-01-2355-54.png new file mode 100644 index 0000000000000000000000000000000000000000..d480a3a8fa80edba44807ea9999bd6657b346235 --- /dev/null +++ b/results/2023-12-01-2355-54.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24426fc3aacea198e1877bde2d2606431150115cace81abb11378a285eb2be48 +size 1163829 diff --git a/results/2023-12-01-2357-39.png b/results/2023-12-01-2357-39.png new file mode 100644 index 0000000000000000000000000000000000000000..d6b267066af64563c4e8b7a9ea4c4e0fced9580f --- /dev/null +++ b/results/2023-12-01-2357-39.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e827e996b40e369f8ff852b4fdfdac7399431536b47b5416b3258c7251c0b353 +size 1000529 diff --git a/results/2023-12-02-0000-23.png b/results/2023-12-02-0000-23.png new file mode 100644 index 0000000000000000000000000000000000000000..0a64e4a40cb3302d5d1f1b93c09578e456d6f249 --- /dev/null +++ b/results/2023-12-02-0000-23.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed2fbfae67c236ff1f8439d93f0caefe401b546e5f8c405babbeedeb1d437d92 +size 1012812 diff --git a/results/2023-12-02-0002-02.png b/results/2023-12-02-0002-02.png new file mode 100644 index 0000000000000000000000000000000000000000..a157f639aedf37f980b22fa011b469b0c2cf14e7 --- /dev/null +++ b/results/2023-12-02-0002-02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:864333336c74a2148ab2e77d3fffd23a61548544cf1186b1501914f43077dd14 +size 1012122 diff --git a/results/2023-12-02-0004-46.png b/results/2023-12-02-0004-46.png new file mode 100644 index 0000000000000000000000000000000000000000..aff887480b6705b4aabdf34b8b9d8d476ad55d7a Binary files /dev/null and b/results/2023-12-02-0004-46.png differ diff --git a/results/2023-12-05-1935-28.png b/results/2023-12-05-1935-28.png new file mode 100644 index 0000000000000000000000000000000000000000..c3644757f4461ca101eaad353ac14b24f873e21a --- /dev/null +++ b/results/2023-12-05-1935-28.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a9ce897a62299e552480abdb8f24fd83fbc6aec3a7a3e4a03d7c359f557e572 +size 1047071 diff --git a/results/2023-12-05-1936-51.png b/results/2023-12-05-1936-51.png new file mode 100644 index 0000000000000000000000000000000000000000..7a2a031feafd773101a1362599358288b54fac2a --- /dev/null +++ b/results/2023-12-05-1936-51.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee81144d27c60be9e1e0f1057173339f526d71eeac94f13db704300ace3f9f8a +size 1047850 diff --git a/results/2023-12-05-1937-52.png b/results/2023-12-05-1937-52.png new file mode 100644 index 0000000000000000000000000000000000000000..a8e5091f0c2a051bab0e2549efe10b3a559611ce --- /dev/null +++ b/results/2023-12-05-1937-52.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fc172ba5e6840f5dda2e18ae6d44e8216fc1799dc31b66e6111b8f3cfa9a2e2 +size 1046532 diff --git a/results/2023-12-05-1939-28.png b/results/2023-12-05-1939-28.png new file mode 100644 index 0000000000000000000000000000000000000000..253a4483e4a67ea8824ca596e2a9e1820f24e63e --- /dev/null +++ b/results/2023-12-05-1939-28.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88f16c0149568a9034ffec78834b180ee488eb26a8bbd2daa5e7199a5055c0d6 +size 1367435 diff --git a/results/2023-12-05-1944-37.png b/results/2023-12-05-1944-37.png new file mode 100644 index 0000000000000000000000000000000000000000..614d8426f216b89d1297c44fdb185ed641f90b72 Binary files /dev/null and b/results/2023-12-05-1944-37.png differ diff --git a/results/2023-12-05-1951-55.png b/results/2023-12-05-1951-55.png new file mode 100644 index 0000000000000000000000000000000000000000..83dc4176c5517eba95511f536568899be4e0e81e --- /dev/null +++ b/results/2023-12-05-1951-55.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aef4f14f2f37a75601e1846e1d1d42dfb6b2b78dd566f3aa5a4fda0278bd96b4 +size 1263498 diff --git a/results/2023-12-05-2007-38.png b/results/2023-12-05-2007-38.png new file mode 100644 index 0000000000000000000000000000000000000000..a0c956fcf3a81bfc60c2fc239d1e2271347434b9 --- /dev/null +++ b/results/2023-12-05-2007-38.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:508ed683cba6c3939ed829835aef09a10c8f003aebd96a6ab770c4382f03023e +size 1509430 diff --git a/results/2023-12-05-2017-02.png b/results/2023-12-05-2017-02.png new file mode 100644 index 0000000000000000000000000000000000000000..e6dc0169a72de92b5449646d6b7e3855839184a9 Binary files /dev/null and b/results/2023-12-05-2017-02.png differ diff --git a/results/2023-12-05-2020-44.png b/results/2023-12-05-2020-44.png new file mode 100644 index 0000000000000000000000000000000000000000..01e3751b896c1f7a0e04cd8ad6c7ba5699000721 --- /dev/null +++ b/results/2023-12-05-2020-44.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd7dbdc4ef2e06d95d73750ed7fb99b579cc5e221f7a31fd622efb6f5d6bbca4 +size 1020964 diff --git a/results/2023-12-05-2024-00.gif b/results/2023-12-05-2024-00.gif new file mode 100644 index 0000000000000000000000000000000000000000..d51dcc95ff90546be8b7c70325416e0e18aa7885 --- /dev/null +++ b/results/2023-12-05-2024-00.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75cda608fe0a28679122a4fee8341610e2fa47cea0924c80cd38a0208907c32d +size 1128803 diff --git a/results/2023-12-05-2024-01.png b/results/2023-12-05-2024-01.png new file mode 100644 index 0000000000000000000000000000000000000000..9576e5bb4bfa18f2c7837904d822de76f2e8e167 --- /dev/null +++ b/results/2023-12-05-2024-01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32e01710b29d45979b98a55041bb4ea4065ed69e09a4e36a5fe42a1a098b0544 +size 1476746 diff --git a/results/2023-12-05-2026-48.gif b/results/2023-12-05-2026-48.gif new file mode 100644 index 0000000000000000000000000000000000000000..586652c178ab96d043cbf1f655ee6bcac0e99fb0 --- /dev/null +++ b/results/2023-12-05-2026-48.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0f4d6f950ed11e7ac1037f7c391603785a6ecac2939644238bfd0531e461d69 +size 2381302 diff --git a/results/2023-12-05-2026-50.png b/results/2023-12-05-2026-50.png new file mode 100644 index 0000000000000000000000000000000000000000..f6dcb422407a227eb961bf72d5fe5be7b603e677 --- /dev/null +++ b/results/2023-12-05-2026-50.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86712640ae1713cf7cf56490199f9230abc46c1279b17c257913be2fef174912 +size 1279358 diff --git a/results/2023-12-05-2028-45.gif b/results/2023-12-05-2028-45.gif new file mode 100644 index 0000000000000000000000000000000000000000..7a3f256bd388b124e71e02567ce264fc54aba18f Binary files /dev/null and b/results/2023-12-05-2028-45.gif differ diff --git a/results/2023-12-05-2028-45.png b/results/2023-12-05-2028-45.png new file mode 100644 index 0000000000000000000000000000000000000000..eb36d36f2dd260c723feed398a5b7c845c88e67a Binary files /dev/null and b/results/2023-12-05-2028-45.png differ diff --git a/results/2023-12-05-2035-02.gif b/results/2023-12-05-2035-02.gif new file mode 100644 index 0000000000000000000000000000000000000000..4b8c8be1450a1d1ef1decb58830b9cb36834bcb2 Binary files /dev/null and b/results/2023-12-05-2035-02.gif differ diff --git a/results/2023-12-05-2035-03.png b/results/2023-12-05-2035-03.png new file mode 100644 index 0000000000000000000000000000000000000000..80bedb90e1a1048fd7e9c682490c3f05313cc97b Binary files /dev/null and b/results/2023-12-05-2035-03.png differ diff --git a/results/2023-12-05-2037-28.gif b/results/2023-12-05-2037-28.gif new file mode 100644 index 0000000000000000000000000000000000000000..ac2cbe899fc415279bfcfc8bd1df71f2475a9912 --- /dev/null +++ b/results/2023-12-05-2037-28.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d90216f90c54628cb179c5e8f42fcaa0789011a561ca6691f516e24be351eff +size 1909854 diff --git a/results/2023-12-05-2037-29.png b/results/2023-12-05-2037-29.png new file mode 100644 index 0000000000000000000000000000000000000000..c37a628be49b3d95a7bdda417d4562c20af45320 Binary files /dev/null and b/results/2023-12-05-2037-29.png differ diff --git a/results/2023-12-05-2042-05.gif b/results/2023-12-05-2042-05.gif new file mode 100644 index 0000000000000000000000000000000000000000..b01a172d834b560c147a08417c407811e9543636 --- /dev/null +++ b/results/2023-12-05-2042-05.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2b8ded0d65da3ceef924fb89ca8a770f1c4bbad4e49ae77d21f92940c85db15 +size 3765519 diff --git a/results/2023-12-05-2042-08.png b/results/2023-12-05-2042-08.png new file mode 100644 index 0000000000000000000000000000000000000000..e2a2761e92cf28b57358d7bfabe6b43f2d2713d5 Binary files /dev/null and b/results/2023-12-05-2042-08.png differ diff --git a/results/2023-12-05-2047-11.gif b/results/2023-12-05-2047-11.gif new file mode 100644 index 0000000000000000000000000000000000000000..32697647cd41cdd7b6fa75a4b4f8a04b02e54cda --- /dev/null +++ b/results/2023-12-05-2047-11.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f1d6d16d074051fda1050823b880c5344777cd8b4eb3c2b9bdc5649a17ee916 +size 4209663 diff --git a/results/2023-12-05-2047-13.png b/results/2023-12-05-2047-13.png new file mode 100644 index 0000000000000000000000000000000000000000..b0ea2dc86489e09ab72042bbb6e5bb6f37bbfc6e --- /dev/null +++ b/results/2023-12-05-2047-13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d41009fe59a9cc3a8d80f53ad27cbf43859bb48b53eb430a9e11755e071a1864 +size 1000182 diff --git a/results/2023-12-05-2050-26.gif b/results/2023-12-05-2050-26.gif new file mode 100644 index 0000000000000000000000000000000000000000..fcd71def50ad72bf739daf1dd3daf1fc78f75653 --- /dev/null +++ b/results/2023-12-05-2050-26.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e26c90dbe9d6afc5256b28ae8834f4d856ddca30a12c95bf4e0ecf281ebfe5ec +size 3828925 diff --git a/results/2023-12-05-2050-29.png b/results/2023-12-05-2050-29.png new file mode 100644 index 0000000000000000000000000000000000000000..f2047f49992f426c6f47aeaedc3d0986441283a0 Binary files /dev/null and b/results/2023-12-05-2050-29.png differ diff --git a/results/2023-12-08-0113-27.gif b/results/2023-12-08-0113-27.gif new file mode 100644 index 0000000000000000000000000000000000000000..ac284b2108b2555e3405f2d984f90178d5f55757 Binary files /dev/null and b/results/2023-12-08-0113-27.gif differ diff --git a/results/2023-12-08-0116-17.gif b/results/2023-12-08-0116-17.gif new file mode 100644 index 0000000000000000000000000000000000000000..b7e1689d506b57869de33576f64e587ab7f15287 Binary files /dev/null and b/results/2023-12-08-0116-17.gif differ diff --git a/results/2023-12-08-0118-50.gif b/results/2023-12-08-0118-50.gif new file mode 100644 index 0000000000000000000000000000000000000000..0321bbecdd9d999d4de9316f42e79c00e4bd9944 Binary files /dev/null and b/results/2023-12-08-0118-50.gif differ diff --git a/results/2023-12-08-0124-50.gif b/results/2023-12-08-0124-50.gif new file mode 100644 index 0000000000000000000000000000000000000000..4e7ff97e7e4fdc5c4d5c0fdf09af83e32ef96256 Binary files /dev/null and b/results/2023-12-08-0124-50.gif differ diff --git a/results/2023-12-08-0124-51traking_points.gif b/results/2023-12-08-0124-51traking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..4e7ff97e7e4fdc5c4d5c0fdf09af83e32ef96256 Binary files /dev/null and b/results/2023-12-08-0124-51traking_points.gif differ diff --git a/results/2023-12-08-0124-52.png b/results/2023-12-08-0124-52.png new file mode 100644 index 0000000000000000000000000000000000000000..ad5009c661f2012671e81dcd822969ab64a66d7b --- /dev/null +++ b/results/2023-12-08-0124-52.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:477ab9a461ac6b84659a5d4397c90b8a6d176a5053cfe4bb61bddd57e3ca041f +size 1312616 diff --git a/results/2023-12-08-0131-22.gif b/results/2023-12-08-0131-22.gif new file mode 100644 index 0000000000000000000000000000000000000000..4dd51a40a0053de76c1251c8e6f8d557f6ab7179 Binary files /dev/null and b/results/2023-12-08-0131-22.gif differ diff --git a/results/2023-12-08-0136-06.gif b/results/2023-12-08-0136-06.gif new file mode 100644 index 0000000000000000000000000000000000000000..6f86d912bc6d0c5a74dcb5f18961e0ecaf2481bf Binary files /dev/null and b/results/2023-12-08-0136-06.gif differ diff --git a/results/2023-12-08-0136-07.png b/results/2023-12-08-0136-07.png new file mode 100644 index 0000000000000000000000000000000000000000..a0333c3524940588c9b95526bd0dd8564ada4680 --- /dev/null +++ b/results/2023-12-08-0136-07.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bd243578e951fb192e5c1da605489df52647bf71b40bc0b716a017fed15241e +size 1323291 diff --git a/results/2023-12-08-0136-07traking_points.gif b/results/2023-12-08-0136-07traking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..60daa33541da2005a4c039d2b0347a0e50a97814 Binary files /dev/null and b/results/2023-12-08-0136-07traking_points.gif differ diff --git a/results/2023-12-08-0138-56.gif b/results/2023-12-08-0138-56.gif new file mode 100644 index 0000000000000000000000000000000000000000..a7017bbfc9a0d789f2c2ef99152e153074c220e3 Binary files /dev/null and b/results/2023-12-08-0138-56.gif differ diff --git a/results/2023-12-08-0143-45.gif b/results/2023-12-08-0143-45.gif new file mode 100644 index 0000000000000000000000000000000000000000..8df7361592724a153d448030db04d797527df335 Binary files /dev/null and b/results/2023-12-08-0143-45.gif differ diff --git a/results/2023-12-08-0143-46.png b/results/2023-12-08-0143-46.png new file mode 100644 index 0000000000000000000000000000000000000000..960d9fedca4f4bf5cafbef9ce37cc30e9663eacd --- /dev/null +++ b/results/2023-12-08-0143-46.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b6d0b1d1b57a0a758778ce0a37467abe910d5fc0559aa22eacd1695924f9411 +size 1303041 diff --git a/results/2023-12-08-0143-46_tracking_points.gif b/results/2023-12-08-0143-46_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..4af9c30d6aa6473c058f8561eee75b73f8b39feb Binary files /dev/null and b/results/2023-12-08-0143-46_tracking_points.gif differ diff --git a/results/2023-12-08-0146-41.gif b/results/2023-12-08-0146-41.gif new file mode 100644 index 0000000000000000000000000000000000000000..636f17c5c3c1d815f217e86646fadb2c6b3498ef --- /dev/null +++ b/results/2023-12-08-0146-41.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:feb806618bcc45e393494e67a856f03150446042b8dfb7a7f19863247315348a +size 4057552 diff --git a/results/2023-12-08-0146-45.png b/results/2023-12-08-0146-45.png new file mode 100644 index 0000000000000000000000000000000000000000..fb2efcec17a30d7ca6ea09602437f93084808e40 --- /dev/null +++ b/results/2023-12-08-0146-45.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b963dd693a0437857a8adfaaeb6971f62cbbec04917baa5b10b9f0a46041a05d +size 1300456 diff --git a/results/2023-12-08-0146-45_tracking_points.gif b/results/2023-12-08-0146-45_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..c6545e775f9b0332d66cb399e0ab75a08d698eab Binary files /dev/null and b/results/2023-12-08-0146-45_tracking_points.gif differ diff --git a/results/2023-12-08-0149-29.png b/results/2023-12-08-0149-29.png new file mode 100644 index 0000000000000000000000000000000000000000..3fc039483b0c16bc1aef12c2297d4c3aeeb47cb6 --- /dev/null +++ b/results/2023-12-08-0149-29.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51a8054b309677afda9e1a60d12a46154a6f080f194fe0a2b0e29d7c51c50787 +size 1293201 diff --git a/results/2023-12-08-0152-29.png b/results/2023-12-08-0152-29.png new file mode 100644 index 0000000000000000000000000000000000000000..061b2db1c7a3119b56ab8ea9de1cb8d06d8fb47b --- /dev/null +++ b/results/2023-12-08-0152-29.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:febfabf10259145bb7b05571716c7b352e5b779bdd2a6897b2eef3b7d1a1aa5f +size 1317790 diff --git a/results/2023-12-08-0153-19.png b/results/2023-12-08-0153-19.png new file mode 100644 index 0000000000000000000000000000000000000000..ecb3c44485daf87d9c0bb6e4e94794106d02f143 --- /dev/null +++ b/results/2023-12-08-0153-19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7eb490cdef2c212d5e605b333f71dc728c90d98b514692952c5db10b8f5d2695 +size 1312139 diff --git a/results/2023-12-08-0154-20.png b/results/2023-12-08-0154-20.png new file mode 100644 index 0000000000000000000000000000000000000000..d2224b16e2a6d3e302026f457123ee2978a502d9 --- /dev/null +++ b/results/2023-12-08-0154-20.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10873c500928427789394a9a0b7f60d7800fbc313adb5db0871768f525898828 +size 1275887 diff --git a/results/2023-12-08-0155-38.png b/results/2023-12-08-0155-38.png new file mode 100644 index 0000000000000000000000000000000000000000..1fee3a2e8e178802dff2fca619449d677b3e4e85 --- /dev/null +++ b/results/2023-12-08-0155-38.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39542aaec8b548e1fb02c12967b9015dda1869c98facf700982d0db0ab53d18b +size 1225888 diff --git a/results/2023-12-08-0156-15.png b/results/2023-12-08-0156-15.png new file mode 100644 index 0000000000000000000000000000000000000000..139dea78e4131ae09a4688723baf3014ee841069 --- /dev/null +++ b/results/2023-12-08-0156-15.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b90f6fb0c0c7152d52584709331c2d45c046fe728949da33ae81174febb9536 +size 1267087 diff --git a/results/2023-12-08-0156-34.png b/results/2023-12-08-0156-34.png new file mode 100644 index 0000000000000000000000000000000000000000..e0d6f16f4326ac5506a15df58dddb6aa69625321 --- /dev/null +++ b/results/2023-12-08-0156-34.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6ab9e9617362b9450bf6950c7a231d87b9d304712b40850647185f33e662f2c +size 1223023 diff --git a/results/2023-12-08-0157-09.png b/results/2023-12-08-0157-09.png new file mode 100644 index 0000000000000000000000000000000000000000..886a7756fb030471e73d191d49cafbc5e42dd804 --- /dev/null +++ b/results/2023-12-08-0157-09.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51e7eddbe742edaea713bd709c0c77bd81a5561e2377f9898d75226eeeadcb1b +size 1198168 diff --git a/results/2023-12-08-0157-52.png b/results/2023-12-08-0157-52.png new file mode 100644 index 0000000000000000000000000000000000000000..75f6f4078658f670043c69d6fd4960deab9bb43c --- /dev/null +++ b/results/2023-12-08-0157-52.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:caeafab282c90d0893adb88286463035f802033d6a0033d03f4052770ba919e2 +size 1212222 diff --git a/results/2023-12-08-0159-25.png b/results/2023-12-08-0159-25.png new file mode 100644 index 0000000000000000000000000000000000000000..3578c143156c1b278ee8fe6203a7995c81e83a8c --- /dev/null +++ b/results/2023-12-08-0159-25.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01cfc8d2816fc0beb45337dbae53a212b40a72027e91f8e41680739b8bfe7499 +size 1270081 diff --git a/results/2023-12-08-0200-31.gif b/results/2023-12-08-0200-31.gif new file mode 100644 index 0000000000000000000000000000000000000000..3cd0e1fe0cece8fb99cc461ef645bd4ec798a3da --- /dev/null +++ b/results/2023-12-08-0200-31.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:070658d4beeec5d882ce32c58f31417559a7c4a49f98e1834d66e3faad319855 +size 1730430 diff --git a/results/2023-12-08-0200-33.png b/results/2023-12-08-0200-33.png new file mode 100644 index 0000000000000000000000000000000000000000..2dfef951a89a6955256bfad12d9c8ac57a23a218 --- /dev/null +++ b/results/2023-12-08-0200-33.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47c6c3ef473f262dfeeda3e57d1cf70552724f840ca2160db64f62a1bf0fd0df +size 1240707 diff --git a/results/2023-12-08-0200-33_tracking_points.gif b/results/2023-12-08-0200-33_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..6dd9738583d3a69a9f01ec5a22ca9bdec4edc0cd Binary files /dev/null and b/results/2023-12-08-0200-33_tracking_points.gif differ diff --git a/results/2023-12-08-0202-12.gif b/results/2023-12-08-0202-12.gif new file mode 100644 index 0000000000000000000000000000000000000000..5fff5f873fa8549e3c06c71a3e7ec0e8eef9d61c --- /dev/null +++ b/results/2023-12-08-0202-12.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5db3e6c6bda4f72b67ad10aaffd38168f9814a9e563113aac92080938eb17f07 +size 1863171 diff --git a/results/2023-12-08-0202-12_tracking_points.gif b/results/2023-12-08-0202-12_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..75bb7a11be144cb3b1b2791e68a109950ff24d3c Binary files /dev/null and b/results/2023-12-08-0202-12_tracking_points.gif differ diff --git a/results/2023-12-08-0202-13.png b/results/2023-12-08-0202-13.png new file mode 100644 index 0000000000000000000000000000000000000000..5ea0dd900e4ba70361f12684753a1f4c7d785565 --- /dev/null +++ b/results/2023-12-08-0202-13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d606ee2ffb3f293a026e27a6e9495b71803c98836cd220d1060f994110e49bc2 +size 1153626 diff --git a/results/2023-12-08-0215-08.gif b/results/2023-12-08-0215-08.gif new file mode 100644 index 0000000000000000000000000000000000000000..1dde031745d07df26543b8fa52e445486c955d22 --- /dev/null +++ b/results/2023-12-08-0215-08.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecb5b58136b60c44caccdd34f1063602157855a1a7db53815f79f6f7cb446b1b +size 2093595 diff --git a/results/2023-12-08-0215-10.png b/results/2023-12-08-0215-10.png new file mode 100644 index 0000000000000000000000000000000000000000..06c5a0596486fdec948bed532262da643c3bdafc Binary files /dev/null and b/results/2023-12-08-0215-10.png differ diff --git a/results/2023-12-08-0215-10_tracking_points.gif b/results/2023-12-08-0215-10_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..c2bd7f5cfef35980d804b8db784c185074c89cfd Binary files /dev/null and b/results/2023-12-08-0215-10_tracking_points.gif differ diff --git a/results/2023-12-08-0217-26.gif b/results/2023-12-08-0217-26.gif new file mode 100644 index 0000000000000000000000000000000000000000..370842180700724b7f76c22a715e6826ec377c1b --- /dev/null +++ b/results/2023-12-08-0217-26.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2661cb08c2ac14327d1a2d564d003b909a55ee53cd36d64a7e3b43fe975da41c +size 2031554 diff --git a/results/2023-12-08-0217-28.png b/results/2023-12-08-0217-28.png new file mode 100644 index 0000000000000000000000000000000000000000..95e0ce2e2743779fcd341ae66069d014852146b0 Binary files /dev/null and b/results/2023-12-08-0217-28.png differ diff --git a/results/2023-12-08-0217-28_tracking_points.gif b/results/2023-12-08-0217-28_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..3d5f75bed5fb74b829953ae6d511654801ff9551 Binary files /dev/null and b/results/2023-12-08-0217-28_tracking_points.gif differ diff --git a/results/2023-12-08-0219-21.gif b/results/2023-12-08-0219-21.gif new file mode 100644 index 0000000000000000000000000000000000000000..4e4b176e94802853495530175eef1c1a227b31ef --- /dev/null +++ b/results/2023-12-08-0219-21.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6957e404450d71f6a6f289275213810857f7abf1240b5a6362aa736c6a03aa6e +size 1514933 diff --git a/results/2023-12-08-0219-22.png b/results/2023-12-08-0219-22.png new file mode 100644 index 0000000000000000000000000000000000000000..0c644e5b4242c16cba40ae80ecbc36a2374691a1 Binary files /dev/null and b/results/2023-12-08-0219-22.png differ diff --git a/results/2023-12-08-0219-22_tracking_points.gif b/results/2023-12-08-0219-22_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..073880228c7af439b98f47f92998ea5a1745a287 Binary files /dev/null and b/results/2023-12-08-0219-22_tracking_points.gif differ diff --git a/results/2023-12-08-0223-15.gif b/results/2023-12-08-0223-15.gif new file mode 100644 index 0000000000000000000000000000000000000000..68b05e599d10ee9446bbee9801b2f08a406da5ab --- /dev/null +++ b/results/2023-12-08-0223-15.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c49c52b5c0b187bf8018a489f8a51ce39db7250750966b3e5a586aa639b73493 +size 1925846 diff --git a/results/2023-12-08-0223-17.png b/results/2023-12-08-0223-17.png new file mode 100644 index 0000000000000000000000000000000000000000..e8b1551099856fcac8e2cd104b7c86d12ac56cb6 Binary files /dev/null and b/results/2023-12-08-0223-17.png differ diff --git a/results/2023-12-08-0223-17_tracking_points.gif b/results/2023-12-08-0223-17_tracking_points.gif new file mode 100644 index 0000000000000000000000000000000000000000..489624d384091b552e1f5204c740cc9f54070068 Binary files /dev/null and b/results/2023-12-08-0223-17_tracking_points.gif differ diff --git a/test_lora.py b/test_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..b382ccb917712b308732115d120f4c1265b287b9 --- /dev/null +++ b/test_lora.py @@ -0,0 +1,28 @@ +import torch + +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + +pipeline = StableDiffusionPipeline.from_pretrained( + "andite/anything-v4.0", torch_dtype=torch.float16, safety_checker=None +).to("cuda") +pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, use_karras_sigmas=True +) + +pipeline.load_lora_weights(".", weight_name="genshin.safetensors") +prompt = ("masterpiece, best quality, absurdres, 1girl, school uniform, kangel, smile, standing, contrapposto, " + "bedroom, leaning forward, , , genshin,") +negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " + "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2)") + +images = pipeline(prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=512, + num_inference_steps=20, + num_images_per_prompt=1, + guidance_scale=7.0, + generator=torch.manual_seed(9625644) +).images + +images[0].save("test.jpg") diff --git a/utils/__pycache__/attn_utils.cpython-38.pyc b/utils/__pycache__/attn_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbc6ccab6a06c898d93bf7b04f2d791e3756a311 Binary files /dev/null and b/utils/__pycache__/attn_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/drag_utils.cpython-38.pyc b/utils/__pycache__/drag_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c133fb770b459ba7799162d919947e3367924cf Binary files /dev/null and b/utils/__pycache__/drag_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/freeu_utils.cpython-38.pyc b/utils/__pycache__/freeu_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b41f71d703f95d30b49e581c8c7f8f53ab176a Binary files /dev/null and b/utils/__pycache__/freeu_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/lora_utils.cpython-38.pyc b/utils/__pycache__/lora_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea63a26bc82a9fa3d07a31255f82574b31f87163 Binary files /dev/null and b/utils/__pycache__/lora_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/ui_utils.cpython-38.pyc b/utils/__pycache__/ui_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eb30653d19636835285d9f1a1a1c0bd25110820 Binary files /dev/null and b/utils/__pycache__/ui_utils.cpython-38.pyc differ diff --git a/utils/attn_utils.py b/utils/attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9b8cf4d218b57a9f44683b1e0c32766d628a4c --- /dev/null +++ b/utils/attn_utils.py @@ -0,0 +1,221 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + + +class AttentionBase: + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + def after_step(self): + pass + + def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): + out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + # after step + self.after_step() + return out + + def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = rearrange(out, 'b h n d -> b n (h d)') + return out + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + +class MutualSelfAttentionControl(AttentionBase): + + def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5): + """ + Mutual self-attention control for Stable-Diffusion model + Args: + start_step: the step to start mutual self-attention control + start_layer: the layer to start mutual self-attention control + layer_idx: list of the layers to apply mutual self-attention control + step_idx: list the steps to apply mutual self-attention control + total_steps: the total number of steps + """ + super().__init__() + self.total_steps = total_steps + self.start_step = start_step + self.start_layer = start_layer + self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16)) + self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) + # store the guidance scale to decide whether there are unconditional branch + self.guidance_scale = guidance_scale + print("step_idx: ", self.step_idx) + print("layer_idx: ", self.layer_idx) + + def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): + """ + Attention forward function + """ + if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: + return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) + + if self.guidance_scale > 1.0: + qu, qc = q[0:2], q[2:4] + ku, kc = k[0:2], k[2:4] + vu, vc = v[0:2], v[2:4] + + # merge queries of source and target branch into one so we can use torch API + qu = torch.cat([qu[0:1], qu[1:2]], dim=2) + qc = torch.cat([qc[0:1], qc[1:2]], dim=2) + + out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) + out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch + out_u = rearrange(out_u, 'b h n d -> b n (h d)') + + out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) + out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch + out_c = rearrange(out_c, 'b h n d -> b n (h d)') + + out = torch.cat([out_u, out_c], dim=0) + else: + q = torch.cat([q[0:1], q[1:2]], dim=2) + out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) + out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch + out = rearrange(out, 'b h n d -> b n (h d)') + return out + +# forward function for default attention processor +# modified from __call__ function of AttnProcessor in diffusers +def override_attn_proc_forward(attn, editor, place_in_unet): + def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): + """ + The attention is similar to the original implementation of LDM CrossAttention class + except adding some modifications on the attention + """ + if encoder_hidden_states is not None: + context = encoder_hidden_states + if attention_mask is not None: + mask = attention_mask + + to_out = attn.to_out + if isinstance(to_out, nn.modules.container.ModuleList): + to_out = attn.to_out[0] + else: + to_out = attn.to_out + + h = attn.heads + q = attn.to_q(x) + is_cross = context is not None + context = context if is_cross else x + k = attn.to_k(context) + v = attn.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + # the only difference + out = editor( + q, k, v, is_cross, place_in_unet, + attn.heads, scale=attn.scale) + + return to_out(out) + + return forward + +# forward function for lora attention processor +# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1 +def override_lora_attn_proc_forward(attn, editor, place_in_unet): + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0): + residual = hidden_states + input_ndim = hidden_states.ndim + is_cross = encoder_hidden_states is not None + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states) + + query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value)) + + # the only difference + hidden_states = editor( + query, key, value, is_cross, place_in_unet, + attn.heads, scale=attn.scale) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + return forward + +def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'): + """ + Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] + """ + def register_editor(net, count, place_in_unet): + for name, subnet in net.named_children(): + if net.__class__.__name__ == 'Attention': # spatial Transformer layer + if attn_processor == 'attn_proc': + net.forward = override_attn_proc_forward(net, editor, place_in_unet) + elif attn_processor == 'lora_attn_proc': + net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet) + else: + raise NotImplementedError("not implemented") + return count + 1 + elif hasattr(net, 'children'): + count = register_editor(subnet, count, place_in_unet) + return count + + cross_att_count = 0 + for net_name, net in model.unet.named_children(): + if "down" in net_name: + cross_att_count += register_editor(net, 0, "down") + elif "mid" in net_name: + cross_att_count += register_editor(net, 0, "mid") + elif "up" in net_name: + cross_att_count += register_editor(net, 0, "up") + editor.num_att_layers = cross_att_count diff --git a/utils/drag_utils.py b/utils/drag_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b05afb8b3b30cd55b847b5830c2dc787452303 --- /dev/null +++ b/utils/drag_utils.py @@ -0,0 +1,267 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import copy +import torch +import torch.nn.functional as F + + +def point_tracking(F0, + F1, + handle_points, + handle_points_init, + args): + with torch.no_grad(): + for i in range(len(handle_points)): + pi0, pi = handle_points_init[i], handle_points[i] + f0 = F0[:, :, int(pi0[0]), int(pi0[1])] + + r1, r2 = int(pi[0])-args.r_p, int(pi[0])+args.r_p+1 + c1, c2 = int(pi[1])-args.r_p, int(pi[1])+args.r_p+1 + F1_neighbor = F1[:, :, r1:r2, c1:c2] + all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1) + all_dist = all_dist.squeeze(dim=0) + # WARNING: no boundary protection right now + row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1]) + handle_points[i][0] = pi[0] - args.r_p + row + handle_points[i][1] = pi[1] - args.r_p + col + return handle_points + +def check_handle_reach_target(handle_points, + target_points): + # dist = (torch.cat(handle_points,dim=0) - torch.cat(target_points,dim=0)).norm(dim=-1) + all_dist = list(map(lambda p,q: (p-q).norm(), handle_points, target_points)) + return (torch.tensor(all_dist) < 2.0).all() + +# obtain the bilinear interpolated feature patch centered around (x, y) with radius r +def interpolate_feature_patch(feat, + y, + x, + r): + x0 = torch.floor(x).long() + x1 = x0 + 1 + + y0 = torch.floor(y).long() + y1 = y0 + 1 + + wa = (x1.float() - x) * (y1.float() - y) + wb = (x1.float() - x) * (y - y0.float()) + wc = (x - x0.float()) * (y1.float() - y) + wd = (x - x0.float()) * (y - y0.float()) + + Ia = feat[:, :, y0-r:y0+r+1, x0-r:x0+r+1] + Ib = feat[:, :, y1-r:y1+r+1, x0-r:x0+r+1] + Ic = feat[:, :, y0-r:y0+r+1, x1-r:x1+r+1] + Id = feat[:, :, y1-r:y1+r+1, x1-r:x1+r+1] + + return Ia * wa + Ib * wb + Ic * wc + Id * wd + +def drag_diffusion_update(model, + init_code, + t, + handle_points, + target_points, + mask, + args): + + assert len(handle_points) == len(target_points), \ + "number of handle point must equals target points" + + text_emb = model.get_text_embeddings(args.prompt).detach() + # the init output feature of unet + with torch.no_grad(): + unet_output, F0 = model.forward_unet_features(init_code, t, encoder_hidden_states=text_emb, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res, interp_res_w=args.sup_res) + x_prev_0,_ = model.step(unet_output, t, init_code) + # init_code_orig = copy.deepcopy(init_code) + + # prepare optimizable init_code and optimizer + init_code.requires_grad_(True) + optimizer = torch.optim.Adam([init_code], lr=args.lr) + + # prepare for point tracking and background regularization + handle_points_init = copy.deepcopy(handle_points) + interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest') + + # prepare amp scaler for mixed-precision training + scaler = torch.cuda.amp.GradScaler() + gif_init_code = [] + for step_idx in range(args.n_pix_step): + with torch.autocast(device_type='cuda', dtype=torch.float16): + unet_output, F1 = model.forward_unet_features(init_code, t, encoder_hidden_states=text_emb, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res, interp_res_w=args.sup_res) + x_prev_updated,_ = model.step(unet_output, t, init_code) + + # do point tracking to update handle points before computing motion supervision loss + if step_idx != 0: + handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args) + print('new handle points', handle_points) + + # break if all handle points have reached the targets + if check_handle_reach_target(handle_points, target_points): + break + + loss = 0.0 + for i in range(len(handle_points)): + pi, ti = handle_points[i], target_points[i] + # skip if the distance between target and source is less than 1 + if (ti - pi).norm() < 2.: + continue + + di = (ti - pi) / (ti - pi).norm() + + # motion supervision + f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach() + f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m) + loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch) + + # masked region must stay unchanged + loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() + # loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum() + print('loss total=%f'%(loss.item())) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + if args.create_gif_checkbox: + if step_idx % args.gif_interval == 0 or step_idx == args.n_pix_step - 1: + gif_init_code.append(init_code) + return init_code, gif_init_code + +def drag_diffusion_update_gen(model, + init_code, + t, + handle_points, + target_points, + mask, + args): + + assert len(handle_points) == len(target_points), \ + "number of handle point must equals target points" + + # positive prompt embedding + text_emb = model.get_text_embeddings(args.prompt).detach() + if args.guidance_scale > 1.0: + unconditional_input = model.tokenizer( + [args.neg_prompt], + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_emb = model.text_encoder(unconditional_input.input_ids.to(text_emb.device))[0].detach() + text_emb = torch.cat([unconditional_emb, text_emb], dim=0) + + # the init output feature of unet + with torch.no_grad(): + if args.guidance_scale > 1.: + model_inputs_0 = copy.deepcopy(torch.cat([init_code] * 2)) + else: + model_inputs_0 = copy.deepcopy(init_code) + unet_output, F0 = model.forward_unet_features(model_inputs_0, t, encoder_hidden_states=text_emb, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w) + if args.guidance_scale > 1.: + # strategy 1: discard the unconditional branch feature maps + # F0 = F0[1].unsqueeze(dim=0) + # strategy 2: concat pos and neg branch feature maps for motion-sup and point tracking + # F0 = torch.cat([F0[0], F0[1]], dim=0).unsqueeze(dim=0) + # strategy 3: concat pos and neg branch feature maps with guidance_scale consideration + coef = args.guidance_scale / (2*args.guidance_scale - 1.0) + F0 = torch.cat([(1-coef)*F0[0], coef*F0[1]], dim=0).unsqueeze(dim=0) + + unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0) + unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon) + x_prev_0,_ = model.step(unet_output, t, init_code) + # init_code_orig = copy.deepcopy(init_code) + + # prepare optimizable init_code and optimizer + init_code.requires_grad_(True) + optimizer = torch.optim.Adam([init_code], lr=args.lr) + + # prepare for point tracking and background regularization + handle_points_init = copy.deepcopy(handle_points) + interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest') + + # prepare amp scaler for mixed-precision training + scaler = torch.cuda.amp.GradScaler() + + gif_init_code = [] + handle_points_list = [] + for step_idx in range(args.n_pix_step): + with torch.autocast(device_type='cuda', dtype=torch.float16): + if args.guidance_scale > 1.: + model_inputs = init_code.repeat(2,1,1,1) + else: + model_inputs = init_code + unet_output, F1 = model.forward_unet_features(model_inputs, t, encoder_hidden_states=text_emb, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w) + if args.guidance_scale > 1.: + # strategy 1: discard the unconditional branch feature maps + # F1 = F1[1].unsqueeze(dim=0) + # strategy 2: concat positive and negative branch feature maps for motion-sup and point tracking + # F1 = torch.cat([F1[0], F1[1]], dim=0).unsqueeze(dim=0) + # strategy 3: concat pos and neg branch feature maps with guidance_scale consideration + coef = args.guidance_scale / (2*args.guidance_scale - 1.0) + F1 = torch.cat([(1-coef)*F1[0], coef*F1[1]], dim=0).unsqueeze(dim=0) + + unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0) + unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon) + x_prev_updated,_ = model.step(unet_output, t, init_code) + + # do point tracking to update handle points before computing motion supervision loss + if step_idx != 0: + handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args) + print('new handle points', handle_points) + print('target points', target_points) + + + # break if all handle points have reached the targets + if check_handle_reach_target(handle_points, target_points): + break + + loss = 0.0 + for i in range(len(handle_points)): + pi, ti = handle_points[i], target_points[i] + # skip if the distance between target and source is less than 1 + if (ti - pi).norm() < 2.: + continue + + di = (ti - pi) / (ti - pi).norm() + + # motion supervision + f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach() + f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m) + loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch) + + # masked region must stay unchanged + loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() + # loss += args.lam * ((init_code_orig - init_code)*(1.0-interp_mask)).abs().sum() + print('loss total=%f'%(loss.item())) + if args.create_gif_checkbox: + if step_idx % args.gif_interval == 0 or step_idx == args.n_pix_step - 1: + gif_init_code.append(init_code) + + if args.create_tracking_points_checkbox: + if step_idx % args.gif_interval == 0 or step_idx == args.n_pix_step - 1: + handle_points_list.append(handle_points) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + return init_code, gif_init_code, handle_points_list \ No newline at end of file diff --git a/utils/lora_utils.py b/utils/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3eec03a63f69e9216548007006fe51ea7dab71be --- /dev/null +++ b/utils/lora_utils.py @@ -0,0 +1,269 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +from PIL import Image +import os +import numpy as np +from einops import rearrange +import torch +import torch.nn.functional as F +from torchvision import transforms +from accelerate import Accelerator +from accelerate.utils import set_seed +from PIL import Image + +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0") + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + +# model_path: path of the model +# image: input image, have not been pre-processed +# save_lora_path: the path to save the lora +# prompt: the user input prompt +# lora_step: number of lora training step +# lora_lr: learning rate of lora training +# lora_rank: the rank of lora +def train_lora(image, prompt, model_path, vae_path, save_lora_path, lora_step, lora_lr, lora_rank, progress): + # initialize accelerator + accelerator = Accelerator( + gradient_accumulation_steps=1, + mixed_precision='fp16' + ) + set_seed(0) + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder="tokenizer", + revision=None, + use_fast=False, + ) + # initialize the model + noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None) + text_encoder = text_encoder_cls.from_pretrained( + model_path, subfolder="text_encoder", revision=None + ) + if vae_path == "default": + vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", revision=None + ) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + unet = UNet2DConditionModel.from_pretrained( + model_path, subfolder="unet", revision=None + ) + + # set device and dtype + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + unet.to(device, dtype=torch.float16) + vae.to(device, dtype=torch.float16) + text_encoder.to(device, dtype=torch.float16) + + # initialize UNet LoRA + unet_lora_attn_procs = {} + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + else: + raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + unet_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank + ) + + unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # Optimizer creation + params_to_optimize = (unet_lora_layers.parameters()) + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=lora_lr, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + lr_scheduler = get_scheduler( + "constant", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=lora_step, + num_cycles=1, + power=1.0, + ) + + # prepare accelerator + unet_lora_layers = accelerator.prepare_model(unet_lora_layers) + optimizer = accelerator.prepare_optimizer(optimizer) + lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) + + # initialize text embeddings + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None) + text_embedding = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=False + ) + + # initialize latent distribution + image_transforms = transforms.Compose( + [ + # transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.RandomCrop(512), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + image = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16) + image = image.unsqueeze(dim=0) + latents_dist = vae.encode(image).latent_dist + for _ in progress.tqdm(range(lora_step), desc="training LoRA"): + unet.train() + model_input = latents_dist.sample() * vae.config.scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_pred = unet(noisy_model_input, timesteps, text_embedding).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # save the trained lora + unet = unet.to(torch.float32) + # unwrap_model is used to remove all special modules added when doing distributed training + # so here, there is no need to call unwrap_model + # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + LoraLoaderMixin.save_lora_weights( + save_directory=save_lora_path, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=None, + ) + + return diff --git a/utils/ui_utils.py b/utils/ui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..04256e40f81d99b4f98c1dccda9f4dcedeeac403 --- /dev/null +++ b/utils/ui_utils.py @@ -0,0 +1,625 @@ +import os +import cv2 +import numpy as np +import gradio as gr +from copy import deepcopy +from einops import rearrange +from types import SimpleNamespace + +import datetime +import PIL +from PIL import Image +from PIL.ImageOps import exif_transpose +import torch +import torch.nn.functional as F + +from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler +from drag_pipeline import DragPipeline + +from torchvision.utils import save_image +from pytorch_lightning import seed_everything + +from .drag_utils import drag_diffusion_update, drag_diffusion_update_gen +from .lora_utils import train_lora +from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl + +import imageio + + +# -------------- general UI functionality -------------- +def clear_all(length=480): + return gr.Image.update(value=None, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + [], None, None + +def clear_all_gen(length=480): + return gr.Image.update(value=None, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + [], None, None, None + +def mask_image(image, + mask, + color=[255,0,0], + alpha=0.5): + """ Overlay mask on image for visualization purpose. + Args: + image (H, W, 3) or (H, W): input image + mask (H, W): mask to be overlaid + color: the color of overlaid mask + alpha: the transparency of the mask + """ + out = deepcopy(image) + img = deepcopy(image) + img[mask == 1] = color + out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) + return out + +def store_img(img, length=512): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + height,width,_ = image.shape + image = Image.fromarray(image) + image = exif_transpose(image) + image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) + mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) + image = np.array(image) + + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + # when new image is uploaded, `selected_points` should be empty + return image, [], masked_img, mask + +# once user upload an image, the original image is stored in `original_image` +# the same image is displayed in `input_image` for point clicking purpose +def store_img_gen(img): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + image = Image.fromarray(image) + image = exif_transpose(image) + image = np.array(image) + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + # when new image is uploaded, `selected_points` should be empty + return image, [], masked_img, mask + +# user click the image to get points, and show the points on the image +def get_points(img, + sel_pix, + evt: gr.SelectData): + img_copy = img.copy() if isinstance(img, np.ndarray) else np.array(img) + # collect the selected point + sel_pix.append(evt.index) + # draw points + points = [] + for idx, point in enumerate(sel_pix): + if idx % 2 == 0: + # draw a red circle at the handle point + cv2.circle(img_copy, tuple(point), 10, (255, 0, 0), -1) + else: + # draw a blue circle at the handle point + cv2.circle(img_copy, tuple(point), 10, (0, 0, 255), -1) + points.append(tuple(point)) + # draw an arrow from handle point to target point + if len(points) == 2: + cv2.arrowedLine(img_copy, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) + points = [] + return img_copy if isinstance(img, np.ndarray) else np.array(img_copy) + +# clear all handle/target points +def undo_points(original_image, + mask): + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = original_image.copy() + return masked_img, [] +# ------------------------------------------------------ + +# ----------- dragging user-input image utils ----------- +def train_lora_interface(original_image, + prompt, + model_path, + vae_path, + lora_path, + lora_step, + lora_lr, + lora_rank, + progress=gr.Progress()): + train_lora( + original_image, + prompt, + model_path, + vae_path, + lora_path, + lora_step, + lora_lr, + lora_rank, + progress) + return "Training LoRA Done!" + +def preprocess_image(image, + device): + image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] + image = rearrange(image, "h w c -> 1 c h w") + image = image.to(device) + return image + +def run_drag(source_image, + image_with_clicks, + mask, + prompt, + points, + inversion_strength, + lam, + latent_lr, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + create_gif_checkbox, + gif_interval, + save_dir="./results" + ): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + + # initialize parameters + seed = 42 # random seed used by a lot of people for unknown reason + seed_everything(seed) + + args = SimpleNamespace() + args.prompt = prompt + args.points = points + args.n_inference_step = 50 + args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) + args.guidance_scale = 1.0 + + args.unet_feature_idx = [3] + + args.sup_res = 256 + + args.r_m = 1 + args.r_p = 3 + args.lam = lam + + args.lr = latent_lr + + args.n_pix_step = n_pix_step + args.create_gif_checkbox = create_gif_checkbox + args.gif_interval = gif_interval + print(args) + + full_h, full_w = source_image.shape[:2] + + source_image = preprocess_image(source_image, device) + image_with_clicks = preprocess_image(image_with_clicks, device) + + # set lora + if lora_path == "": + print("applying default parameters") + model.unet.set_default_attn_processor() + else: + print("applying lora: " + lora_path) + model.unet.load_attn_procs(lora_path) + + # invert the source image + # the latent code resolution is too small, only 64*64 + invert_code = model.invert(source_image, + prompt, + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step) + + mask = torch.from_numpy(mask).float() / 255. + mask[mask > 0.0] = 1.0 + mask = rearrange(mask, "h w -> 1 1 h w").cuda() + mask = F.interpolate(mask, (args.sup_res, args.sup_res), mode="nearest") + + handle_points = [] + target_points = [] + # here, the point is in x,y coordinate + for idx, point in enumerate(points): + cur_point = torch.tensor([point[1] / full_h, point[0] / full_w]) * args.sup_res + cur_point = torch.round(cur_point) + if idx % 2 == 0: + handle_points.append(cur_point) + else: + target_points.append(cur_point) + print('handle points:', handle_points) + print('target points:', target_points) + + init_code = invert_code + init_code_orig = deepcopy(init_code) + model.scheduler.set_timesteps(args.n_inference_step) + t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] + + # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] + # update according to the given supervision + updated_init_code, gif_updated_init_code = drag_diffusion_update(model, init_code, t, + handle_points, target_points, mask, args) + + # hijack the attention module + # inject the reference branch to guide the generation + editor = MutualSelfAttentionControl(start_step=start_step, + start_layer=start_layer, + total_steps=args.n_inference_step, + guidance_scale=args.guidance_scale) + if lora_path == "": + register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') + else: + register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') + + # inference the synthesized image + gen_image = model( + prompt=args.prompt, + batch_size=2, + latents=torch.cat([init_code_orig, updated_init_code], dim=0), + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + )[1].unsqueeze(dim=0) + + # if gif, inference the synthesized image for each step and save them to gif + if args.create_gif_checkbox: + out_frames = [] + for step_updated_init_code in gif_updated_init_code: + gen_image = model( + prompt=args.prompt, + batch_size=1, + latents=step_updated_init_code, + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + ).unsqueeze(dim=0) + out_frame = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_frame = (out_frame * 255).astype(np.uint8) + out_frames.append(out_frame) + #save the gif + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + imageio.mimsave(os.path.join(save_dir, save_prefix + '.gif'), out_frames, fps=10) + + + # save the original image, user editing instructions, synthesized image + save_result = torch.cat([ + source_image * 0.5 + 0.5, + torch.ones((1,3,512,25)).cuda(), + image_with_clicks * 0.5 + 0.5, + torch.ones((1,3,512,25)).cuda(), + gen_image[0:1] + ], dim=-1) + + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_image = (out_image * 255).astype(np.uint8) + return out_image + +# ------------------------------------------------------- + +# ----------- dragging generated image utils ----------- +# once the user generated an image +# it will be displayed on mask drawing-areas and point-clicking area +def gen_img( + length, # length of the window displaying the image + height, # height of the generated image + width, # width of the generated image + n_inference_step, + scheduler_name, + seed, + guidance_scale, + prompt, + neg_prompt, + model_path, + vae_path, + lora_path): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device) + if scheduler_name == "DDIM": + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + elif scheduler_name == "DPM++2M": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config + ) + elif scheduler_name == "DPM++2M_karras": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config, use_karras_sigmas=True + ) + else: + raise NotImplementedError("scheduler name not correct") + model.scheduler = scheduler + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + # set lora + #if lora_path != "": + # print("applying lora for image generation: " + lora_path) + # model.unet.load_attn_procs(lora_path) + if lora_path != "": + print("applying lora: " + lora_path) + model.load_lora_weights(lora_path, weight_name="lora.safetensors") + + # initialize init noise + seed_everything(seed) + init_noise = torch.randn([1, 4, height // 8, width // 8], device=device, dtype=model.vae.dtype) + gen_image, intermediate_latents = model(prompt=prompt, + neg_prompt=neg_prompt, + num_inference_steps=n_inference_step, + latents=init_noise, + guidance_scale=guidance_scale, + return_intermediates=True) + gen_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + gen_image = (gen_image * 255).astype(np.uint8) + + if height < width: + # need to do this due to Gradio's bug + return gr.Image.update(value=gen_image, height=int(length*height/width), width=length), \ + gr.Image.update(height=int(length*height/width), width=length), \ + gr.Image.update(height=int(length*height/width), width=length), \ + None, \ + intermediate_latents + else: + return gr.Image.update(value=gen_image, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + None, \ + intermediate_latents + +def run_drag_gen( + n_inference_step, + scheduler_name, + source_image, + image_with_clicks, + intermediate_latents_gen, + guidance_scale, + mask, + prompt, + neg_prompt, + points, + inversion_strength, + lam, + latent_lr, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + create_gif_checkbox, + create_tracking_points_checkbox, + gif_interval, + gif_fps, + save_dir="./results"): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device) + if scheduler_name == "DDIM": + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + elif scheduler_name == "DPM++2M": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config + ) + elif scheduler_name == "DPM++2M_karras": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config, use_karras_sigmas=True + ) + else: + raise NotImplementedError("scheduler name not correct") + model.scheduler = scheduler + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + + # initialize parameters + seed = 42 # random seed used by a lot of people for unknown reason + seed_everything(seed) + + args = SimpleNamespace() + args.prompt = prompt + args.neg_prompt = neg_prompt + args.points = points + args.n_inference_step = n_inference_step + args.n_actual_inference_step = round(n_inference_step * inversion_strength) + args.guidance_scale = guidance_scale + + args.unet_feature_idx = [3] + + full_h, full_w = source_image.shape[:2] + + args.sup_res_h = int(0.5*full_h) + args.sup_res_w = int(0.5*full_w) + + args.r_m = 1 + args.r_p = 3 + args.lam = lam + + args.lr = latent_lr + + args.n_pix_step = n_pix_step + args.create_gif_checkbox = create_gif_checkbox + args.create_tracking_points_checkbox = create_tracking_points_checkbox + args.gif_interval = gif_interval + print(args) + + source_image = preprocess_image(source_image, device) + image_with_clicks = preprocess_image(image_with_clicks, device) + + # set lora + #if lora_path == "": + # print("applying default parameters") + # model.unet.set_default_attn_processor() + #else: + # print("applying lora: " + lora_path) + # model.unet.load_attn_procs(lora_path) + if lora_path != "": + print("applying lora: " + lora_path) + model.load_lora_weights(lora_path, weight_name="lora.safetensors") + + mask = torch.from_numpy(mask).float() / 255. + mask[mask > 0.0] = 1.0 + mask = rearrange(mask, "h w -> 1 1 h w").cuda() + mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") + + handle_points = [] + target_points = [] + # here, the point is in x,y coordinate + for idx, point in enumerate(points): + cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) + cur_point = torch.round(cur_point) + if idx % 2 == 0: + handle_points.append(cur_point) + else: + target_points.append(cur_point) + print('handle points:', handle_points) + print('target points:', target_points) + + model.scheduler.set_timesteps(args.n_inference_step) + t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] + init_code = deepcopy(intermediate_latents_gen[args.n_inference_step - args.n_actual_inference_step]) + init_code_orig = deepcopy(init_code) + + # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] + # update according to the given supervision + init_code = init_code.to(torch.float32) + model = model.to(device, torch.float32) + updated_init_code, gif_updated_init_code, handle_points_list = drag_diffusion_update_gen(model, init_code, t, + handle_points, target_points, mask, args) + updated_init_code = updated_init_code.to(torch.float16) + model = model.to(device, torch.float16) + + # hijack the attention module + # inject the reference branch to guide the generation + editor = MutualSelfAttentionControl(start_step=start_step, + start_layer=start_layer, + total_steps=args.n_inference_step, + guidance_scale=args.guidance_scale) + if lora_path == "": + register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') + else: + register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') + + # inference the synthesized image + gen_image = model( + prompt=args.prompt, + neg_prompt=args.neg_prompt, + batch_size=2, # batch size is 2 because we have reference init_code and updated init_code + latents=torch.cat([init_code_orig, updated_init_code], dim=0), + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + )[1].unsqueeze(dim=0) + # if gif, inference the synthesized image for each step and save them to gif + if args.create_gif_checkbox: + out_frames = [] + print('Start Generate GIF') + for step_updated_init_code in gif_updated_init_code: + step_updated_init_code = step_updated_init_code.to(torch.float16) + gen_image = model( + prompt=args.prompt, + batch_size=2, + latents=torch.cat([init_code_orig, step_updated_init_code], dim=0), + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + )[1].unsqueeze(dim=0) + out_frame = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_frame = (out_frame * 255).astype(np.uint8) + out_frames.append(out_frame) + #save the gif + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + imageio.mimsave(os.path.join(save_dir, save_prefix + '.gif'), out_frames, fps=gif_fps) + + if args.create_tracking_points_checkbox: + white_image_base = np.ones((full_h, full_w, 3), dtype=np.uint8) * 255 + out_points_frames = [] + previous_points = {i: None for i in range(len(handle_points))} # To store the previous locations of points + print('Start Generate Tracking Points GIF', len(handle_points_list), handle_points_list) + for step_idx, step_handle_points in enumerate(handle_points_list): + out_points_frame = white_image_base.copy() + + for idx, point in enumerate(step_handle_points): + current_point = (int(point[1].item()), int(point[0].item())) + # Draw a circle at the handle point + cv2.circle(out_points_frame, current_point, 4, (0, 0, 255), -1) + # Optionally, add text labels + cv2.putText(out_points_frame, f'P{idx}', (current_point[0] + 5, current_point[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + + # Draw lines to show trajectory + if previous_points[idx] is not None: + cv2.line(out_points_frame, previous_points[idx], current_point, (0, 255, 0), 2) + previous_points[idx] = current_point + + out_points_frame = out_points_frame.astype(np.uint8) + out_points_frames.append(out_points_frame) + + # Save the gif + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + imageio.mimsave(os.path.join(save_dir, save_prefix + '_tracking_points.gif'), out_points_frames, fps=gif_fps) + + + # save the original image, user editing instructions, synthesized image + save_result = torch.cat([ + source_image * 0.5 + 0.5, + torch.ones((1,3,full_h,25)).cuda(), + image_with_clicks * 0.5 + 0.5, + torch.ones((1,3,full_h,25)).cuda(), + gen_image[0:1] + ], dim=-1) + + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_image = (out_image * 255).astype(np.uint8) + return out_image + +# ------------------------------------------------------