diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..ff4956d706393bfe3dd5a3c835a181af1fed1ac0 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..7d435297a69c69210358f0f05876f842776d2f0a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,100 @@ +name: Bug Report +description: You think somethings is broken in the UI +title: "[Bug]: " +labels: ["bug-report"] + +body: + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. + options: + - label: I have searched the existing issues and checked the recent builds/commits + required: true + - type: markdown + attributes: + value: | + *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** + - type: textarea + id: what-did + attributes: + label: What happened? + description: Tell us what happened in a very clear and simple way + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce the problem + description: Please provide us with precise step by step information on how to reproduce the bug + value: | + 1. Go to .... + 2. Press .... + 3. ... + validations: + required: true + - type: textarea + id: what-should + attributes: + label: What should have happened? + description: Tell what you think the normal behavior should be + validations: + required: true + - type: input + id: commit + attributes: + label: Commit where the problem happens + description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) + validations: + required: true + - type: dropdown + id: platforms + attributes: + label: What platforms do you use to access the UI ? + multiple: true + options: + - Windows + - Linux + - MacOS + - iOS + - Android + - Other/Cloud + - type: dropdown + id: browsers + attributes: + label: What browsers do you use to access the UI ? + multiple: true + options: + - Mozilla Firefox + - Google Chrome + - Brave + - Apple Safari + - Microsoft Edge + - type: textarea + id: cmdargs + attributes: + label: Command Line Arguments + description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. + render: Shell + validations: + required: true + - type: textarea + id: extensions + attributes: + label: List of extensions + description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise. + validations: + required: true + - type: textarea + id: logs + attributes: + label: Console logs + description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. + render: Shell + validations: + required: true + - type: textarea + id: misc + attributes: + label: Additional information + description: Please provide us with any relevant additional info or context. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..f58c94a9be6847193a971ac67aa83e9a6d75c0ae --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: WebUI Community Support + url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions + about: Please ask and answer questions here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..35a887408c1a0cb7d5bbf0a8444d0903a708be75 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,40 @@ +name: Feature request +description: Suggest an idea for this project +title: "[Feature Request]: " +labels: ["enhancement"] + +body: + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. + options: + - label: I have searched the existing issues and checked the recent builds/commits + required: true + - type: markdown + attributes: + value: | + *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* + - type: textarea + id: feature + attributes: + label: What would your feature do ? + description: Tell us about your feature in a very clear and simple way, and what problem it would solve + validations: + required: true + - type: textarea + id: workflow + attributes: + label: Proposed workflow + description: Please provide us with step by step information on how you'd like the feature to be accessed and used + value: | + 1. Go to .... + 2. Press .... + 3. ... + validations: + required: true + - type: textarea + id: misc + attributes: + label: Additional information + description: Add any other context or screenshots about the feature request here. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..69056331b8f56cb7a0f8bcff3010eb2a67c141c7 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,28 @@ +# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! + +If you have a large change, pay special attention to this paragraph: + +> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. + +Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. + +**Describe what this pull request is trying to achieve.** + +A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. + +**Additional notes and description of your changes** + +More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. + +**Environment this was tested in** + +List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. + - OS: [e.g. Windows, Linux] + - Browser: [e.g. chrome, safari] + - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] + +**Screenshots or videos of your changes** + +If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. + +This is **required** for anything that touches the user interface. \ No newline at end of file diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a168be5b8be804f657bd109355dc18542f03ba6c --- /dev/null +++ b/.github/workflows/on_pull_request.yaml @@ -0,0 +1,39 @@ +# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file +name: Run Linting/Formatting on Pull Requests + +on: + - push + - pull_request + # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs + # if you want to filter out branches, delete the `- pull_request` and uncomment these lines : + # pull_request: + # branches: + # - master + # branches-ignore: + # - development + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: 3.10.6 + cache: pip + cache-dependency-path: | + **/requirements*txt + - name: Install PyLint + run: | + python -m pip install --upgrade pip + pip install pylint + # This lets PyLint check to see if it can resolve imports + - name: Install dependencies + run: | + export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" + python launch.py + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be7ffa23b91e3112daf42fbd04e4c848dcf5e21b --- /dev/null +++ b/.github/workflows/run_tests.yaml @@ -0,0 +1,29 @@ +name: Run basic features tests on CPU with empty SD model + +on: + - push + - pull_request + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: 3.10.6 + cache: pip + cache-dependency-path: | + **/requirements*txt + - name: Run tests + run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + - name: Upload main app stdout-stderr + uses: actions/upload-artifact@v3 + if: always() + with: + name: stdout-stderr + path: | + test/stdout.txt + test/stderr.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f5557719e5a4ff762ba9b3a76b46c483ab2c323d --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +__pycache__ +*.ckpt +*.safetensors +*.pth +/ESRGAN/* +/SwinIR/* +/repositories +/venv +/tmp +/model.ckpt +/models/**/* +/GFPGANv1.3.pth +/gfpgan/weights/*.pth +/ui-config.json +/outputs +/config.json +/log +/webui.settings.bat +/embeddings +/styles.csv +/params.txt +/styles.csv.bak +/webui-user.bat +/webui-user.sh +/interrogate +/user.css +/.idea +notification.mp3 +/SwinIR +/textual_inversion +.vscode +# /extensions +/test/stdout.txt +/test/stderr.txt +/cache.json diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..53254e5dcfd871c8c0f0f4dec9dceeb1ba967eda --- /dev/null +++ b/.pylintrc @@ -0,0 +1,3 @@ +# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html +[MESSAGES CONTROL] +disable=C,R,W,E,I diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..7438c9bc69d2a53df3dffd053ca82a689f97adb2 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,12 @@ +* @AUTOMATIC1111 + +# if you were managing a localization and were removed from this file, this is because +# the intended way to do localizations now is via extensions. See: +# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions +# Make a repo with your localization and since you are still listed as a collaborator +# you can add it to the wiki page yourself. This change is because some people complained +# the git commit log is cluttered with things unrelated to almost everyone and +# because I believe this is the best overall for the project to handle localizations almost +# entirely without my oversight. + + diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..1457754386456039688ccf0f6268dfcc2a94587c --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,663 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (c) 2023 AUTOMATIC1111 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..55105e500ec05540f173f433aee7b7aa36c4d5df --- /dev/null +++ b/README.md @@ -0,0 +1,173 @@ +--- +title: Stable Diffusion Webui +emoji: 🌖 +colorFrom: pink +colorTo: blue +sdk: gradio +sdk_version: 3.19.1 +app_file: launch.py +pinned: false +duplicated_from: user238921933/stable-diffusion-webui +--- + +# Stable Diffusion web UI +A browser interface based on Gradio library for Stable Diffusion. + +![](screenshot.png) + +## Features +[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): +- Original txt2img and img2img modes +- One click install and run script (but you still must install python and git) +- Outpainting +- Inpainting +- Color Sketch +- Prompt Matrix +- Stable Diffusion Upscale +- Attention, specify parts of text that the model should pay more attention to + - a man in a ((tuxedo)) - will pay more attention to tuxedo + - a man in a (tuxedo:1.21) - alternative syntax + - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) +- Loopback, run img2img processing multiple times +- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters +- Textual Inversion + - have as many embeddings as you want and use any names you like for them + - use multiple embeddings with different numbers of vectors per token + - works with half precision floating point numbers + - train embeddings on 8GB (also reports of 6GB working) +- Extras tab with: + - GFPGAN, neural network that fixes faces + - CodeFormer, face restoration tool as an alternative to GFPGAN + - RealESRGAN, neural network upscaler + - ESRGAN, neural network upscaler with a lot of third party models + - SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers + - LDSR, Latent diffusion super resolution upscaling +- Resizing aspect ratio options +- Sampling method selection + - Adjust sampler eta values (noise multiplier) + - More advanced noise setting options +- Interrupt processing at any time +- 4GB video card support (also reports of 2GB working) +- Correct seeds for batches +- Live prompt token length validation +- Generation parameters + - parameters you used to generate images are saved with that image + - in PNG chunks for PNG, in EXIF for JPEG + - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI + - can be disabled in settings + - drag and drop an image/text-parameters to promptbox +- Read Generation Parameters Button, loads parameters in promptbox to UI +- Settings page +- Running arbitrary python code from UI (must run with --allow-code to enable) +- Mouseover hints for most UI elements +- Possible to change defaults/mix/max/step values for UI elements via text config +- Tiling support, a checkbox to create images that can be tiled like textures +- Progress bar and live image generation preview + - Can use a separate neural network to produce previews with almost none VRAM or compute requirement +- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image +- Styles, a way to save part of prompt and easily apply them via dropdown later +- Variations, a way to generate same image but with tiny differences +- Seed resizing, a way to generate same image but at slightly different resolution +- CLIP interrogator, a button that tries to guess prompt from an image +- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway +- Batch Processing, process a group of files using img2img +- Img2img Alternative, reverse Euler method of cross attention control +- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions +- Reloading checkpoints on the fly +- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one +- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community +- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once + - separate prompts using uppercase `AND` + - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` +- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) +- DeepDanbooru integration, creates danbooru style tags for anime prompts +- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) +- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI +- Generate forever option +- Training tab + - hypernetworks and embeddings options + - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) +- Clip skip +- Hypernetworks +- Loras (same as Hypernetworks but more pretty) +- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. +- Can select to load a different VAE from settings screen +- Estimated completion time in progress bar +- API +- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. +- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) +- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions +- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions +- Now without any bad letters! +- Load checkpoints in safetensors format +- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Now with a license! +- Reorder elements in the UI from settings screen +- + +## Installation and Running +Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. + +Alternatively, use online services (like Google Colab): + +- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) + +### Automatic Installation on Windows +1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" +2. Install [git](https://git-scm.com/download/win). +3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. +4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. + +### Automatic Installation on Linux +1. Install the dependencies: +```bash +# Debian-based: +sudo apt install wget git python3 python3-venv +# Red Hat-based: +sudo dnf install wget git python3 +# Arch-based: +sudo pacman -S wget git python3 +``` +2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run: +```bash +bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) +``` +3. Run `webui.sh`. +### Installation on Apple Silicon + +Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon). + +## Contributing +Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) + +## Documentation +The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). + +## Credits +Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. + +- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers +- k-diffusion - https://github.com/crowsonkb/k-diffusion.git +- GFPGAN - https://github.com/TencentARC/GFPGAN.git +- CodeFormer - https://github.com/sczhou/CodeFormer +- ESRGAN - https://github.com/xinntao/ESRGAN +- SwinIR - https://github.com/JingyunLiang/SwinIR +- Swin2SR - https://github.com/mv-lab/swin2sr +- LDSR - https://github.com/Hafiidz/latent-diffusion +- MiDaS - https://github.com/isl-org/MiDaS +- Ideas for optimizations - https://github.com/basujindal/stable-diffusion +- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. +- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention) +- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). +- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd +- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot +- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator +- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch +- xformers - https://github.com/facebookresearch/xformers +- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) +- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix +- Security advice - RyotaK +- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- (You) diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cfbee72d71bfd7deed2075e423ca51bd1da0521c --- /dev/null +++ b/configs/alt-diffusion-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: modules.xlmr.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" \ No newline at end of file diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4e896879dd7ac5697b89cb323ec43eb41c03596c --- /dev/null +++ b/configs/instruct-pix2pix.yaml @@ -0,0 +1,98 @@ +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +model: + base_learning_rate: 1.0e-04 + target: modules.models.diffusion.ddpm_edit.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: edited + cond_stage_key: edit + # image_size: 64 + # image_size: 32 + image_size: 16 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: false + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 0 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + num_workers: 1 + wrap: false + validation: + target: edit_dataset.EditDataset + params: + path: data/clip-filtered-dataset + cache_dir: data/ + cache_name: data_10k + split: val + min_text_sim: 0.2 + min_image_sim: 0.75 + min_direction_sim: 0.2 + max_samples_per_prompt: 1 + min_resize_res: 512 + max_resize_res: 512 + crop_res: 512 + output_as_edit: False + real_input: True diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4effe569e897369918625f9d8be5603a0e6a0d6 --- /dev/null +++ b/configs/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9eec37d24bce33ce92320a782d16ae72308190a --- /dev/null +++ b/configs/v1-inpainting-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/embeddings/Place Textual Inversion embeddings here.txt b/embeddings/Place Textual Inversion embeddings here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/environment-wsl2.yaml b/environment-wsl2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f88727507835d96bfbbfae3ece2996e8506e3760 --- /dev/null +++ b/environment-wsl2.yaml @@ -0,0 +1,11 @@ +name: automatic +channels: + - pytorch + - defaults +dependencies: + - python=3.10 + - pip=22.2.2 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 \ No newline at end of file diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..bc11cc6e4821bf110ada3c61c0fef35be7f548e8 --- /dev/null +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -0,0 +1,253 @@ +import os +import gc +import time + +import numpy as np +import torch +import torchvision +from PIL import Image +from einops import rearrange, repeat +from omegaconf import OmegaConf +import safetensors.torch + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config, ismap +from modules import shared, sd_hijack + +cached_ldsr_model: torch.nn.Module = None + + +# Create LDSR Class +class LDSR: + def load_model_from_config(self, half_attention): + global cached_ldsr_model + + if shared.opts.ldsr_cached and cached_ldsr_model is not None: + print("Loading model from cache") + model: torch.nn.Module = cached_ldsr_model + else: + print(f"Loading model from {self.modelPath}") + _, extension = os.path.splitext(self.modelPath) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu") + else: + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd + config = OmegaConf.load(self.yamlPath) + config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" + model: torch.nn.Module = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model = model.to(shared.device) + if half_attention: + model = model.half() + if shared.cmd_opts.opt_channelslast: + model = model.to(memory_format=torch.channels_last) + + sd_hijack.model_hijack.hijack(model) # apply optimization + model.eval() + + if shared.opts.ldsr_cached: + cached_ldsr_model = model + + return {"model": model} + + def __init__(self, model_path, yaml_path): + self.modelPath = model_path + self.yamlPath = yaml_path + + @staticmethod + def run(model, selected_path, custom_steps, eta): + example = get_cond(selected_path) + + n_runs = 1 + guider = None + ckwargs = None + ddim_use_x0_pred = False + temperature = 1. + eta = eta + custom_shape = None + + height, width = example["image"].shape[1:3] + split_input = height >= 128 and width >= 128 + + if split_input: + ks = 128 + stride = 64 + vqf = 4 # + model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01} + else: + if hasattr(model, "split_input_params"): + delattr(model, "split_input_params") + + x_t = None + logs = None + for n in range(n_runs): + if custom_shape is not None: + x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) + x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) + + logs = make_convolutional_sample(example, model, + custom_steps=custom_steps, + eta=eta, quantize_x0=False, + custom_shape=custom_shape, + temperature=temperature, noise_dropout=0., + corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, + ddim_use_x0_pred=ddim_use_x0_pred + ) + return logs + + def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): + model = self.load_model_from_config(half_attention) + + # Run settings + diffusion_steps = int(steps) + eta = 1.0 + + down_sample_method = 'Lanczos' + + gc.collect() + if torch.cuda.is_available: + torch.cuda.empty_cache() + + im_og = image + width_og, height_og = im_og.size + # If we can adjust the max upscale size, then the 4 below should be our variable + down_sample_rate = target_scale / 4 + wd = width_og * down_sample_rate + hd = height_og * down_sample_rate + width_downsampled_pre = int(np.ceil(wd)) + height_downsampled_pre = int(np.ceil(hd)) + + if down_sample_rate != 1: + print( + f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') + im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) + else: + print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") + + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts + pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size + im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) + + logs = self.run(model["model"], im_padded, diffusion_steps, eta) + + sample = logs["sample"] + sample = sample.detach().cpu() + sample = torch.clamp(sample, -1., 1.) + sample = (sample + 1.) / 2. * 255 + sample = sample.numpy().astype(np.uint8) + sample = np.transpose(sample, (0, 2, 3, 1)) + a = Image.fromarray(sample[0]) + + # remove padding + a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) + + del model + gc.collect() + if torch.cuda.is_available: + torch.cuda.empty_cache() + + return a + + +def get_cond(selected_path): + example = dict() + up_f = 4 + c = selected_path.convert('RGB') + c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) + c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], + antialias=True) + c_up = rearrange(c_up, '1 c h w -> 1 h w c') + c = rearrange(c, '1 c h w -> 1 h w c') + c = 2. * c - 1. + + c = c.to(shared.device) + example["LR_image"] = c + example["image"] = c_up + + return example + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, + mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, + corrector_kwargs=None, x_t=None + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, + normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, + mask=mask, x0=x0, temperature=temperature, verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, x_t=x_t) + + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, + corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): + log = dict() + + z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, 'split_input_params') + and model.cond_stage_key == 'coordinates_bbox'), + return_original_cond=True) + + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + z0 = None + + log["input"] = x + log["reconstruction"] = xrec + + if ismap(xc): + log["original_conditioning"] = model.to_rgb(xc) + if hasattr(model, 'cond_stage_key'): + log[model.cond_stage_key] = model.to_rgb(xc) + + else: + log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_model: + log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_key == 'class_label': + log[model.cond_stage_key] = xc[model.cond_stage_key] + + with model.ema_scope("Plotting"): + t0 = time.time() + + sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, mask=None, x0=z0, + temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, + x_t=x_T) + t1 = time.time() + + if ddim_use_x0_pred: + sample = intermediates['pred_x0'][-1] + + x_sample = model.decode_first_stage(sample) + + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + log["sample"] = x_sample + log["time"] = t1 - t0 + + return log diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py new file mode 100644 index 0000000000000000000000000000000000000000..d746007c7cc83b037b6ae82b766475b56c3fd778 --- /dev/null +++ b/extensions-builtin/LDSR/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR')) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b8cff29b9f4ca56e3a9f4b1ac8e150abb1a0ff30 --- /dev/null +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -0,0 +1,69 @@ +import os +import sys +import traceback + +from basicsr.utils.download_util import load_file_from_url + +from modules.upscaler import Upscaler, UpscalerData +from ldsr_model_arch import LDSR +from modules import shared, script_callbacks +import sd_hijack_autoencoder, sd_hijack_ddpm_v1 + + +class UpscalerLDSR(Upscaler): + def __init__(self, user_path): + self.name = "LDSR" + self.user_path = user_path + self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" + self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + super().__init__() + scaler_data = UpscalerData("LDSR", None, self) + self.scalers = [scaler_data] + + def load_model(self, path: str): + # Remove incorrect project.yaml file if too big + yaml_path = os.path.join(self.model_path, "project.yaml") + old_model_path = os.path.join(self.model_path, "model.pth") + new_model_path = os.path.join(self.model_path, "model.ckpt") + safetensors_model_path = os.path.join(self.model_path, "model.safetensors") + if os.path.exists(yaml_path): + statinfo = os.stat(yaml_path) + if statinfo.st_size >= 10485760: + print("Removing invalid LDSR YAML file.") + os.remove(yaml_path) + if os.path.exists(old_model_path): + print("Renaming model from model.pth to model.ckpt") + os.rename(old_model_path, new_model_path) + if os.path.exists(safetensors_model_path): + model = safetensors_model_path + else: + model = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="model.ckpt", progress=True) + yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, + file_name="project.yaml", progress=True) + + try: + return LDSR(model, yaml) + + except Exception: + print("Error importing LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + def do_upscale(self, img, path): + ldsr = self.load_model(path) + if ldsr is None: + print("NO LDSR!") + return img + ddim_steps = shared.opts.ldsr_steps + return ldsr.super_resolution(img, ddim_steps, self.scale) + + +def on_ui_settings(): + import gradio as gr + + shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8e03c7f898988c237c714ed949610f5035b30b50 --- /dev/null +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -0,0 +1,286 @@ +# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo +# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo +# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder + +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.util import instantiate_from_config + +import ldm.models.autoencoder + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + +setattr(ldm.models.autoencoder, "VQModel", VQModel) +setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0488e5f6fcea41e7f9fa25070e38fbfe656478 --- /dev/null +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -0,0 +1,1449 @@ +# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo) +# Original filename: ldm/models/diffusion/ddpm.py +# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't +# Some models such as LDSR require VQ to work correctly +# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + +import ldm.models.diffusion.ddpm + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPMV1(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapperV1(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusionV1(DDPMV1): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapperV1(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusionV1(LatentDiffusionV1): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs + +setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) +setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) +setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) +setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..6be6ef73c66f2829a179d1bcaf4f1982f25ea452 --- /dev/null +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -0,0 +1,26 @@ +from modules import extra_networks, shared +import lora + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + def activate(self, p, params_list): + additional = shared.opts.sd_lora + + if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lora.load_loras(names, multipliers) + + def deactivate(self, p): + pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8f1d36a8a96eb1fed4f0fedd45f9b395c4bc0a --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,207 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + self.alpha = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: + res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + else: + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +available_loras = {} +loaded_loras = [] + +list_available_loras() diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py new file mode 100644 index 0000000000000000000000000000000000000000..863dc5c0b510a20e03ed68adec0108ebf1e03474 --- /dev/null +++ b/extensions-builtin/Lora/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py new file mode 100644 index 0000000000000000000000000000000000000000..2e860160e945daf5e1b46b7f7d7b9141ab6d3a3e --- /dev/null +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -0,0 +1,38 @@ +import torch +import gradio as gr + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks, shared + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), + "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), + +})) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..22cabcb0f2ff2b9a724377f94b3273fbcdd127c0 --- /dev/null +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -0,0 +1,37 @@ +import json +import os +import lora + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + lora.list_available_loras() + + def list_items(self): + for name, lora_on_disk in lora.available_loras.items(): + path, ext = os.path.splitext(lora_on_disk.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = self.link_preview(file) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "search_term": self.search_terms_from_path(lora_on_disk.filename), + "prompt": json.dumps(f""), + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.lora_dir] + diff --git a/extensions-builtin/ScuNET/preload.py b/extensions-builtin/ScuNET/preload.py new file mode 100644 index 0000000000000000000000000000000000000000..f12c5b90ed2984ef16d8d8dd30d1ebef34cbf7c3 --- /dev/null +++ b/extensions-builtin/ScuNET/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET')) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e0fbf3a33747f447d396dd0d564e92c904cfabac --- /dev/null +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -0,0 +1,87 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import devices, modelloader +from scunet_model_arch import SCUNet as net + + +class UpscalerScuNET(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "ScuNET" + self.model_name = "ScuNET GAN" + self.model_name2 = "ScuNET PSNR" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" + self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pth"]) + scalers = [] + add_model2 = True + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + if name == self.model_name2 or file == self.model_url2: + add_model2 = False + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading ScuNET model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + if add_model2: + scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) + scalers.append(scaler_data2) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + + model = self.load_model(selected_file) + if model is None: + return img + + device = devices.get_device_for('scunet') + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(device) + + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + device = devices.get_device_for('scunet') + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: + print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) + return None + + model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + + return model + diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..43ca8d36fe57a12dcad58e8b06ee2e0774494b0e --- /dev/null +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_, DropPath + + +class WMSA(nn.Module): + """ Self-attention module in Swin Transformer + """ + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim ** -0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) + + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=.02) + self.relative_position_params = torch.nn.Parameter( + self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, + 2).transpose( + 0, 1)) + + def generate_mask(self, h, w, p, shift): + """ generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting square. + attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) + if self.type == 'W': + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') + return attn_mask + + def forward(self, x): + """ Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; + attn_mask: attention mask, fill -inf where the value is True; + Returns: + output: tensor shape [b h w c] + """ + if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) + h_windows = x.size(1) + w_windows = x.size(2) + # square validation + # assert h_windows == w_windows + + x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) + qkv = self.embedding_layer(x) + q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) + sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale + # Adding learnable relative embedding + sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') + # Using Attn Mask to distinguish different subwindows. + if self.type != 'W': + attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) + sim = sim.masked_fill_(attn_mask, float("-inf")) + + probs = nn.functional.softmax(sim, dim=-1) + output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) + output = rearrange(output, 'h b w p c -> b w p (h c)') + output = self.linear(output) + output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) + + if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), + dims=(1, 2)) + return output + + def relative_embedding(self): + cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) + relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 + # negative is allowed + return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] + + +class Block(nn.Module): + def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer Block + """ + super(Block, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + assert type in ['W', 'SW'] + self.type = type + if input_resolution <= window_size: + self.type = 'W' + + self.ln1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = nn.Sequential( + nn.Linear(input_dim, 4 * input_dim), + nn.GELU(), + nn.Linear(4 * input_dim, output_dim), + ) + + def forward(self, x): + x = x + self.drop_path(self.msa(self.ln1(x))) + x = x + self.drop_path(self.mlp(self.ln2(x))) + return x + + +class ConvTransBlock(nn.Module): + def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer and Conv Block + """ + super(ConvTransBlock, self).__init__() + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.head_dim = head_dim + self.window_size = window_size + self.drop_path = drop_path + self.type = type + self.input_resolution = input_resolution + + assert self.type in ['W', 'SW'] + if self.input_resolution <= self.window_size: + self.type = 'W' + + self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, + self.type, self.input_resolution) + self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + + self.conv_block = nn.Sequential( + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) + ) + + def forward(self, x): + conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) + conv_x = self.conv_block(conv_x) + conv_x + trans_x = Rearrange('b c h w -> b h w c')(trans_x) + trans_x = self.trans_block(trans_x) + trans_x = Rearrange('b h w c -> b c h w')(trans_x) + res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) + x = x + res + + return x + + +class SCUNet(nn.Module): + # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): + def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): + super(SCUNet, self).__init__() + if config is None: + config = [2, 2, 2, 2, 2, 2, 2] + self.config = config + self.dim = dim + self.head_dim = 32 + self.window_size = 8 + + # drop path rate for each layer + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] + + self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] + + begin = 0 + self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[0])] + \ + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] + + begin += config[0] + self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[1])] + \ + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] + + begin += config[1] + self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[2])] + \ + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] + + begin += config[2] + self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 8) + for i in range(config[3])] + + begin += config[3] + self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[4])] + + begin += config[4] + self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[5])] + + begin += config[5] + self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[6])] + + self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] + + self.m_head = nn.Sequential(*self.m_head) + self.m_down1 = nn.Sequential(*self.m_down1) + self.m_down2 = nn.Sequential(*self.m_down2) + self.m_down3 = nn.Sequential(*self.m_down3) + self.m_body = nn.Sequential(*self.m_body) + self.m_up3 = nn.Sequential(*self.m_up3) + self.m_up2 = nn.Sequential(*self.m_up2) + self.m_up1 = nn.Sequential(*self.m_up1) + self.m_tail = nn.Sequential(*self.m_tail) + # self.apply(self._init_weights) + + def forward(self, x0): + + h, w = x0.size()[-2:] + paddingBottom = int(np.ceil(h / 64) * 64 - h) + paddingRight = int(np.ceil(w / 64) * 64 - w) + x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) + + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[..., :h, :w] + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) \ No newline at end of file diff --git a/extensions-builtin/SwinIR/preload.py b/extensions-builtin/SwinIR/preload.py new file mode 100644 index 0000000000000000000000000000000000000000..567e44bcaa1d40ca5dabed7744c94c5b2c68e87f --- /dev/null +++ b/extensions-builtin/SwinIR/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR')) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e8783bca153954afd086536a6dee854ec5e17ba9 --- /dev/null +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -0,0 +1,178 @@ +import contextlib +import os + +import numpy as np +import torch +from PIL import Image +from basicsr.utils.download_util import load_file_from_url +from tqdm import tqdm + +from modules import modelloader, devices, script_callbacks, shared +from modules.shared import cmd_opts, opts, state +from swinir_model_arch import SwinIR as net +from swinir_model_arch_v2 import Swin2SR as net2 +from modules.upscaler import Upscaler, UpscalerData + + +device_swinir = devices.get_device_for('swinir') + + +class UpscalerSwinIR(Upscaler): + def __init__(self, dirname): + self.name = "SwinIR" + self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ + "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ + "-L_x4_GAN.pth " + self.model_name = "SwinIR 4x" + self.user_path = dirname + super().__init__() + scalers = [] + model_files = self.find_models(ext_filter=[".pt", ".pth"]) + for model in model_files: + if "http" in model: + name = self.model_name + else: + name = modelloader.friendly_name(model) + model_data = UpscalerData(name, model, self) + scalers.append(model_data) + self.scalers = scalers + + def do_upscale(self, img, model_file): + model = self.load_model(model_file) + if model is None: + return img + model = model.to(device_swinir, dtype=devices.dtype) + img = upscale(img, model) + try: + torch.cuda.empty_cache() + except: + pass + return img + + def load_model(self, path, scale=4): + if "http" in path: + dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None + if filename.endswith(".v2.pth"): + model = net2( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="1conv", + ) + params = None + else: + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + params = "params_ema" + + pretrained_model = torch.load(filename) + if params is not None: + model.load_state_dict(pretrained_model[params], strict=True) + else: + model.load_state_dict(pretrained_model, strict=True) + return model + + +def upscale( + img, + model, + tile=None, + tile_overlap=None, + window_size=8, + scale=4, +): + tile = tile or opts.SWIN_tile + tile_overlap = tile_overlap or opts.SWIN_tile_overlap + + + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + with torch.no_grad(), devices.autocast(): + _, _, h_old, w_old = img.size() + h_pad = (h_old // window_size + 1) * window_size - h_old + w_pad = (w_old // window_size + 1) * window_size - w_old + img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] + img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] + output = inference(img, model, tile, tile_overlap, window_size, scale) + output = output[..., : h_old * scale, : w_old * scale] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose( + output[[2, 1, 0], :, :], (1, 2, 0) + ) # CHW-RGB to HCW-BGR + output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 + return Image.fromarray(output, "RGB") + + +def inference(img, model, tile, tile_overlap, window_size, scale): + # test the image tile by tile + b, c, h, w = img.size() + tile = min(tile, h, w) + assert tile % window_size == 0, "tile size should be a multiple of window_size" + sf = scale + + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) + + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: + for h_idx in h_idx_list: + if state.interrupted or state.skipped: + break + + for w_idx in w_idx_list: + if state.interrupted or state.skipped: + break + + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) + output = E.div_(W) + + return output + + +def on_ui_settings(): + import gradio as gr + + shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..863f42db6f50e5eac70931b8c0e6443f831a6018 --- /dev/null +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -0,0 +1,867 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0e28ae6eefa2f4bc6260b14760907c54ce633876 --- /dev/null +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -0,0 +1,1017 @@ +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + #assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(Swin2SR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_hf: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] + + else: + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js new file mode 100644 index 0000000000000000000000000000000000000000..4a85c8ebf25110e911a6a1021fae6a014aa11000 --- /dev/null +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -0,0 +1,110 @@ +// Stable Diffusion WebUI - Bracket checker +// Version 1.0 +// By Hingashi no Florin/Bwin4L +// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. +// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. + +function checkBrackets(evt, textArea, counterElt) { + errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; + errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; + errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; + + openBracketRegExp = /\(/g; + closeBracketRegExp = /\)/g; + + openSquareBracketRegExp = /\[/g; + closeSquareBracketRegExp = /\]/g; + + openCurlyBracketRegExp = /\{/g; + closeCurlyBracketRegExp = /\}/g; + + totalOpenBracketMatches = 0; + totalCloseBracketMatches = 0; + totalOpenSquareBracketMatches = 0; + totalCloseSquareBracketMatches = 0; + totalOpenCurlyBracketMatches = 0; + totalCloseCurlyBracketMatches = 0; + + openBracketMatches = textArea.value.match(openBracketRegExp); + if(openBracketMatches) { + totalOpenBracketMatches = openBracketMatches.length; + } + + closeBracketMatches = textArea.value.match(closeBracketRegExp); + if(closeBracketMatches) { + totalCloseBracketMatches = closeBracketMatches.length; + } + + openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp); + if(openSquareBracketMatches) { + totalOpenSquareBracketMatches = openSquareBracketMatches.length; + } + + closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp); + if(closeSquareBracketMatches) { + totalCloseSquareBracketMatches = closeSquareBracketMatches.length; + } + + openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp); + if(openCurlyBracketMatches) { + totalOpenCurlyBracketMatches = openCurlyBracketMatches.length; + } + + closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp); + if(closeCurlyBracketMatches) { + totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length; + } + + if(totalOpenBracketMatches != totalCloseBracketMatches) { + if(!counterElt.title.includes(errorStringParen)) { + counterElt.title += errorStringParen; + } + } else { + counterElt.title = counterElt.title.replace(errorStringParen, ''); + } + + if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) { + if(!counterElt.title.includes(errorStringSquare)) { + counterElt.title += errorStringSquare; + } + } else { + counterElt.title = counterElt.title.replace(errorStringSquare, ''); + } + + if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) { + if(!counterElt.title.includes(errorStringCurly)) { + counterElt.title += errorStringCurly; + } + } else { + counterElt.title = counterElt.title.replace(errorStringCurly, ''); + } + + if(counterElt.title != '') { + counterElt.classList.add('error'); + } else { + counterElt.classList.remove('error'); + } +} + +function setupBracketChecking(id_prompt, id_counter){ + var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); + var counter = gradioApp().getElementById(id_counter) + textarea.addEventListener("input", function(evt){ + checkBrackets(evt, textarea, counter) + }); +} + +var shadowRootLoaded = setInterval(function() { + var shadowRoot = document.querySelector('gradio-app').shadowRoot; + if(! shadowRoot) return false; + + var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); + if(shadowTextArea.length < 1) return false; + + clearInterval(shadowRootLoaded); + + setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') + setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') + setupBracketChecking('img2img_prompt', 'imgimg_token_counter') + setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') +}, 1000); diff --git a/extensions/deforum/.github/FUNDING.yml b/extensions/deforum/.github/FUNDING.yml new file mode 100644 index 0000000000000000000000000000000000000000..e707ddb03c159bc5fc649349fe8303dbf43ddedc --- /dev/null +++ b/extensions/deforum/.github/FUNDING.yml @@ -0,0 +1,13 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: deforum +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/extensions/deforum/.github/ISSUE_TEMPLATE/bug_report.yml b/extensions/deforum/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..c86f978adbda56d77e669f5541b3f86300733f72 --- /dev/null +++ b/extensions/deforum/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,97 @@ +name: Bug Report +description: Create a bug report for the Deforum extension +title: "[Bug]: " +labels: ["bug-report"] + +body: + - type: checkboxes + attributes: + label: Have you read the latest version of the FAQ? + description: Please visit the page called FAQ & Troubleshooting on the Deforum wiki in this repository and see if your problem has already been described there. + options: + - label: I have visited the FAQ page right now and my issue is not present there + required: true + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the bug you encountered (including the closed issues). + options: + - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui + required: true + - type: checkboxes + attributes: + label: Are you using the latest version of the Deforum extension? + description: Please, check if your Deforum is based on the latest repo commit (git log) or update it through the 'Extensions' tab and check if the issue still persist. Otherwise, check this box. + options: + - label: I have Deforum updated to the lastest version and I still have the issue. + required: true + - type: markdown + attributes: + value: | + *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** + - type: textarea + id: what-did + attributes: + label: What happened? + description: Tell us what happened in a very clear and simple way + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce the problem + description: Please provide us with precise step by step information on how to reproduce the bug + value: | + 1. Go to .... + 2. Press .... + 3. ... + validations: + required: true + - type: textarea + id: what-should + attributes: + label: What should have happened? + description: Tell what you think the normal behavior should be + - type: textarea + id: commits + attributes: + label: WebUI and Deforum extension Commit IDs + description: Which commit of the webui/deforum extension are you running on? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or if you can't launch the webui at all, enter your cmd/terminal, CD into the main webui folder to get the webui commit id, and cd into the extensions/deforum folder to get the deforum commit id, both using the command 'git rev-parse HEAD'.) + value: | + webui commit id - + deforum exten commit id - + validations: + required: true + - type: dropdown + id: where + attributes: + label: On which platform are you launching the webui with the extension? + multiple: true + options: + - Local PC setup (Windows) + - Local PC setup (Linux) + - Local PC setup (Mac) + - Google Colab (The Last Ben's) + - Google Colab (Other) + - Cloud server (Linux) + - Other (please specify in "additional information") + - type: textarea + id: customsettings + attributes: + label: Webui core settings + description: Send here a link to your ui-config.json file in the core 'stable-diffusion-webui' folder (ideally, upload it to GitHub gists). Friendly reminder - if you have 'With img2img, do exactly the amount of steps the slider specified' checked, your issue will be discarded immediately. 😉 + validations: + required: true + - type: textarea + id: logs + attributes: + label: Console logs + description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to GitHub gists or similar service. + render: Shell + validations: + required: true + - type: textarea + id: misc + attributes: + label: Additional information + description: Please provide us with any relevant additional info or context. diff --git a/extensions/deforum/.github/ISSUE_TEMPLATE/config.yml b/extensions/deforum/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..b15e609d9f1315070a1bcc0d82bd9d3dfc7b81f0 --- /dev/null +++ b/extensions/deforum/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Deforum Github discussions + url: https://github.com/deforum-art/deforum-for-automatic1111-webui/discussions + about: Please ask and answer questions here. + - name: Deforum Discord + url: https://discord.gg/deforum + about: Or here :) \ No newline at end of file diff --git a/extensions/deforum/.github/ISSUE_TEMPLATE/feature_request.yml b/extensions/deforum/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000000000000000000000000000000000..58a6332ef1bec7637bbbad898eddd06a0a2e1cf4 --- /dev/null +++ b/extensions/deforum/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,40 @@ +name: Feature request +description: Suggest an idea for the Deforum extension +title: "[Feature Request]: " +labels: ["enhancement"] + +body: + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. + options: + - label: I have searched the existing issues and checked the recent builds/commits + required: true + - type: markdown + attributes: + value: | + *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* + - type: textarea + id: feature + attributes: + label: What would your feature do ? + description: Tell us about your feature in a very clear and simple way, and what problem it would solve + validations: + required: true + - type: textarea + id: workflow + attributes: + label: Proposed workflow + description: Please provide us with step by step information on how you'd like the feature to be accessed and used + value: | + 1. Go to .... + 2. Press .... + 3. ... + validations: + required: true + - type: textarea + id: misc + attributes: + label: Additional information + description: Add any other context or screenshots about the feature request here. \ No newline at end of file diff --git a/extensions/deforum/.gitignore b/extensions/deforum/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5cc8a723f7ba813d56294d38c6859dc1ce195e42 --- /dev/null +++ b/extensions/deforum/.gitignore @@ -0,0 +1,10 @@ +# Unnecessary compiled python files. +__pycache__ +*.pyc +*.pyo + +# Output Images +outputs + +# Log files for colab-convert +cc-outputs.log diff --git a/extensions/deforum/CONTRIBUTING.md b/extensions/deforum/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..c5df29a13e839422129b9e5e1919cafac4a651e5 --- /dev/null +++ b/extensions/deforum/CONTRIBUTING.md @@ -0,0 +1,7 @@ +# Contributing + +As a part of the Deforum team I (kabachuha) want this script extension to remain a part of the Deforum project. + +Thus, if you want to submit feature request or bugfix, unless it only relates to automatic1111's porting issues, consider making a PR first to the parent repository notebook https://github.com/deforum/stable-diffusion. + +Also, you may want to inforum the dev team about your work via Discord https://discord.gg/deforum to ensure that no one else is working on the same stuff. diff --git a/extensions/deforum/LICENSE b/extensions/deforum/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1e315f0d1cf6154f77fbd35af88d0b8c8c5d933e --- /dev/null +++ b/extensions/deforum/LICENSE @@ -0,0 +1,2496 @@ +deforum-stable-diffusion: +MIT License + +Copyright (c) 2022 deforum and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +k-diffusion: +MIT License + +Copyright (c) 2022 Katherine Crowson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +clip: +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +MiDaS: +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +pytorch3d-lite: +BSD License + +For PyTorch3D software + +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Meta nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +taming-transformers: +Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +OR OTHER DEALINGS IN THE SOFTWARE./ + +stable diffusion: +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). + +transformers: +Copyright 2018- The Hugging Face team. All rights reserved. + +transformers, FILM Interpolation: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +adabins: + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. + +adabins: + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. + +Clipseg: +MIT License + +This license does not apply to the model weights. + +RIFE: +MIT License + +Copyright (c) 2021 hzwer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +gifski: +GNU AFFERO GENERAL PUBLIC LICENSE ++ +pngquant.c + +### pngquant.c +© 1989, 1991 by Jef Poskanzer. + +Permission to use, copy, modify, and distribute this software and its documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation. This software is provided "as is" without express or implied warranty. + +© 1997-2002 by Greg Roelofs; based on an idea by Stefan Schneider. + +All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +### GNU AFFERO GENERAL PUBLIC LICENSE + +Version 3, 19 November 2007 + +© 2007 Free Software Foundation, Inc. + + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +### Preamble + +The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains +free software for all its users. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + +A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + +The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + +An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing +under this license. + +The precise terms and conditions for copying, distribution and +modification follow. + +### TERMS AND CONDITIONS + +#### 0. Definitions. + +"This License" refers to version 3 of the GNU Affero General Public +License. + +"Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of +an exact copy. The resulting work is called a "modified version" of +the earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user +through a computer network, with no transfer of a copy, is not +conveying. + +An interactive user interface displays "Appropriate Legal Notices" to +the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +#### 1. Source Code. + +The "source code" for a work means the preferred form of the work for +making modifications to it. "Object code" means any non-source form of +a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users can +regenerate automatically from other parts of the Corresponding Source. + +The Corresponding Source for a work in source code form is that same +work. + +#### 2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not convey, +without conditions so long as your license otherwise remains in force. +You may convey covered works to others for the sole purpose of having +them make modifications exclusively for you, or provide you with +facilities for running those works, provided that you comply with the +terms of this License in conveying all material for which you do not +control copyright. Those thus making or running the covered works for +you must do so exclusively on your behalf, under your direction and +control, on terms that prohibit them from making any copies of your +copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under the +conditions stated below. Sublicensing is not allowed; section 10 makes +it unnecessary. + +#### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such +circumvention is effected by exercising rights under this License with +respect to the covered work, and you disclaim any intention to limit +operation or modification of the work as a means of enforcing, against +the work's users, your or third parties' legal rights to forbid +circumvention of technological measures. + +#### 4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +#### 5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these +conditions: + +- a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. +- b) The work must carry prominent notices stating that it is + released under this License and any conditions added under + section 7. This requirement modifies the requirement in section 4 + to "keep intact all notices". +- c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. +- d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +#### 6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms of +sections 4 and 5, provided that you also convey the machine-readable +Corresponding Source under the terms of this License, in one of these +ways: + +- a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. +- b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the Corresponding + Source from a network server at no charge. +- c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. +- d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. +- e) Convey the object code using peer-to-peer transmission, + provided you inform other peers where the object code and + Corresponding Source of the work are being offered to the general + public at no charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +#### 7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders +of that material) supplement the terms of this License with terms: + +- a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or +- b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or +- c) Prohibiting misrepresentation of the origin of that material, + or requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or +- d) Limiting the use for publicity purposes of names of licensors + or authors of the material; or +- e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or +- f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; the +above requirements apply either way. + +#### 8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your license +from a particular copyright holder is reinstated (a) provisionally, +unless and until the copyright holder explicitly and finally +terminates your license, and (b) permanently, if the copyright holder +fails to notify you of the violation by some reasonable means prior to +60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +#### 9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or run +a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +#### 10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +#### 11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims owned +or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within the +scope of its coverage, prohibits the exercise of, or is conditioned on +the non-exercise of one or more of the rights that are specifically +granted under this License. You may not convey a covered work if you +are a party to an arrangement with a third party that is in the +business of distributing software, under which you make payment to the +third party based on the extent of your activity of conveying the +work, and under which the third party grants, to any of the parties +who would receive the covered work from you, a discriminatory patent +license (a) in connection with copies of the covered work conveyed by +you (or copies made from those copies), or (b) primarily for and in +connection with specific products or compilations that contain the +covered work, unless you entered into that arrangement, or that patent +license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +#### 12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under +this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree to +terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + +#### 13. Remote Network Interaction; Use with the GNU General Public License. + +Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your +version supports such interaction) an opportunity to receive the +Corresponding Source of your version by providing access to the +Corresponding Source from a network server at no charge, through some +standard or customary means of facilitating copying of software. This +Corresponding Source shall include the Corresponding Source for any +work covered by version 3 of the GNU General Public License that is +incorporated pursuant to the following paragraph. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + +#### 14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions +of the GNU Affero General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever +published by the Free Software Foundation. + +If the Program specifies that a proxy can decide which future versions +of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +#### 15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + +#### 16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR +CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES +ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT +NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR +LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM +TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER +PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +#### 17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + +REALESRGAN/ REALESRGAN-NCNN-VULKAN: +MIT License (MIT) + +Copyright (c) 2021 Xintao Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------ +The following is the License of realsr-ncnn-vulkan + +The MIT License (MIT) + +Copyright (c) 2019 nihui + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/extensions/deforum/README.md b/extensions/deforum/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f94beee68a7475e89fe4f1eca62b8567320ed473 --- /dev/null +++ b/extensions/deforum/README.md @@ -0,0 +1,81 @@ + +# Deforum Stable Diffusion — official extension for AUTOMATIC1111's webui + +

+ Last Commit + GitHub issues + GitHub stars + GitHub forks + +

+ +## Before Starting + +**Important note about versions updating:**
+As auto's webui is getting updated multiple times a day, every day, things tend to break with regards to extensions compatability. +Therefore, it is best recommended to keep two folders: +1. "Stable" folder that you don't regularly update, with versions that you know *work* together (we will provide info on this soon). +2. "Experimental" folder in which you can add 'git pull' to your webui-user.bat, update deforum every day, etc. Keep it wild - but be prepared for bugs. + + +## Getting Started + +1. Install [AUTOMATIC1111's webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/).
If the repo link doesn't work, please use the alternate official download source: [https://gitgud.io/AUTOMATIC1111/stable-diffusion-webui](https://gitgud.io/AUTOMATIC1111/stable-diffusion-webui). To change your existing webui's installation origin, execute `git remote set-url origin https://gitgud.io/AUTOMATIC1111/stable-diffusion-webui` in the webui starting folder. + +2. Now two ways: either clone the repo into the `extensions` directory via git commandline launched within in the `stable-diffusion-webui` folder + +```sh +git clone https://github.com/deforum-art/deforum-for-automatic1111-webui extensions/deforum +``` + +Or download this repository, locate the `extensions` folder within your WebUI installation, create a folder named `deforum` and put the contents of the downloaded directory inside of it. Then restart WebUI. **Warning: the extension folder has to be named 'deforum' or 'deforum-for-automatic1111-webui', otherwise it will fail to locate the 3D modules as the PATH addition is hardcoded** + +3. Open the webui, find the Deforum tab at the top of the page. + +4. Enter the animation settings. Refer to [this general guide](https://docs.google.com/document/d/1pEobUknMFMkn8F5TMsv8qRzamXX_75BShMMXV8IFslI/edit) and [this guide to math keyframing functions in Deforum](https://docs.google.com/document/d/1pfW1PwbDIuW0cv-dnuyYj1UzPqe23BlSLTJsqazffXM/edit?usp=sharing). However, **in this version prompt weights less than zero don't just like in original Deforum!** Split the positive and the negative prompt in the json section using --neg argument like this "apple:\`where(cos(t)>=0, cos(t), 0)\`, snow --neg strawberry:\`where(cos(t)<0, -cos(t), 0)\`" + +5. To view animation frames as they're being made, without waiting for the completion of an animation, go to the 'Settings' tab and set the value of this toolbar **above zero**. Warning: it may slow down the generation process. If you have 'Do exactly the amount of steps the slider specifies' checkbox selected in the tab, unselect it as it won't allow you to use Deforum schedules and you will get adrupt frame changes without transitions. Then click 'Apply settings' at the top of the page. Now return to the 'Deforum' tab. + +![adsdasunknown](https://user-images.githubusercontent.com/14872007/196064311-1b79866a-e55b-438a-84a7-004ff30829ad.png) + + +6. Run the script and see if you got it working or even got something. **In 3D mode a large delay is expected at first** as the script loads the depth models. In the end, using the default settings the whole thing should consume 6.4 GBs of VRAM at 3D mode peaks and no more than 3.8 GB VRAM in 3D mode if you launch the webui with the '--lowvram' command line argument. + +7. After the generation process is completed, click the button with the self-describing name to show the video or gif result right in the GUI! + +8. Join our Discord where you can post generated stuff, ask questions and more: https://discord.gg/deforum.
+* There's also the 'Issues' tab in the repo, for well... reporting issues ;) + +9. Profit! + +## Known issues + +* This port is not fully backward-compatible with the notebook and the local version both due to the changes in how AUTOMATIC1111's webui handles Stable Diffusion models and the changes in this script to get it to work in the new environment. *Expect* that you may not get exactly the same result or that the thing may break down because of the older settings. + +## Screenshots + +https://user-images.githubusercontent.com/121192995/215522284-d6fbedd5-09e2-4d2c-bd10-f9bbb4a20f82.mp4 + +Main extension tab: + +![maintab](https://user-images.githubusercontent.com/121192995/215362176-4e5599c1-9cb6-4bf9-964d-0ff882661993.png) + +Keyframes tab: + +![keyframes](https://user-images.githubusercontent.com/121192995/215362228-c239c43a-d565-4862-b490-d18b19eaaaa5.png) + +Math evaluation: + +![math-eval](https://user-images.githubusercontent.com/121192995/215362467-481127a4-247a-4b0d-924a-d10719aa4c01.png) + + +## Benchmarks + +3D mode without additional WebUI flags + +![image](https://user-images.githubusercontent.com/14872007/196294447-7817f138-ec4b-4001-885f-454f8667100d.png) + +3D mode when WebUI is launched with '--lowvram' + +![image](https://user-images.githubusercontent.com/14872007/196294517-125fbb27-c06d-4c4b-bcbc-7c743103eff6.png) + diff --git a/extensions/deforum/install.py b/extensions/deforum/install.py new file mode 100644 index 0000000000000000000000000000000000000000..b9166e71c44972d8582836239636d0f483a51ff5 --- /dev/null +++ b/extensions/deforum/install.py @@ -0,0 +1,14 @@ +import launch +import os +import sys + +req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") + +with open(req_file) as file: + for lib in file: + lib = lib.strip() + if not launch.is_installed(lib): + if lib == 'rich': + launch.run(f'"{sys.executable}" -m pip install {lib}', desc=f"Installing Deforum requirement: {lib}", errdesc=f"Couldn't install {lib}") + else: + launch.run_pip(f"install {lib}", f"Deforum requirement: {lib}") diff --git a/extensions/deforum/javascript/deforum-hints.js b/extensions/deforum/javascript/deforum-hints.js new file mode 100644 index 0000000000000000000000000000000000000000..bc50ffc016ee93cd88050b7e4d0fbd50f3c96718 --- /dev/null +++ b/extensions/deforum/javascript/deforum-hints.js @@ -0,0 +1,191 @@ +// mouseover tooltips for various UI elements + +deforum_titles = { + //Run + "Override settings": "specify a custom settings file and ignore settings displayed in the interface", + "Custom settings file": "the path to a custom settings file", + "Width": "The width of the output images, in pixels (must be a multiple of 64)", + "Height": "The height of the output images, in pixels (must be a multiple of 64)", + "Restore faces": "Restore low quality faces using GFPGAN neural network", + "Tiling": "Produce an image that can be tiled.", + "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", + "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", + "Sampler": "Which algorithm to use to produce the image", + "Enable extras": "enable additional seed settings", + "Subseed": "Seed of a different picture to be mixed into the generation.", + "Subseed strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).", + "Resize seed from width": "Normally, changing the resolution will completely change an image, even when using the same seed. If you generated an image with a particular seed and then changed the resolution, put the original resolution here to get an image that more closely resemles the original", + "Resize seed from height": "Normally, changing the resolution will completely change an image, even when using the same seed. If you generated an image with a particular seed and then changed the resolution, put the original resolution here to get an image that more closely resemles the original", + "Steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", + //"ddim_eta": ""; + //"n_batch": "", + //"make_grid": "", + //"grid_rows": "", + //"save_settings": "", + //"save_samples": "", + "Batch name": "output images will be placed in a folder with this name, inside of the img2img output folder", + "Pix2Pix img CFG schedule": "*Only in use with pix2pix checkpoints!*", + "Filename format": "specify the format of the filename for output images", + "Seed behavior": "defines the seed behavior that is used for animations", + "iter": "the seed value will increment by 1 for each subsequent frame of the animation", + "fixed": "the seed will remain fixed across all frames of animation", + "random": "a random seed will be used on each frame of the animation", + "schedule": "specify your own seed schedule (found on the Keyframes page)", + + //Keyframes + "Animation mode": "selects the type of animation", + "2D": "only 2D motion parameters will be used, but this mode uses the least amount of VRAM. You can optionally enable flip_2d_perspective to enable some psuedo-3d animation parameters while in 2D mode.", + "3D": "enables all 3D motion parameters.", + "Video Input": "will ignore all motion parameters and attempt to reference a video loaded into the runtime, specified by the video_init_path. Max_frames is ignored during video_input mode, and instead, follows the number of frames pulled from the video’s length. Resume_from_timestring is NOT available with Video_Input mode.", + "Max frames": "the maximum number of output images to be created", + "Border": "controls handling method of pixels to be generated when the image is smaller than the frame.", + "wrap": "pulls pixels from the opposite edge of the image", + "replicate": "repeats the edge of the pixels, and extends them. Animations with quick motion may yield lines where this border function was attempting to populate pixels into the empty space created.", + "Angle": "2D operator to rotate canvas clockwise/anticlockwise in degrees per frame", + "Zoom": "2D operator that scales the canvas size, multiplicatively. [static = 1.0]", + "Translation X": "2D & 3D operator to move canvas left/right in pixels per frame", + "Translation Y": "2D & 3D operator to move canvas up/down in pixels per frame", + "Translation Z": "3D operator to move canvas towards/away from view [speed set by FOV]", + "Rotation 3D X": "3D operator to tilt canvas up/down in degrees per frame", + "Rotation 3D Y": "3D operator to pan canvas left/right in degrees per frame", + "Rotation 3D Z": "3D operator to roll canvas clockwise/anticlockwise", + "Enable perspective flip": "enables 2D mode functions to simulate faux 3D movement", + "Perspective flip theta": "the roll effect angle", + "Perspective flip phi": "the tilt effect angle", + "Perspective flip gamma": "the pan effect angle", + "Perspective flip fv": "the 2D vanishing point of perspective (recommended range 30-160)", + "Noise schedule": "amount of graininess to add per frame for diffusion diversity", + "Strength schedule": "amount of presence of previous frame to influence next frame, also controls steps in the following formula [steps - (strength_schedule * steps)]", + "Sampler schedule": "controls which sampler to use at a specific scheduled frame", + "Contrast schedule": "adjusts the overall contrast per frame [default neutral at 1.0]", + "CFG scale schedule": "how closely the image should conform to the prompt. Lower values produce more creative results. (recommended range 5-15)", + "FOV schedule": "adjusts the scale at which the canvas is moved in 3D by the translation_z value. [maximum range -180 to +180, with 0 being undefined. Values closer to 180 will make the image have less depth, while values closer to 0 will allow more depth]", + //"near_schedule": "", + //"far_schedule": "", + "Seed schedule": "allows you to specify seeds at a specific schedule, if seed_behavior is set to schedule.", + "Color coherence": "The color coherence will attempt to sample the overall pixel color information, and trend those values analyzed in the first frame to be applied to future frames.", + // "None": "Disable color coherence", + "Match Frame 0 HSV": "HSV is a good method for balancing presence of vibrant colors, but may produce unrealistic results - (ie.blue apples)", + "Match Frame 0 LAB": "LAB is a more linear approach to mimic human perception of color space - a good default setting for most users.", + "Match Frame 0 RGB": "RGB is good for enforcing unbiased amounts of color in each red, green and blue channel - some images may yield colorized artifacts if sampling is too low.", + "Cadence": "A setting of 1 will cause every frame to receive diffusion in the sequence of image outputs. A setting of 2 will only diffuse on every other frame, yet motion will still be in effect. The output of images during the cadence sequence will be automatically blended, additively and saved to the specified drive. This may improve the illusion of coherence in some workflows as the content and context of an image will not change or diffuse during frames that were skipped. Higher values of 4-8 cadence will skip over a larger amount of frames and only diffuse the “Nth” frame as set by the diffusion_cadence value. This may produce more continuity in an animation, at the cost of little opportunity to add more diffused content. In extreme examples, motion within a frame will fail to produce diverse prompt context, and the space will be filled with lines or approximations of content - resulting in unexpected animation patterns and artifacts. Video Input & Interpolation modes are not affected by diffusion_cadence.", + "Noise type": "Selects the type of noise being added to each frame", + "uniform": "Uniform noise covers the entire frame. It somewhat flattens and sharpens the video over time, but may be good for cartoonish look. This is the old default setting.", + "perlin": "Perlin noise is a more natural looking noise. It is heterogeneous and less sharp than uniform noise, this way it is more likely that new details will appear in a more coherent way. This is the new default setting.", + "Perlin W": "The width of the Perlin sample. Lower values will make larger noise regions. Think of it as inverse brush stroke width. The greater this setting, the smaller details it will affect.", + "Perlin H": "The height of the Perlin sample. Lower values will make larger noise regions. Think of it as inverse brush stroke width. The greater this setting, the smaller details it will affect.", + "Perlin octaves": "The number of Perlin noise octaves, that is the count of P-noise iterations. Higher values will make the noise more soft and smoke-like, whereas lower values will make it look more organic and spotty. It is limited by 8 octaves as the resulting gain will run out of bounds.", + "Perlin persistence": "How much of noise from each octave is added on each iteration. Higher values will make it more straighter and sharper, while lower values will make it rounder and smoother. It is limited by 1.0 as the resulting gain fill the frame completely with noise.", + "Use depth warping": "enables instructions to warp an image dynamically in 3D mode only.", + "MiDaS weight": "sets a midpoint at which a depthmap is to be drawn: range [-1 to +1]", + "Padding mode": "instructs the handling of pixels outside the field of view as they come into the scene.", + //"border": "Border will attempt to use the edges of the canvas as the pixels to be drawn", //duplicate name as another property + "reflection": "reflection will attempt to approximate the image and tile/repeat pixels", + "zeros": "zeros will not add any new pixel information", + "sampling_mode": "choose from Bicubic, Bilinear or Nearest modes. (Recommended: Bicubic)", + "Save depth maps": "will output a greyscale depth map image alongside the output images.", + + // Prompts + "Prompts": "prompts for your animation in a JSON format. Use --neg words to add 'words' as negative prompt", + "Prompts positive": "positive prompt to be appended to *all* prompts", + "Prompts negative": "negative prompt to be appended to *all* prompts. DON'T use --neg here!", + + //Init + "Use init": "Diffuse the first frame based on an image, similar to img2img.", + "Strength": "Controls the strength of the diffusion on the init image. 0 = disabled", + "Strength 0 no init": "Set the strength to 0 automatically when no init image is used", + "Init image": "the path to your init image", + "Use mask": "Use a grayscale image as a mask on your init image. Whiter areas of the mask are areas that change more.", + "Use alpha as mask": "use the alpha channel of the init image as the mask", + "Mask file": "the path to your mask image", + "Invert mask": "Inverts the colors of the mask", + "Mask brightness adjust": "adjust the brightness of the mask. Should be a positive number, with 1.0 meaning no adjustment.", + "Mask contrast adjust": "adjust the brightness of the mask. Should be a positive number, with 1.0 meaning no adjustment.", + "overlay mask": "Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding", + "Mask overlay blur": "Blur edges of final overlay mask, if used. Minimum = 0 (no blur)", + "Video init path": "the directory \/ URL at which your video file is located for Video Input mode only", + "Extract nth frame": "during the run sequence, only frames specified by this value will be extracted, saved, and diffused upon. A value of 1 indicates that every frame is to be accounted for. Values of 2 will use every other frame for the sequence. Higher values will skip that number of frames respectively.", + "Extract from frame":"start extracting the input video only from this frame number", + "Extract to frame": "stop the extraction of the video at this frame number. -1 for no limits", + "Overwrite extracted frames": "when enabled, will re-extract video frames each run. When using video_input mode, the run will be instructed to write video frames to the drive. If you’ve already populated the frames needed, uncheck this box to skip past redundant extraction, and immediately start the render. If you have not extracted frames, you must run at least once with this box checked to write the necessary frames.", + "Use mask video": "video_input mode only, enables the extraction and use of a separate video file intended for use as a mask. White areas of the extracted video frames will not be affected by diffusion, while black areas will be fully effected. Lighter/darker areas are affected dynamically.", + "Video mask path": "the directory in which your mask video is located.", + "Interpolate key frames": "selects whether to ignore prompt schedule or _x_frames.", + "Interpolate x frames": "the number of frames to transition thru between prompts (when interpolate_key_frames = true, then the numbers in front of the animation prompts will dynamically guide the images based on their value. If set to false, will ignore the prompt numbers and force interpole_x_frames value regardless of prompt number)", + "Resume from timestring": "instructs the run to start from a specified point", + "Resume timestring": "the required timestamp to reference when resuming. Currently only available in 2D & 3D mode, the timestamp is saved as the settings .txt file name as well as images produced during your previous run. The format follows: yyyymmddhhmmss - a timestamp of when the run was started to diffuse.", + + //Video Output + "Skip video for run all": "when checked, do not output a video", + "Make GIF": "create a gif in addition to .mp4 file. supports up to 30 fps, will self-disable at higher fps values", + "Upscale":"upscale the images of the next run once it's finished + make a video out of them", + "Upscale model":"model of the upscaler to use. 'realesr-animevideov3' is much faster but yields smoother, less detailed results. the other models only do x4", + "Upscale factor":"how many times to upscale, actual options depend on the chosen upscale model", + "FPS": "The frames per second that the video will run at", + "Output format": "select the type of video file to output", + "PIL gif": "create an animated GIF", + "FFMPEG mp4": "create an MP4 video file", + "FFmpeg location": "the path to where ffmpeg is located. Leave at default 'ffmpeg' if ffmpeg is in your PATH!", + "FFmpeg crf": "controls quality where lower is better, less compressed. values: 0 to 51, default 17", + "FFmpeg preset": "controls how good the compression is, and the operation speed. If you're not in a rush keep it at 'veryslow'", + "Add soundtrack": "when this box is checked, and FFMPEG mp4 is selected as the output format, an audio file will be multiplexed with the video.", + "Soundtrack path": "the path\/ URL to an audio file to accompany the video", + "Use manual settings": "when this is unchecked, the video will automatically be created in the same output folder as the images. Check this box to specify different settings for the creation of the video, specified by the following options", + "Render steps": "render each step of diffusion as a separate frame", + "Max video frames": "the maximum number of frames to include in the video, when use_manual_settings is checked", + //"path_name_modifier": "", + "Image path": "the location of images to create the video from, when use_manual_settings is checked", + "MP4 path": "the output location of the mp4 file, when use_manual_settings is checked", + "Engine": "choose the frame interpolation engine and version", + "Interp X":"how many times to interpolate the source video. e.g source video fps of 12 and a value of x2 will yield a 24fps interpolated video", + "Slow-Mo X":"how many times to slow-down the video. *Naturally affects output fps as well", + "Keep Imgs": "delete or keep raw affected (interpolated/ upscaled depending on the UI section) png imgs", + "Interpolate an existing video":"This feature allows you to interpolate any video with a dedicated button. Video could be completly unrelated to deforum", + "In Frame Count": "uploaded video total frame count", + "In FPS":"uploaded video FPS", + "Interpolated Vid FPS":"calculated output-interpolated video FPS", + "In Res":"uploaded video resolution", + "Out Res":"output video resolution", + + // Looper Args + // "use_looper": "", + "Enable guided images mode": "check this box to enable guided images mode", + "Images to use for keyframe guidance": "images you iterate over, you can do local or web paths (no single backslashes!)", + "Image strength schedule": "how much the image should look like the previou one and new image frame init. strength schedule might be better if this is higher, around .75 during the keyfames you want to switch on", + "Blend factor max": "blendFactor = blendFactorMax - blendFactorSlope * cos((frame % tweening_frames_schedule) / (tweening_frames_schedule / 2))", + "Blend factor slope": "blendFactor = blendFactorMax - blendFactorSlope * cos((frame % tweening_frames_schedule) / (tweening_frames_schedule / 2))", + "Tweening frames schedule": "number of the frames that we will blend between current imagined image and input frame image", + "Color correction factor": "how close to get to the colors of the input frame image/ the amount each frame during a tweening step to use the new images colors" +} + + +onUiUpdate(function(){ + gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){ + tooltip = deforum_titles[span.textContent]; + + if(!tooltip){ + tooltip = deforum_titles[span.value]; + } + + if(!tooltip){ + for (const c of span.classList) { + if (c in deforum_titles) { + tooltip = deforum_titles[c]; + break; + } + } + } + + if(tooltip){ + span.title = tooltip; + } + }) + + gradioApp().querySelectorAll('select').forEach(function(select){ + if (select.onchange != null) return; + + select.onchange = function(){ + select.title = deforum_titles[select.value] || ""; + } + }) +}) \ No newline at end of file diff --git a/extensions/deforum/javascript/deforum.js b/extensions/deforum/javascript/deforum.js new file mode 100644 index 0000000000000000000000000000000000000000..889e76f39c937d9eb3b602aa5510d8ac637f5e3c --- /dev/null +++ b/extensions/deforum/javascript/deforum.js @@ -0,0 +1,21 @@ +function submit_deforum(){ + // alert('Hello, Deforum!') + rememberGallerySelection('deforum_gallery') + showSubmitButtons('deforum', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('deforum_gallery_container'), gradioApp().getElementById('deforum_gallery'), function(){ + showSubmitButtons('deforum', true) + }) + + var res = create_submit_args(arguments) + + res[0] = id + // res[1] = get_tab_index('deforum') + + return res +} + +onUiUpdate(function(){ + check_gallery('deforum_gallery') +}) diff --git a/extensions/deforum/requirements.txt b/extensions/deforum/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c9d5f4ea9c75368bd75aa3804a7b9aea47254b2 --- /dev/null +++ b/extensions/deforum/requirements.txt @@ -0,0 +1,7 @@ +numexpr +matplotlib +pandas +av +pims +imageio_ffmpeg +rich \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum.py b/extensions/deforum/scripts/deforum.py new file mode 100644 index 0000000000000000000000000000000000000000..50c188c2475f58572540f615d957f8d28d3f019d --- /dev/null +++ b/extensions/deforum/scripts/deforum.py @@ -0,0 +1,318 @@ +# Detach 'deforum_helpers' from 'scripts' to prevent "No module named 'scripts.deforum_helpers'" error +# causing Deforum's tab not show up in some cases when you've might've broken the environment with webui packages updates +import sys, os, shutil + +basedirs = [os.getcwd()] +if 'google.colab' in sys.modules: + basedirs.append('/content/gdrive/MyDrive/sd/stable-diffusion-webui') #hardcode as TheLastBen's colab seems to be the primal source + +for basedir in basedirs: + deforum_paths_to_ensure = [basedir + '/extensions/deforum-for-automatic1111-webui/scripts', basedir + '/extensions/sd-webui-controlnet', basedir + '/extensions/deforum/scripts', basedir + '/scripts/deforum_helpers/src', basedir + '/extensions/deforum/scripts/deforum_helpers/src', basedir +'/extensions/deforum-for-automatic1111-webui/scripts/deforum_helpers/src',basedir] + + for deforum_scripts_path_fix in deforum_paths_to_ensure: + if not deforum_scripts_path_fix in sys.path: + sys.path.extend([deforum_scripts_path_fix]) + +# Main deforum stuff +import deforum_helpers.args as deforum_args +import deforum_helpers.settings as deforum_settings +from deforum_helpers.save_images import dump_frames_cache, reset_frames_cache +from deforum_helpers.frame_interpolation import process_video_interpolation + +import modules.scripts as wscripts +from modules import script_callbacks +import gradio as gr +import json + +from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images +from PIL import Image +from deforum_helpers.video_audio_utilities import ffmpeg_stitch_video, make_gifski_gif +from deforum_helpers.upscaling import make_upscale_v2 +import gc +import torch +from webui import wrap_gradio_gpu_call +import modules.shared as shared +from modules.shared import opts, cmd_opts, state +from modules.ui import create_output_panel, plaintext_to_html, wrap_gradio_call +from types import SimpleNamespace + +def run_deforum(*args, **kwargs): + args_dict = {deforum_args.component_names[i]: args[i+2] for i in range(0, len(deforum_args.component_names))} + p = StableDiffusionProcessingImg2Img( + sd_model=shared.sd_model, + outpath_samples = opts.outdir_samples or opts.outdir_img2img_samples, + outpath_grids = opts.outdir_grids or opts.outdir_img2img_grids, + #we'll setup the rest later + ) + + print("\033[4;33mDeforum extension for auto1111 webui, v2.2b\033[0m") + args_dict['self'] = None + args_dict['p'] = p + + root, args, anim_args, video_args, parseq_args, loop_args, controlnet_args = deforum_args.process_args(args_dict) + root.clipseg_model = None + root.initial_clipskip = opts.data["CLIP_stop_at_last_layers"] + root.basedirs = basedirs + + for basedir in basedirs: + sys.path.extend([ + basedir + '/scripts/deforum_helpers/src', + basedir + '/extensions/deforum/scripts/deforum_helpers/src', + basedir + '/extensions/deforum-for-automatic1111-webui/scripts/deforum_helpers/src', + ]) + + # clean up unused memory + reset_frames_cache(root) + gc.collect() + torch.cuda.empty_cache() + + from deforum_helpers.render import render_animation + from deforum_helpers.render_modes import render_input_video, render_animation_with_video_mask, render_interpolation + + tqdm_backup = shared.total_tqdm + shared.total_tqdm = deforum_settings.DeforumTQDM(args, anim_args, parseq_args) + try: + # dispatch to appropriate renderer + if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D': + if anim_args.use_mask_video: + render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root.animation_prompts, root) # allow mask video without an input video + else: + render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root.animation_prompts, root) + elif anim_args.animation_mode == 'Video Input': + render_input_video(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root.animation_prompts, root)#TODO: prettify code + elif anim_args.animation_mode == 'Interpolation': + render_interpolation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root.animation_prompts, root) + else: + print('Other modes are not available yet!') + finally: + shared.total_tqdm = tqdm_backup + opts.data["CLIP_stop_at_last_layers"] = root.initial_clipskip + + if video_args.store_frames_in_ram: + dump_frames_cache(root) + + from base64 import b64encode + + real_audio_track = None + if video_args.add_soundtrack != 'None': + real_audio_track = anim_args.video_init_path if video_args.add_soundtrack == 'Init Video' else video_args.soundtrack_path + + # Delete folder with duplicated imgs from OS temp folder + shutil.rmtree(root.tmp_deforum_run_duplicated_folder, ignore_errors=True) + + # Decide whether or not we need to try and frame interpolate laters + need_to_frame_interpolate = False + if video_args.frame_interpolation_engine != "None" and not video_args.skip_video_for_run_all and not video_args.store_frames_in_ram: + need_to_frame_interpolate = True + + if video_args.skip_video_for_run_all: + print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it') + else: + import subprocess + + path_name_modifier = video_args.path_name_modifier + if video_args.render_steps: # render steps from a single image + fname = f"{path_name_modifier}_%05d.png" + all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))] + newest_dir = max(all_step_dirs, key=os.path.getmtime) + image_path = os.path.join(newest_dir, fname) + print(f"Reading images from {image_path}") + mp4_path = os.path.join(newest_dir, f"{args.timestring}_{path_name_modifier}.mp4") + max_video_frames = args.steps + else: # render images for a video + image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png") + mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4") + max_video_frames = anim_args.max_frames + + exclude_keys = deforum_settings.get_keys_to_exclude('video') + video_settings_filename = os.path.join(args.outdir, f"{args.timestring}_video-settings.txt") + with open(video_settings_filename, "w+", encoding="utf-8") as f: + s = {} + for key, value in dict(video_args.__dict__).items(): + if key not in exclude_keys: + s[key] = value + json.dump(s, f, ensure_ascii=False, indent=4) + + # Stitch video using ffmpeg! + try: + ffmpeg_stitch_video(ffmpeg_location=video_args.ffmpeg_location, fps=video_args.fps, outmp4_path=mp4_path, stitch_from_frame=0, stitch_to_frame=max_video_frames, imgs_path=image_path, add_soundtrack=video_args.add_soundtrack, audio_path=real_audio_track, crf=video_args.ffmpeg_crf, preset=video_args.ffmpeg_preset) + mp4 = open(mp4_path,'rb').read() + data_url = "data:video/mp4;base64," + b64encode(mp4).decode() + deforum_args.i1_store = f'

Deforum v0.5-webui-beta

' + except Exception as e: + if need_to_frame_interpolate: + print(f"FFMPEG DID NOT STITCH ANY VIDEO. However, you requested to frame interpolate - so we will continue to frame interpolation, but you'll be left only with the interpolated frames and not a video, since ffmpeg couldn't run. Original ffmpeg error: {e}") + else: + print(f"** FFMPEG DID NOT STITCH ANY VIDEO ** Error: {e}") + pass + + if root.initial_info is None: + root.initial_info = "An error has occured and nothing has been generated!" + root.initial_info += "\nPlease, report the bug to https://github.com/deforum-art/deforum-for-automatic1111-webui/issues" + import numpy as np + a = np.random.rand(args.W, args.H, 3)*255 + root.first_frame = Image.fromarray(a.astype('uint8')).convert('RGB') + root.initial_seed = 6934 + # FRAME INTERPOLATION TIME + if need_to_frame_interpolate: + print(f"Got a request to *frame interpolate* using {video_args.frame_interpolation_engine}") + process_video_interpolation(frame_interpolation_engine=video_args.frame_interpolation_engine, frame_interpolation_x_amount=video_args.frame_interpolation_x_amount,frame_interpolation_slow_mo_enabled=video_args.frame_interpolation_slow_mo_enabled, frame_interpolation_slow_mo_amount=video_args.frame_interpolation_slow_mo_amount, orig_vid_fps=video_args.fps, deforum_models_path=root.models_path, real_audio_track=real_audio_track, raw_output_imgs_path=args.outdir, img_batch_id=args.timestring, ffmpeg_location=video_args.ffmpeg_location, ffmpeg_crf=video_args.ffmpeg_crf, ffmpeg_preset=video_args.ffmpeg_preset, keep_interp_imgs=video_args.frame_interpolation_keep_imgs, orig_vid_name=None, resolution=None) + + if video_args.make_gif and not video_args.skip_video_for_run_all and not video_args.store_frames_in_ram: + make_gifski_gif(imgs_raw_path = args.outdir, imgs_batch_id = args.timestring, fps = video_args.fps, models_folder = root.models_path, current_user_os = root.current_user_os) + + # Upscale video once generation is done: + if video_args.r_upscale_video and not video_args.skip_video_for_run_all and not video_args.store_frames_in_ram: + + # out mp4 path is defined in make_upscale func + make_upscale_v2(upscale_factor = video_args.r_upscale_factor, upscale_model = video_args.r_upscale_model, keep_imgs = video_args.r_upscale_keep_imgs, imgs_raw_path = args.outdir, imgs_batch_id = args.timestring, fps = video_args.fps, deforum_models_path = root.models_path, current_user_os = root.current_user_os, ffmpeg_location=video_args.ffmpeg_location, stitch_from_frame=0, stitch_to_frame=max_video_frames, ffmpeg_crf=video_args.ffmpeg_crf, ffmpeg_preset=video_args.ffmpeg_preset, add_soundtrack = video_args.add_soundtrack ,audio_path=real_audio_track) + + root.initial_info += "\n The animation is stored in " + args.outdir + root.initial_info += "\n Timestring = " + args.timestring + '\n' + root.initial_info += "Only the first frame is shown in webui not to clutter the memory" + reset_frames_cache(root) # cleanup the RAM in any case + processed = Processed(p, [root.first_frame], root.initial_seed, root.initial_info) + + if processed is None: + processed = process_images(p) + + shared.total_tqdm.clear() + + generation_info_js = processed.js() + if opts.samples_log_stdout: + print(generation_info_js) + + if opts.do_not_show_images: + processed.images = [] + + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html('') + +def on_ui_tabs(): + with gr.Blocks(analytics_enabled=False) as deforum_interface: + components = {} + dummy_component = gr.Label(visible=False) + with gr.Row(elem_id='deforum_progress_row').style(equal_height=False): + with gr.Column(scale=1, variant='panel'): + components = deforum_args.setup_deforum_setting_dictionary(None, True, True) + + with gr.Column(scale=1): + with gr.Row(): + btn = gr.Button("Click here after the generation to show the video") + components['btn'] = btn + close_btn = gr.Button("Close the video", visible=False) + with gr.Row(): + i1 = gr.HTML(deforum_args.i1_store, elem_id='deforum_header') + components['i1'] = i1 + # Show video + def show_vid(): + return { + i1: gr.update(value=deforum_args.i1_store, visible=True), + close_btn: gr.update(visible=True), + btn: gr.update(value="Update the video", visible=True), + } + + btn.click( + show_vid, + [], + [i1, close_btn, btn], + ) + # Close video + def close_vid(): + return { + i1: gr.update(value=deforum_args.i1_store_backup, visible=True), + close_btn: gr.update(visible=False), + btn: gr.update(value="Click here after the generation to show the video", visible=True), + } + + close_btn.click( + close_vid, + [], + [i1, close_btn, btn], + ) + id_part = 'deforum' + with gr.Row(elem_id=f"{id_part}_generate_box"): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip", visible=False) + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", visible=True) + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: state.interrupt(), + inputs=[], + outputs=[], + ) + + deforum_gallery, generation_info, html_info, html_log = create_output_panel("deforum", opts.outdir_img2img_samples) + + gr.HTML("

* Paths can be relative to webui folder OR full - absolute

") + with gr.Row(): + settings_path = gr.Textbox("deforum_settings.txt", elem_id='deforum_settings_path', label="General Settings File") + #reuse_latest_settings_btn = gr.Button('Reuse Latest', elem_id='deforum_reuse_latest_settings_btn')#TODO + with gr.Row(): + save_settings_btn = gr.Button('Save Settings', elem_id='deforum_save_settings_btn') + load_settings_btn = gr.Button('Load Settings', elem_id='deforum_load_settings_btn') + with gr.Row(): + video_settings_path = gr.Textbox("deforum_video-settings.txt", elem_id='deforum_video_settings_path', label="Video Settings File") + #reuse_latest_video_settings_btn = gr.Button('Reuse Latest', elem_id='deforum_reuse_latest_video_settings_btn')#TODO + with gr.Row(): + save_video_settings_btn = gr.Button('Save Video Settings', elem_id='deforum_save_video_settings_btn') + load_video_settings_btn = gr.Button('Load Video Settings', elem_id='deforum_load_video_settings_btn') + + # components['prompts'].visible = False#hide prompts for the time being + #TODO clean up the code + components['save_sample_per_step'].visible = False + components['show_sample_per_step'].visible = False + components['display_samples'].visible = False + + component_list = [components[name] for name in deforum_args.component_names] + + submit.click( + fn=wrap_gradio_gpu_call(run_deforum, extra_outputs=[None, '', '']), + _js="submit_deforum", + inputs=[dummy_component, dummy_component] + component_list, + outputs=[ + deforum_gallery, + generation_info, + html_info, + html_log, + ], + ) + + settings_component_list = [components[name] for name in deforum_args.settings_component_names] + video_settings_component_list = [components[name] for name in deforum_args.video_args_names] + stuff = gr.HTML("") # wrap gradio call garbage + stuff.visible = False + + save_settings_btn.click( + fn=wrap_gradio_call(deforum_settings.save_settings), + inputs=[settings_path] + settings_component_list, + outputs=[stuff], + ) + + load_settings_btn.click( + fn=wrap_gradio_call(deforum_settings.load_settings), + inputs=[settings_path]+ settings_component_list, + outputs=settings_component_list + [stuff], + ) + + save_video_settings_btn.click( + fn=wrap_gradio_call(deforum_settings.save_video_settings), + inputs=[video_settings_path] + video_settings_component_list, + outputs=[stuff], + ) + + load_video_settings_btn.click( + fn=wrap_gradio_call(deforum_settings.load_video_settings), + inputs=[video_settings_path] + video_settings_component_list, + outputs=video_settings_component_list + [stuff], + ) + + + return [(deforum_interface, "Deforum", "deforum_interface")] + +script_callbacks.on_ui_tabs(on_ui_tabs) diff --git a/extensions/deforum/scripts/deforum_helpers/__init__.py b/extensions/deforum/scripts/deforum_helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe08549e9c92f0e8a7bc96cead3277dedd12583c --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/__init__.py @@ -0,0 +1,7 @@ +""" +from .save_images import save_samples, get_output_folder +from .depth import DepthModel +from .prompt import sanitize +from .animation import construct_RotationMatrixHomogenous, getRotationMatrixManual, getPoints_for_PerspectiveTranformEstimation, warpMatrix, anim_frame_warp_2d, anim_frame_warp_3d +from .generate import add_noise, load_img, load_mask_latent, prepare_mask +""" \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/animation.py b/extensions/deforum/scripts/deforum_helpers/animation.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb6dbd206b09ada3627e89398655964fdd72845 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/animation.py @@ -0,0 +1,258 @@ +import numpy as np +import cv2 +from functools import reduce +import math +import py3d_tools as p3d +import torch +from einops import rearrange +from .prompt import check_is_number + +# Webui +from modules.shared import state + +def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: + sample = ((sample.astype(float) / 255.0) * 2) - 1 + sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) + sample = torch.from_numpy(sample) + return sample + +def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: + sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) + sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) + sample_int8 = (sample_f32 * 255) + return sample_int8.astype(type) + +def construct_RotationMatrixHomogenous(rotation_angles): + assert(type(rotation_angles)==list and len(rotation_angles)==3) + RH = np.eye(4,4) + cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3]) + return RH + +# https://en.wikipedia.org/wiki/Rotation_matrix +def getRotationMatrixManual(rotation_angles): + + rotation_angles = [np.deg2rad(x) for x in rotation_angles] + + phi = rotation_angles[0] # around x + gamma = rotation_angles[1] # around y + theta = rotation_angles[2] # around z + + # X rotation + Rphi = np.eye(4,4) + sp = np.sin(phi) + cp = np.cos(phi) + Rphi[1,1] = cp + Rphi[2,2] = Rphi[1,1] + Rphi[1,2] = -sp + Rphi[2,1] = sp + + # Y rotation + Rgamma = np.eye(4,4) + sg = np.sin(gamma) + cg = np.cos(gamma) + Rgamma[0,0] = cg + Rgamma[2,2] = Rgamma[0,0] + Rgamma[0,2] = sg + Rgamma[2,0] = -sg + + # Z rotation (in-image-plane) + Rtheta = np.eye(4,4) + st = np.sin(theta) + ct = np.cos(theta) + Rtheta[0,0] = ct + Rtheta[1,1] = Rtheta[0,0] + Rtheta[0,1] = -st + Rtheta[1,0] = st + + R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta]) + + return R + +def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength): + + ptsIn2D = ptsIn[0,:] + ptsOut2D = ptsOut[0,:] + ptsOut2Dlist = [] + ptsIn2Dlist = [] + + for i in range(0,4): + ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]]) + ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]]) + + pin = np.array(ptsIn2Dlist) + [W/2.,H/2.] + pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength) + pin = pin.astype(np.float32) + pout = pout.astype(np.float32) + + return pin, pout + + +def warpMatrix(W, H, theta, phi, gamma, scale, fV): + + # M is to be estimated + M = np.eye(4, 4) + + fVhalf = np.deg2rad(fV/2.) + d = np.sqrt(W*W+H*H) + sideLength = scale*d/np.cos(fVhalf) + h = d/(2.0*np.sin(fVhalf)) + n = h-(d/2.0) + f = h+(d/2.0) + + # Translation along Z-axis by -h + T = np.eye(4,4) + T[2,3] = -h + + # Rotation matrices around x,y,z + R = getRotationMatrixManual([phi, gamma, theta]) + + + # Projection Matrix + P = np.eye(4,4) + P[0,0] = 1.0/np.tan(fVhalf) + P[1,1] = P[0,0] + P[2,2] = -(f+n)/(f-n) + P[2,3] = -(2.0*f*n)/(f-n) + P[3,2] = -1.0 + + # pythonic matrix multiplication + F = reduce(lambda x,y : np.matmul(x,y), [P, T, R]) + + # shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. + # In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3); + ptsIn = np.array([[ + [-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.] + ]]) + ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype)) + ptsOut = cv2.perspectiveTransform(ptsIn, F) + + ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength) + + # check float32 otherwise OpenCV throws an error + assert(ptsInPt2f.dtype == np.float32) + assert(ptsOutPt2f.dtype == np.float32) + M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f) + + return M33, sideLength + +def get_flip_perspective_matrix(W, H, keys, frame_idx): + perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx] + perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx] + perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx] + perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx] + M,sl = warpMatrix(W, H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv); + post_trans_mat = np.float32([[1, 0, (W-sl)/2], [0, 1, (H-sl)/2]]) + post_trans_mat = np.vstack([post_trans_mat, [0,0,1]]) + bM = np.matmul(M, post_trans_mat) + return bM + +def flip_3d_perspective(anim_args, prev_img_cv2, keys, frame_idx): + W, H = (prev_img_cv2.shape[1], prev_img_cv2.shape[0]) + return cv2.warpPerspective( + prev_img_cv2, + get_flip_perspective_matrix(W, H, keys, frame_idx), + (W, H), + borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE + ) + +def anim_frame_warp(prev_img_cv2, args, anim_args, keys, frame_idx, depth_model=None, depth=None, device='cuda', half_precision = False): + + if anim_args.use_depth_warping: + if depth is None and depth_model is not None: + depth = depth_model.predict(prev_img_cv2, anim_args, half_precision) + else: + depth = None + + if anim_args.animation_mode == '2D': + prev_img = anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx) + else: # '3D' + prev_img = anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx) + + return prev_img, depth + +def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): + angle = keys.angle_series[frame_idx] + zoom = keys.zoom_series[frame_idx] + translation_x = keys.translation_x_series[frame_idx] + translation_y = keys.translation_y_series[frame_idx] + + center = (args.W // 2, args.H // 2) + trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) + rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) + trans_mat = np.vstack([trans_mat, [0,0,1]]) + rot_mat = np.vstack([rot_mat, [0,0,1]]) + if anim_args.enable_perspective_flip: + bM = get_flip_perspective_matrix(args.W, args.H, keys, frame_idx) + rot_mat = np.matmul(bM, rot_mat, trans_mat) + else: + rot_mat = np.matmul(rot_mat, trans_mat) + return cv2.warpPerspective( + prev_img_cv2, + rot_mat, + (prev_img_cv2.shape[1], prev_img_cv2.shape[0]), + borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE + ) + +def anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx): + TRANSLATION_SCALE = 1.0/200.0 # matches Disco + translate_xyz = [ + -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, + keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, + -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE + ] + rotate_xyz = [ + math.radians(keys.rotation_3d_x_series[frame_idx]), + math.radians(keys.rotation_3d_y_series[frame_idx]), + math.radians(keys.rotation_3d_z_series[frame_idx]) + ] + if anim_args.enable_perspective_flip: + prev_img_cv2 = flip_3d_perspective(anim_args, prev_img_cv2, keys, frame_idx) + rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0) + result = transform_image_3d(device if not device.type.startswith('mps') else torch.device('cpu'), prev_img_cv2, depth, rot_mat, translate_xyz, anim_args, keys, frame_idx) + torch.cuda.empty_cache() + return result + +def transform_image_3d(device, prev_img_cv2, depth_tensor, rot_mat, translate, anim_args, keys, frame_idx): + # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + aspect_ratio = float(w)/float(h) + near = keys.near_series[frame_idx] + far = keys.far_series[frame_idx] + fov_deg = keys.fov_series[frame_idx] + persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device) + persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device) + + # range of [-1,1] is important to torch grid_sample's padding handling + y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device)) + if depth_tensor is None: + z = torch.ones_like(x) + else: + z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device) + xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1) + + xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + + offset_xy = xyz_new_cam_xy - xyz_old_cam_xy + # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation. + identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0) + # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs. + coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False) + offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0) + + image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device) + new_image = torch.nn.functional.grid_sample( + image_tensor.add(1/512 - 0.0001).unsqueeze(0), + offset_coords_2d, + mode=anim_args.sampling_mode, + padding_mode=anim_args.padding_mode, + align_corners=False + ) + + # convert back to cv2 style numpy array + result = rearrange( + new_image.squeeze().clamp(0,255), + 'c h w -> h w c' + ).cpu().numpy().astype(prev_img_cv2.dtype) + return result diff --git a/extensions/deforum/scripts/deforum_helpers/animation_key_frames.py b/extensions/deforum/scripts/deforum_helpers/animation_key_frames.py new file mode 100644 index 0000000000000000000000000000000000000000..824a0f6d8ea36d77a30cba71c0265ad44d4b0eec --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/animation_key_frames.py @@ -0,0 +1,106 @@ +import re +import numpy as np +import numexpr +import pandas as pd +from .prompt import check_is_number + +class DeformAnimKeys(): + def __init__(self, anim_args): + self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames) + self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames) + self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames) + self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames) + self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames) + self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames) + self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames) + self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames) + self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames) + self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames) + self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames) + self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames) + self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames) + self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames) + self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames) + self.cfg_scale_schedule_series = get_inbetweens(parse_key_frames(anim_args.cfg_scale_schedule), anim_args.max_frames) + self.pix2pix_img_cfg_scale_series = get_inbetweens(parse_key_frames(anim_args.pix2pix_img_cfg_scale_schedule), anim_args.max_frames) + self.subseed_schedule_series = get_inbetweens(parse_key_frames(anim_args.subseed_schedule), anim_args.max_frames) + self.subseed_strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.subseed_strength_schedule), anim_args.max_frames) + self.checkpoint_schedule_series = get_inbetweens(parse_key_frames(anim_args.checkpoint_schedule), anim_args.max_frames, is_single_string = True) + self.steps_schedule_series = get_inbetweens(parse_key_frames(anim_args.steps_schedule), anim_args.max_frames) + self.seed_schedule_series = get_inbetweens(parse_key_frames(anim_args.seed_schedule), anim_args.max_frames) + self.sampler_schedule_series = get_inbetweens(parse_key_frames(anim_args.sampler_schedule), anim_args.max_frames, is_single_string = True) + self.clipskip_schedule_series = get_inbetweens(parse_key_frames(anim_args.clipskip_schedule), anim_args.max_frames) + self.mask_schedule_series = get_inbetweens(parse_key_frames(anim_args.mask_schedule), anim_args.max_frames, is_single_string = True) + self.noise_mask_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_mask_schedule), anim_args.max_frames, is_single_string = True) + self.kernel_schedule_series = get_inbetweens(parse_key_frames(anim_args.kernel_schedule), anim_args.max_frames) + self.sigma_schedule_series = get_inbetweens(parse_key_frames(anim_args.sigma_schedule), anim_args.max_frames) + self.amount_schedule_series = get_inbetweens(parse_key_frames(anim_args.amount_schedule), anim_args.max_frames) + self.threshold_schedule_series = get_inbetweens(parse_key_frames(anim_args.threshold_schedule), anim_args.max_frames) + self.fov_series = get_inbetweens(parse_key_frames(anim_args.fov_schedule), anim_args.max_frames) + self.near_series = get_inbetweens(parse_key_frames(anim_args.near_schedule), anim_args.max_frames) + self.far_series = get_inbetweens(parse_key_frames(anim_args.far_schedule), anim_args.max_frames) + self.hybrid_comp_alpha_schedule_series = get_inbetweens(parse_key_frames(anim_args.hybrid_comp_alpha_schedule), anim_args.max_frames) + self.hybrid_comp_mask_blend_alpha_schedule_series = get_inbetweens(parse_key_frames(anim_args.hybrid_comp_mask_blend_alpha_schedule), anim_args.max_frames) + self.hybrid_comp_mask_contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.hybrid_comp_mask_contrast_schedule), anim_args.max_frames) + self.hybrid_comp_mask_auto_contrast_cutoff_high_schedule_series = get_inbetweens(parse_key_frames(anim_args.hybrid_comp_mask_auto_contrast_cutoff_high_schedule), anim_args.max_frames) + self.hybrid_comp_mask_auto_contrast_cutoff_low_schedule_series = get_inbetweens(parse_key_frames(anim_args.hybrid_comp_mask_auto_contrast_cutoff_low_schedule), anim_args.max_frames) + +class LooperAnimKeys(): + def __init__(self, loop_args, anim_args): + self.use_looper = loop_args.use_looper + self.imagesToKeyframe = loop_args.init_images + self.image_strength_schedule_series = get_inbetweens(parse_key_frames(loop_args.image_strength_schedule), anim_args.max_frames) + self.blendFactorMax_series = get_inbetweens(parse_key_frames(loop_args.blendFactorMax), anim_args.max_frames) + self.blendFactorSlope_series = get_inbetweens(parse_key_frames(loop_args.blendFactorSlope), anim_args.max_frames) + self.tweening_frames_schedule_series = get_inbetweens(parse_key_frames(loop_args.tweening_frames_schedule), anim_args.max_frames) + self.color_correction_factor_series = get_inbetweens(parse_key_frames(loop_args.color_correction_factor), anim_args.max_frames) + +def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear', is_single_string = False): + key_frame_series = pd.Series([np.nan for a in range(max_frames)]) + for i in range(0, max_frames): + if i in key_frames: + value = key_frames[i] + value_is_number = check_is_number(value) + # if it's only a number, leave the rest for the default interpolation + if value_is_number: + t = i + key_frame_series[i] = value + if not value_is_number: + t = i + if is_single_string: + if value.find("'") > -1: + value = value.replace("'","") + if value.find('"') > -1: + value = value.replace('"',"") + key_frame_series[i] = numexpr.evaluate(value) if not is_single_string else value # workaround for values formatted like 0:("I am test") //used for sampler schedules + key_frame_series = key_frame_series.astype(float) if not is_single_string else key_frame_series # as string + + if interp_method == 'Cubic' and len(key_frames.items()) <= 3: + interp_method = 'Quadratic' + if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: + interp_method = 'Linear' + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] + key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both') + if integer: + return key_frame_series.astype(int) + return key_frame_series + +def parse_key_frames(string, prompt_parser=None): + # because math functions (i.e. sin(t)) can utilize brackets + # it extracts the value in form of some stuff + # which has previously been enclosed with brackets and + # with a comma or end of line existing after the closing one + pattern = r'((?P[0-9]+):[\s]*\((?P[\S\s]*?)\)([,][\s]?|[\s]?$))' + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()['frame']) + param = match_object.groupdict()['param'] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + if frames == {} and len(string) != 0: + raise RuntimeError('Key Frame string not correctly formatted') + return frames \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/args.py b/extensions/deforum/scripts/deforum_helpers/args.py new file mode 100644 index 0000000000000000000000000000000000000000..cd86df83a1247111560a23920971e4d702bf4f7b --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/args.py @@ -0,0 +1,1214 @@ +from modules.shared import cmd_opts +from modules.processing import get_fixed_seed +from modules.ui_components import FormRow +import modules.shared as sh +import modules.paths as ph +import os +from .frame_interpolation import set_interp_out_fps, gradio_f_interp_get_fps_and_fcount, process_interp_vid_upload_logic +from .upscaling import process_upscale_vid_upload_logic, process_ncnn_upscale_vid_upload_logic +from .video_audio_utilities import find_ffmpeg_binary, ffmpeg_stitch_video, direct_stitch_vid_from_frames, get_quick_vid_info, extract_number +from .gradio_funcs import * +from .general_utils import get_os +from .deforum_controlnet import controlnet_component_names, setup_controlnet_ui +import tempfile + +def Root(): + device = sh.device + models_path = ph.models_path + '/Deforum' + half_precision = not cmd_opts.no_half + mask_preset_names = ['everywhere','init_mask','video_mask'] + p = None + frames_cache = [] + initial_seed = None + initial_info = None + first_frame = None + outpath_samples = "" + animation_prompts = None + color_corrections = None + initial_clipskip = None + current_user_os = get_os() + tmp_deforum_run_duplicated_folder = os.path.join(tempfile.gettempdir(), 'tmp_run_deforum') + return locals() + +def DeforumAnimArgs(): + + #@markdown ####**Animation:** + animation_mode = '2D' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} + max_frames = 120 #@param {type:"number"} + border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'} + #@markdown ####**Motion Parameters:** + angle = "0:(0)"#@param {type:"string"} + zoom = "0:(1.0025+0.002*sin(1.25*3.14*t/30))"#@param {type:"string"} + translation_x = "0:(0)"#@param {type:"string"} + translation_y = "0:(0)"#@param {type:"string"} + translation_z = "0:(1.75)"#@param {type:"string"} + rotation_3d_x = "0:(0)"#@param {type:"string"} + rotation_3d_y = "0:(0)"#@param {type:"string"} + rotation_3d_z = "0:(0)"#@param {type:"string"} + enable_perspective_flip = False #@param {type:"boolean"} + perspective_flip_theta = "0:(0)"#@param {type:"string"} + perspective_flip_phi = "0:(0)"#@param {type:"string"} + perspective_flip_gamma = "0:(0)"#@param {type:"string"} + perspective_flip_fv = "0:(53)"#@param {type:"string"} + noise_schedule = "0: (0.065)"#@param {type:"string"} + strength_schedule = "0: (0.65)"#@param {type:"string"} + contrast_schedule = "0: (1.0)"#@param {type:"string"} + cfg_scale_schedule = "0: (7)" + enable_steps_scheduling = False#@param {type:"boolean"} + steps_schedule = "0: (25)"#@param {type:"string"} + fov_schedule = "0: (70)" + near_schedule = "0: (200)" + far_schedule = "0: (10000)" + seed_schedule = "0:(5), 1:(-1), 219:(-1), 220:(5)" + pix2pix_img_cfg_scale = "1.5" + pix2pix_img_cfg_scale_schedule = "0:(1.5)" + enable_subseed_scheduling = False + subseed_schedule = "0:(1)" + subseed_strength_schedule = "0:(0)" + + # Sampler Scheduling + enable_sampler_scheduling = False #@param {type:"boolean"} + sampler_schedule = '0: ("Euler a")' + + # Composable mask scheduling + use_noise_mask = False + mask_schedule = '0: ("!({everywhere}^({init_mask}|{video_mask}) ) ")' + noise_mask_schedule = '0: ("!({everywhere}^({init_mask}|{video_mask}) ) ")' + # Checkpoint Scheduling + enable_checkpoint_scheduling = False#@param {type:"boolean"} + checkpoint_schedule = '0: ("model1.ckpt"), 100: ("model2.ckpt")' + + # CLIP skip Scheduling + enable_clipskip_scheduling = False #@param {type:"boolean"} + clipskip_schedule = '0: (2)' + + # Anti-blur + kernel_schedule = "0: (5)" + sigma_schedule = "0: (1.0)" + amount_schedule = "0: (0.35)" + threshold_schedule = "0: (0.0)" + # Hybrid video + hybrid_comp_alpha_schedule = "0:(1)" #@param {type:"string"} + hybrid_comp_mask_blend_alpha_schedule = "0:(0.5)" #@param {type:"string"} + hybrid_comp_mask_contrast_schedule = "0:(1)" #@param {type:"string"} + hybrid_comp_mask_auto_contrast_cutoff_high_schedule = "0:(100)" #@param {type:"string"} + hybrid_comp_mask_auto_contrast_cutoff_low_schedule = "0:(0)" #@param {type:"string"} + + #@markdown ####**Coherence:** + color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB', 'Video Input'] {type:'string'} + color_coherence_video_every_N_frames = 1 #@param {type:"integer"} + color_force_grayscale = False #@param {type:"boolean"} + diffusion_cadence = '2' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} + + #@markdown ####**Noise settings:** + noise_type = 'perlin' #@param ['uniform', 'perlin'] {type:'string'} + # Perlin params + perlin_w = 8 #@param {type:"number"} + perlin_h = 8 #@param {type:"number"} + perlin_octaves = 4 #@param {type:"number"} + perlin_persistence = 0.5 #@param {type:"number"} + + #@markdown ####**3D Depth Warping:** + use_depth_warping = True #@param {type:"boolean"} + midas_weight = 0.2 #@param {type:"number"} + + padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} + sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} + save_depth_maps = False #@param {type:"boolean"} + + #@markdown ####**Video Input:** + video_init_path ='https://github.com/hithereai/d/releases/download/m/vid.mp4' #@param {type:"string"} + extract_nth_frame = 1#@param {type:"number"} + extract_from_frame = 0 #@param {type:"number"} + extract_to_frame = -1 #@param {type:"number"} minus 1 for unlimited frames + overwrite_extracted_frames = True #@param {type:"boolean"} + use_mask_video = False #@param {type:"boolean"} + video_mask_path ='/content/video_in.mp4'#@param {type:"string"} + + #@markdown ####**Hybrid Video for 2D/3D Animation Mode:** + hybrid_generate_inputframes = False #@param {type:"boolean"} + hybrid_generate_human_masks = "None" #@param ['None','PNGs','Video', 'Both'] + hybrid_use_first_frame_as_init_image = True #@param {type:"boolean"} + hybrid_motion = "None" #@param ['None','Optical Flow','Perspective','Affine'] + hybrid_motion_use_prev_img = False #@param {type:"boolean"} + hybrid_flow_method = "Farneback" #@param ['DIS Medium','Farneback'] + hybrid_composite = False #@param {type:"boolean"} + hybrid_comp_mask_type = "None" #@param ['None', 'Depth', 'Video Depth', 'Blend', 'Difference'] + hybrid_comp_mask_inverse = False #@param {type:"boolean"} + hybrid_comp_mask_equalize = "None" #@param ['None','Before','After','Both'] + hybrid_comp_mask_auto_contrast = False #@param {type:"boolean"} + hybrid_comp_save_extra_frames = False #@param {type:"boolean"} + + #@markdown ####**Resume Animation:** + resume_from_timestring = False #@param {type:"boolean"} + resume_timestring = "20220829210106" #@param {type:"string"} + + return locals() + +# def DeforumPrompts(): + # return + +def DeforumAnimPrompts(): + return r"""{ + "0": "tiny cute swamp bunny, highly detailed, intricate, ultra hd, sharp photo, crepuscular rays, in focus, by tomasz alen kopera", + "30": "anthropomorphic clean cat, surrounded by fractals, epic angle and pose, symmetrical, 3d, depth of field, ruan jia and fenghua zhong", + "60": "a beautiful coconut --neg photo, realistic", + "90": "a beautiful durian, trending on Artstation" +} + """ + +def DeforumArgs(): + #@markdown **Image Settings** + W = 512 #@param + H = 512 #@param + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + #@markdonw **Webui stuff** + tiling = False + restore_faces = False + seed_enable_extras = False + subseed = -1 + subseed_strength = 0 + seed_resize_from_w = 0 + seed_resize_from_h = 0 + + #@markdown **Sampling Settings** + seed = -1 #@param + sampler = 'euler_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] + steps = 25 #@param + scale = 7 #@param + ddim_eta = 0.0 #@param + dynamic_threshold = None + static_threshold = None + + #@markdown **Save & Display Settings** + save_samples = True #@param {type:"boolean"} + save_settings = True #@param {type:"boolean"} + display_samples = True #@param {type:"boolean"} + save_sample_per_step = False #@param {type:"boolean"} + show_sample_per_step = False #@param {type:"boolean"} + + #@markdown **Prompt Settings** + prompt_weighting = False #@param {type:"boolean"} + normalize_prompt_weights = True #@param {type:"boolean"} + log_weighted_subprompts = False #@param {type:"boolean"} + + #@markdown **Batch Settings** + n_batch = 1 #@param + batch_name = "Deforum" #@param {type:"string"} + filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] + seed_behavior = "iter" #@param ["iter","fixed","random","ladder","alternate","schedule"] + seed_iter_N = 1 #@param {type:'integer'} + # make_grid = False #@param {type:"boolean"} + # grid_rows = 2 #@param + outdir = ""#get_output_folder(output_path, batch_name) + + #@markdown **Init Settings** + use_init = False #@param {type:"boolean"} + strength = 0.0 #@param {type:"number"} + strength_0_no_init = True # Set the strength to 0 automatically when no init image is used + init_image = "https://github.com/hithereai/d/releases/download/m/kaba.png" #@param {type:"string"} + # Whiter areas of the mask are areas that change more + use_mask = False #@param {type:"boolean"} + use_alpha_as_mask = False # use the alpha channel of the init image as the mask + mask_file = "https://github.com/hithereai/d/releases/download/m/mask.jpg" #@param {type:"string"} + invert_mask = False #@param {type:"boolean"} + # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. + mask_contrast_adjust = 1.0 #@param {type:"number"} + mask_brightness_adjust = 1.0 #@param {type:"number"} + # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding + overlay_mask = True # {type:"boolean"} + # Blur edges of final overlay mask, if used. Minimum = 0 (no blur) + mask_overlay_blur = 4 # {type:"number"} + + fill = 1 #MASKARGSEXPANSION Todo : Rename and convert to same formatting as used in img2img masked content + full_res_mask = True + full_res_mask_padding = 4 + reroll_blank_frames = 'reroll' # reroll, interrupt, or ignore + + n_samples = 1 # doesnt do anything + precision = 'autocast' + C = 4 + f = 8 + + prompt = "" + timestring = "" + init_latent = None + init_sample = None + init_c = None + mask_image = None + noise_mask = None + seed_internal = 0 + + return locals() + +def keyframeExamples(): + return '''{ + "0": "https://user-images.githubusercontent.com/121192995/215279228-1673df8a-f919-4380-b04c-19379b2041ff.png", + "50": "https://user-images.githubusercontent.com/121192995/215279281-7989fd6f-4b9b-4d90-9887-b7960edd59f8.png", + "100": "https://user-images.githubusercontent.com/121192995/215279284-afc14543-d220-4142-bbf4-503776ca2b8b.png", + "150": "https://user-images.githubusercontent.com/121192995/215279286-23378635-85b3-4457-b248-23e62c048049.jpg", + "200": "https://user-images.githubusercontent.com/121192995/215279228-1673df8a-f919-4380-b04c-19379b2041ff.png" +}''' + +def LoopArgs(): + use_looper = False + init_images = keyframeExamples() + image_strength_schedule = "0:(0.75)" + blendFactorMax = "0:(0.35)" + blendFactorSlope = "0:(0.25)" + tweening_frames_schedule = "0:(20)" + color_correction_factor = "0:(0.075)" + return locals() + +def ParseqArgs(): + parseq_manifest = None + parseq_use_deltas = True + return locals() + +def DeforumOutputArgs(): + skip_video_for_run_all = False #@param {type: 'boolean'} + fps = 15 #@param {type:"number"} + make_gif = False + image_path = "C:/SD/20230124234916_%05d.png" #@param {type:"string"} + mp4_path = "testvidmanualsettings.mp4" #@param {type:"string"} + ffmpeg_location = find_ffmpeg_binary() + ffmpeg_crf = '17' + ffmpeg_preset = 'slow' + add_soundtrack = 'None' #@param ["File","Init Video"] + soundtrack_path = "https://freetestdata.com/wp-content/uploads/2021/09/Free_Test_Data_1MB_MP3.mp3" + # End-Run upscaling + r_upscale_video = False + r_upscale_factor = 'x2' # ['2x', 'x3', 'x4'] + # **model below** - 'realesr-animevideov3' (default of realesrgan engine, does 2-4x), the rest do only 4x: 'realesrgan-x4plus', 'realesrgan-x4plus-anime' + r_upscale_model = 'realesr-animevideov3' + r_upscale_keep_imgs = True + + render_steps = False #@param {type: 'boolean'} + path_name_modifier = "x0_pred" #@param ["x0_pred","x"] + # max_video_frames = 200 #@param {type:"string"} + store_frames_in_ram = False #@param {type: 'boolean'} + #@markdown **Interpolate Video Settings** + # todo: change them to support FILM interpolation as well + frame_interpolation_engine = "None" #@param ["None", "RIFE v4.6", "FILM"] + frame_interpolation_x_amount = 2 # [2 to 1000 depends on the engine] + frame_interpolation_slow_mo_enabled = False + frame_interpolation_slow_mo_amount = 2 #@param [2 to 10] + frame_interpolation_keep_imgs = False #@param {type: 'boolean'} + return locals() + +import gradio as gr +import os +import time +from types import SimpleNamespace + +i1_store_backup = "

Deforum extension for auto1111 — version 2.2b

" +i1_store = i1_store_backup + +mask_fill_choices=['fill', 'original', 'latent noise', 'latent nothing'] + +def setup_deforum_setting_dictionary(self, is_img2img, is_extension = True): + d = SimpleNamespace(**DeforumArgs()) #default args + da = SimpleNamespace(**DeforumAnimArgs()) #default anim args + dp = SimpleNamespace(**ParseqArgs()) #default parseq ars + dv = SimpleNamespace(**DeforumOutputArgs()) #default video args + dr = SimpleNamespace(**Root()) # ROOT args + dloopArgs = SimpleNamespace(**LoopArgs()) + if not is_extension: + with gr.Row(): + btn = gr.Button("Click here after the generation to show the video") + with gr.Row(): + i1 = gr.HTML(i1_store, elem_id='deforum_header') + else: + btn = i1 = gr.HTML("") + + # MAIN (TOP) EXTENSION INFO ACCORD + with gr.Accordion("Info, Links and Help", open=False, elem_id='main_top_info_accord'): + gr.HTML("""Made by deforum.github.io, port for AUTOMATIC1111's webui maintained by kabachuha""") + gr.HTML("""FOR HELP CLICK HERE +
  • The code for this extension: here.
  • +
  • Join the official Deforum Discord to share your creations and suggestions.
  • +
  • Official Deforum Wiki: here.
  • +
  • Anime-inclined great guide (by FizzleDorf) with lots of examples: here.
  • +
  • For advanced keyframing with Math functions, see here.
  • +
  • Alternatively, use sd-parseq as a UI to define your animation schedules (see the Parseq section in the Keyframes tab).
  • +
  • framesync.xyz is also a good option, it makes compact math formulae for Deforum keyframes by selecting various waveforms.
  • +
  • The other site allows for making keyframes using interactive splines and Bezier curves (select Disco output format).
  • +
  • If you want to use Width/Height which are not multiples of 64, please change noise_type to 'Uniform', in Keyframes --> Noise.
  • + + If you liked this extension, please give it a star on GitHub! 😊""") + if not is_extension: + def show_vid(): + return { + i1: gr.update(value=i1_store, visible=True) + } + + btn.click( + show_vid, + [], + [i1] + ) + + with gr.Blocks(): + # RUN TAB + with gr.Tab('Run'): + from modules.sd_samplers import samplers_for_img2img + with gr.Row(variant='compact'): + sampler = gr.Dropdown(label="Sampler", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="value", elem_id="sampler", interactive=True) + steps = gr.Slider(label="Steps", minimum=0, maximum=200, step=1, value=d.steps, interactive=True) + with gr.Row(variant='compact'): + W = gr.Slider(label="Width", minimum=64, maximum=2048, step=64, value=d.W, interactive=True) + H = gr.Slider(label="Height", minimum=64, maximum=2048, step=64, value=d.H, interactive=True) + with gr.Row(variables='compact'): + seed = gr.Number(label="Seed", value=d.seed, interactive=True, precision=0) + batch_name = gr.Textbox(label="Batch name", lines=1, interactive=True, value = d.batch_name) + with gr.Accordion('Restore Faces, Tiling & more', open=False) as run_more_settings_accord: + with gr.Row(variant='compact'): + restore_faces = gr.Checkbox(label='Restore Faces', value=d.restore_faces) + tiling = gr.Checkbox(label='Tiling', value=False) + ddim_eta = gr.Number(label="DDIM Eta", value=d.ddim_eta, interactive=True) + with gr.Row() as pix2pix_img_cfg_scale_row: + pix2pix_img_cfg_scale_schedule = gr.Textbox(label="Pix2Pix img CFG schedule", value=da.pix2pix_img_cfg_scale_schedule, interactive=True) + # RUN FROM SETTING FILE ACCORD + with gr.Accordion('Resume & Run from file', open=False): + with gr.Tab('Run from Settings file'): + with gr.Row(variant='compact'): + override_settings_with_file = gr.Checkbox(label="Override settings", value=False, interactive=True, elem_id='override_settings') + custom_settings_file = gr.Textbox(label="Custom settings file", lines=1, interactive=True, elem_id='custom_settings_file') + # RESUME ANIMATION ACCORD + with gr.Tab('Resume Animation'): + with gr.Row(variant='compact'): + resume_from_timestring = gr.Checkbox(label="Resume from timestring", value=da.resume_from_timestring, interactive=True) + resume_timestring = gr.Textbox(label="Resume timestring", lines=1, value = da.resume_timestring, interactive=True) + # KEYFRAMES TAB + with gr.Tab('Keyframes'): #TODO make a some sort of the original dictionary parsing + with gr.Row(variant='compact'): + with gr.Column(scale=2): + animation_mode = gr.Radio(['2D', '3D', 'Interpolation', 'Video Input'], label="Animation mode", value=da.animation_mode, elem_id="animation_mode") + with gr.Column(scale=1, min_width=180): + border = gr.Radio(['replicate', 'wrap'], label="Border", value=da.border, elem_id="border") + with gr.Row(variant='compact'): + diffusion_cadence = gr.Slider(label="Cadence", minimum=1, maximum=50, step=1, value=da.diffusion_cadence, interactive=True) + max_frames = gr.Number(label="Max frames", lines=1, value = da.max_frames, interactive=True, precision=0) + # GUIDED IMAGES ACCORD + with gr.Accordion('Guided Images', open=False, elem_id='guided_images_accord') as guided_images_accord: + # GUIDED IMAGES INFO ACCORD + with gr.Accordion('*READ ME before you use this mode!*', open=False): + gr.HTML("""You can use this as a guided image tool or as a looper depending on your settings in the keyframe images field. + Set the keyframes and the images that you want to show up. + Note: the number of frames between each keyframe should be greater than the tweening frames.""") + # In later versions this should be also in the strength schedule, but for now you need to set it. + gr.HTML("""Prerequisites and Important Info: +
      +
    • This mode works ONLY with 2D/3D animation modes. Interpolation and Video Input modes aren't supported. +
    • Set Init tab's strength slider greater than 0. Recommended value (.65 - .80). +
    • Set 'seed_behavior' to 'schedule' under the Seed Scheduling section below.
    • +
    + """) + gr.HTML("""Looping recommendations: +
      +
    • seed_schedule should start and end on the same seed.
      + Example: seed_schedule could use 0:(5), 1:(-1), 219:(-1), 220:(5)
    • +
    • The 1st and last keyframe images should match.
    • +
    • Set your total number of keyframes to be 21 more than the last inserted keyframe image.
      + Example: Default args should use 221 as total keyframes.
    • +
    • Prompts are stored in JSON format. If you've got an error, check it in validator, like here
    • +
    + """) + with gr.Row(): + use_looper = gr.Checkbox(label="Enable guided images mode", value=dloopArgs.use_looper, interactive=True) + with gr.Row(): + init_images = gr.Textbox(label="Images to use for keyframe guidance", lines=9, value = keyframeExamples(), interactive=True) + # GUIDED IMAGES SCHEDULES ACCORD + with gr.Accordion('Guided images schedules', open=False): + with gr.Row(): + image_strength_schedule = gr.Textbox(label="Image strength schedule", lines=1, value = dloopArgs.image_strength_schedule, interactive=True) + with gr.Row(): + blendFactorMax = gr.Textbox(label="Blend factor max", lines=1, value = dloopArgs.blendFactorMax, interactive=True) + with gr.Row(): + blendFactorSlope = gr.Textbox(label="Blend factor slope", lines=1, value = dloopArgs.blendFactorSlope, interactive=True) + with gr.Row(): + tweening_frames_schedule = gr.Textbox(label="Tweening frames schedule", lines=1, value = dloopArgs.tweening_frames_schedule, interactive=True) + with gr.Row(): + color_correction_factor = gr.Textbox(label="Color correction factor", lines=1, value = dloopArgs.color_correction_factor, interactive=True) + # EXTA SCHEDULES TABS + with gr.Tabs(elem_id='extra_schedules'): + with gr.TabItem('Strength'): + strength_schedule = gr.Textbox(label="Strength schedule", lines=1, value = da.strength_schedule, interactive=True) + with gr.TabItem('CFG'): + cfg_scale_schedule = gr.Textbox(label="CFG scale schedule", lines=1, value = da.cfg_scale_schedule, interactive=True) + with gr.TabItem('Seed') as a3: + with gr.Row(): + seed_behavior = gr.Radio(['iter', 'fixed', 'random', 'ladder', 'alternate', 'schedule'], label="Seed behavior", value=d.seed_behavior, elem_id="seed_behavior") + with gr.Row() as seed_iter_N_row: + seed_iter_N = gr.Number(label="Seed iter N", value=d.seed_iter_N, interactive=True, precision=0) + with gr.Row(visible=False) as seed_schedule_row: + seed_schedule = gr.Textbox(label="Seed schedule", lines=1, value = da.seed_schedule, interactive=True) + with gr.TabItem('SubSeed', open=False) as subseed_sch_tab: + enable_subseed_scheduling = gr.Checkbox(label="Enable Subseed scheduling", value=da.enable_subseed_scheduling, interactive=True) + subseed_schedule = gr.Textbox(label="Subseed schedule", lines=1, value = da.subseed_schedule, interactive=True) + subseed_strength_schedule = gr.Textbox(label="Subseed strength schedule", lines=1, value = da.subseed_strength_schedule, interactive=True) + with gr.Row(variant='compact'): + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) + # Steps Scheduling + with gr.TabItem('Step') as a13: + with gr.Row(): + enable_steps_scheduling = gr.Checkbox(label="Enable steps scheduling", value=da.enable_steps_scheduling, interactive=True) + with gr.Row(): + steps_schedule = gr.Textbox(label="Steps schedule", lines=1, value = da.steps_schedule, interactive=True) + # Sampler Scheduling + with gr.TabItem('Sampler') as a14: + with gr.Row(): + enable_sampler_scheduling = gr.Checkbox(label="Enable sampler scheduling", value=da.enable_sampler_scheduling, interactive=True) + with gr.Row(): + sampler_schedule = gr.Textbox(label="Sampler schedule", lines=1, value = da.sampler_schedule, interactive=True) + # Checkpoint Scheduling + with gr.TabItem('Checkpoint') as a15: + with gr.Row(): + enable_checkpoint_scheduling = gr.Checkbox(label="Enable checkpoint scheduling", value=da.enable_checkpoint_scheduling, interactive=True) + with gr.Row(): + checkpoint_schedule = gr.Textbox(label="Checkpoint schedule", lines=1, value = da.checkpoint_schedule, interactive=True) + with gr.TabItem('CLIP Skip', open=False) as a16: + with gr.Row(): + enable_clipskip_scheduling = gr.Checkbox(label="Enable CLIP skip scheduling", value=da.enable_clipskip_scheduling, interactive=True) + with gr.Row(): + clipskip_schedule = gr.Textbox(label="CLIP skip schedule", lines=1, value = da.clipskip_schedule, interactive=True) + # MOTION INNER TAB + with gr.Tab('Motion') as motion_tab: + with gr.Column(visible=True) as only_2d_motion_column: + with gr.Row(variant='compact'): + angle = gr.Textbox(label="Angle", lines=1, value = da.angle, interactive=True) + with gr.Row(variant='compact'): + zoom = gr.Textbox(label="Zoom", lines=1, value = da.zoom, interactive=True) + with gr.Column(visible=True) as both_anim_mode_motion_params_column: + with gr.Row(variant='compact'): + translation_x = gr.Textbox(label="Translation X", lines=1, value = da.translation_x, interactive=True) + with gr.Row(variant='compact'): + translation_y = gr.Textbox(label="Translation Y", lines=1, value = da.translation_y, interactive=True) + with gr.Column(visible=False) as only_3d_motion_column: + with gr.Row(variant='compact'): + translation_z = gr.Textbox(label="Translation Z", lines=1, value = da.translation_z, interactive=True) + with gr.Row(variant='compact'): + rotation_3d_x = gr.Textbox(label="Rotation 3D X", lines=1, value = da.rotation_3d_x, interactive=True) + with gr.Row(variant='compact'): + rotation_3d_y = gr.Textbox(label="Rotation 3D Y", lines=1, value = da.rotation_3d_y, interactive=True) + with gr.Row(variant='compact'): + rotation_3d_z = gr.Textbox(label="Rotation 3D Z", lines=1, value = da.rotation_3d_z, interactive=True) + # 3D DEPTH & FOV ACCORD + with gr.Accordion('Depth Warping & FOV', visible=False, open=False) as depth_3d_warping_accord: + with gr.Tab('Depth Warping'): + with gr.Row(variant='compact'): + use_depth_warping = gr.Checkbox(label="Use depth warping", value=da.use_depth_warping, interactive=True) + midas_weight = gr.Number(label="MiDaS weight", value=da.midas_weight, interactive=True) + with gr.Row(variant='compact'): + padding_mode = gr.Radio(['border', 'reflection', 'zeros'], label="Padding mode", value=da.padding_mode, elem_id="padding_mode") + sampling_mode = gr.Radio(['bicubic', 'bilinear', 'nearest'], label="Sampling mode", value=da.sampling_mode, elem_id="sampling_mode") + with gr.Tab('Field Of View', visible=False, open=False) as fov_accord: + with gr.Row(variant='compact'): + fov_schedule = gr.Textbox(label="FOV schedule", lines=1, value = da.fov_schedule, interactive=True) + with gr.Row(): + near_schedule = gr.Textbox(label="Near schedule", lines=1, value = da.near_schedule, interactive=True) + with gr.Row(): + far_schedule = gr.Textbox(label="Far schedule", lines=1, value = da.far_schedule, interactive=True) + # PERSPECTIVE FLIP ACCORD + with gr.Accordion('Perspective Flip', open=False) as perspective_flip_accord: + with gr.Row(): + enable_perspective_flip = gr.Checkbox(label="Enable perspective flip", value=da.enable_perspective_flip, interactive=True) + with gr.Row(): + perspective_flip_theta = gr.Textbox(label="Perspective flip theta", lines=1, value = da.perspective_flip_theta, interactive=True) + with gr.Row(): + perspective_flip_phi = gr.Textbox(label="Perspective flip phi", lines=1, value = da.perspective_flip_phi, interactive=True) + with gr.Row(): + perspective_flip_gamma = gr.Textbox(label="Perspective flip gamma", lines=1, value = da.perspective_flip_gamma, interactive=True) + with gr.Row(): + perspective_flip_fv = gr.Textbox(label="Perspective flip fv", lines=1, value = da.perspective_flip_fv, interactive=True) + # NOISE INNER TAB + with gr.Tab('Noise', open=True) as a8: + with gr.Row(): + noise_type = gr.Radio(['uniform', 'perlin'], label="Noise type", value=da.noise_type, elem_id="noise_type") + with gr.Row(): + noise_schedule = gr.Textbox(label="Noise schedule", lines=1, value = da.noise_schedule, interactive=True) + with gr.Row() as perlin_row: + with gr.Column(min_width=220): + perlin_octaves = gr.Slider(label="Perlin octaves", minimum=1, maximum=7, value=da.perlin_octaves, step=1, interactive=True) + with gr.Column(min_width=220): + perlin_persistence = gr.Slider(label="Perlin persistence", minimum=0, maximum=1, value=da.perlin_persistence, step=0.02, interactive=True) + # COHERENCE INNER TAB + with gr.Tab('Coherence', open=False) as coherence_accord: + with gr.Row(equal_height=True): + # Future TODO: remove 'match frame 0' prefix (after we manage the deprecated-names settings import), then convert from Dropdown to Radio! + color_coherence = gr.Dropdown(label="Color coherence", choices=['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB', 'Video Input'], value=da.color_coherence, type="value", elem_id="color_coherence", interactive=True) + with gr.Column() as force_grayscale_column: + color_force_grayscale = gr.Checkbox(label="Color force Grayscale", value=da.color_force_grayscale, interactive=True) + with gr.Row(visible=False) as color_coherence_video_every_N_frames_row: + color_coherence_video_every_N_frames = gr.Number(label="Color coherence video every N frames", value=1, interactive=True) + with gr.Row(): + contrast_schedule = gr.Textbox(label="Contrast schedule", lines=1, value = da.contrast_schedule, interactive=True) + with gr.Row(): + # what to do with blank frames (they may result from glitches or the NSFW filter being turned on): reroll with +1 seed, interrupt the animation generation, or do nothing + reroll_blank_frames = gr.Radio(['reroll', 'interrupt', 'ignore'], label="Reroll blank frames", value=d.reroll_blank_frames, elem_id="reroll_blank_frames") + # ANTI BLUR INNER TAB + with gr.Tab('Anti Blur', open=False, elem_id='anti_blur_accord') as anti_blur_tab: + with gr.Row(variant='compact'): + kernel_schedule = gr.Textbox(label="Kernel schedule", lines=1, value = da.kernel_schedule, interactive=True) + with gr.Row(variant='compact'): + sigma_schedule = gr.Textbox(label="Sigma schedule", lines=1, value = da.sigma_schedule, interactive=True) + with gr.Row(variant='compact'): + amount_schedule = gr.Textbox(label="Amount schedule", lines=1, value = da.amount_schedule, interactive=True) + with gr.Row(variant='compact'): + threshold_schedule = gr.Textbox(label="Threshold schedule", lines=1, value = da.threshold_schedule, interactive=True) + # PROMPTS TAB + with gr.Tab('Prompts'): + # PROMPTS INFO ACCORD + with gr.Accordion(label='*Important* notes on Prompts', elem_id='prompts_info_accord', open=False, visible=True) as prompts_info_accord: + gr.HTML(""" +
      +
    • Please always keep values in math functions above 0.
    • +
    • There is *no* Batch mode like in vanilla deforum. Please Use the txt2img tab for that.
    • +
    • For negative prompts, please write your positive prompt, then --neg ugly, text, assymetric, or any other negative tokens of your choice. OR:
    • +
    • Use the negative_prompts field to automatically append all words as a negative prompt. *Don't* add --neg in the negative_prompts field!
    • +
    • Prompts are stored in JSON format. If you've got an error, check it in a JSON Validator
    • +
    + """) + with gr.Row(): + animation_prompts = gr.Textbox(label="Prompts", lines=8, interactive=True, value = DeforumAnimPrompts()) + with gr.Row(): + animation_prompts_positive = gr.Textbox(label="Prompts positive", lines=1, interactive=True, value = "") + with gr.Row(): + animation_prompts_negative = gr.Textbox(label="Prompts negative", lines=1, interactive=True, value = "") + # COMPOSABLE MASK SCHEDULING ACCORD + with gr.Accordion('Composable Mask scheduling', open=False): + gr.HTML(""" +
      +
    • To enable, check use_mask in the Init tab
    • +
    • Supports boolean operations: (! - negation, & - and, | - or, ^ - xor, \ - difference, () - nested operations)
    • +
    • default variables: in \{\}, like \{init_mask\}, \{video_mask\}, \{everywhere\}
    • +
    • masks from files: in [], like [mask1.png]
    • +
    • description-based: word masks in <>, like <apple>, <hair>
    • +
    + """) + with gr.Row(): + mask_schedule = gr.Textbox(label="Mask schedule", lines=1, value = da.mask_schedule, interactive=True) + with gr.Row(): + use_noise_mask = gr.Checkbox(label="Use noise mask", value=da.use_noise_mask, interactive=True) + with gr.Row(): + noise_mask_schedule = gr.Textbox(label="Noise mask schedule", lines=1, value = da.noise_mask_schedule, interactive=True) + # INIT MAIN TAB + with gr.Tab('Init'): + # IMAGE INIT INNER-TAB + with gr.Tab('Image Init'): + with gr.Row(): + with gr.Column(min_width=150): + use_init = gr.Checkbox(label="Use init", value=d.use_init, interactive=True, visible=True) + with gr.Column(min_width=150): + strength_0_no_init = gr.Checkbox(label="Strength 0 no init", value=True, interactive=True) + with gr.Column(min_width=170): + strength = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=0, interactive=True) + with gr.Row(): + init_image = gr.Textbox(label="Init image", lines=1, interactive=True, value = d.init_image) + # VIDEO INIT INNER-TAB + with gr.Tab('Video Init'): + with gr.Row(): + video_init_path = gr.Textbox(label="Video init path", lines=1, value = da.video_init_path, interactive=True) + with gr.Row(): + extract_from_frame = gr.Number(label="Extract from frame", value=da.extract_from_frame, interactive=True, precision=0) + extract_to_frame = gr.Number(label="Extract to frame", value=da.extract_to_frame, interactive=True, precision=0) + extract_nth_frame = gr.Number(label="Extract nth frame", value=da.extract_nth_frame, interactive=True, precision=0) + overwrite_extracted_frames = gr.Checkbox(label="Overwrite extracted frames", value=False, interactive=True) + use_mask_video = gr.Checkbox(label="Use mask video", value=False, interactive=True) + with gr.Row(): + video_mask_path = gr.Textbox(label="Video mask path", lines=1, value = da.video_mask_path, interactive=True) + # MASK INIT INNER-TAB + with gr.Tab('Mask Init'): + with gr.Row(): + use_mask = gr.Checkbox(label="Use mask", value=d.use_mask, interactive=True) + use_alpha_as_mask = gr.Checkbox(label="Use alpha as mask", value=d.use_alpha_as_mask, interactive=True) + invert_mask = gr.Checkbox(label="Invert mask", value=d.invert_mask, interactive=True) + overlay_mask = gr.Checkbox(label="Overlay mask", value=d.overlay_mask, interactive=True) + with gr.Row(): + mask_file = gr.Textbox(label="Mask file", lines=1, interactive=True, value = d.mask_file) + with gr.Row(): + mask_overlay_blur = gr.Slider(label="Mask overlay blur", minimum=0, maximum=64, step=1, value=d.mask_overlay_blur, interactive=True) + with gr.Row(): + choice = mask_fill_choices[d.fill] + fill = gr.Radio(label='Mask fill', choices=mask_fill_choices, value=choice, type="index") + with gr.Row(): + full_res_mask = gr.Checkbox(label="Full res mask", value=d.full_res_mask, interactive=True) + full_res_mask_padding = gr.Slider(minimum=0, maximum=512, step=1, label="Full res mask padding", value=d.full_res_mask_padding, interactive=True) + # PARSEQ ACCORD + with gr.Accordion('Parseq', open=False): + gr.HTML(""" + Use an sd-parseq manifest for your animation (leave blank to ignore).

    +

    + Note that parseq overrides: +

      +
    • Run: seed, subseed, subseed strength.
    • +
    • Keyframes: generation settings (noise, strength, contrast, scale).
    • +
    • Keyframes: motion parameters for 2D and 3D (angle, zoom, translation, rotation, perspective flip).
    • +
    +

    + + Parseq does not override: +
      +
    • Run: Sampler, Width, Height, tiling, resize seed.
    • +
    • Keyframes: animation settings (animation mode, max frames, border)
    • +
    • Keyframes: coherence (color coherence & cadence)
    • +
    • Keyframes: depth warping
    • +
    • Output settings: all settings (including fps and max frames)
    • +
    +

    + """) + with gr.Row(): + parseq_manifest = gr.Textbox(label="Parseq Manifest (JSON or URL)", lines=4, value = dp.parseq_manifest, interactive=True) + with gr.Row(): + parseq_use_deltas = gr.Checkbox(label="Use delta values for movement parameters", value=dp.parseq_use_deltas, interactive=True) + def show_hybrid_html_msg(choice): + if choice not in ['2D','3D']: + return gr.update(visible=True) + else: + return gr.update(visible=False) + def change_hybrid_tab_status(choice): + if choice in ['2D','3D']: + return gr.update(visible=True) + else: + return gr.update(visible=False) + # CONTROLNET TAB + with gr.Tab('ControlNet'): + gr.HTML(""" + Requires the ControlNet extension to be installed.

    +

    + *Work In Progress*. All params below are going to be keyframable at some point. If you want to speedup the integration, join Deforum's development. 😉 +

    + + Due to ControlNet base extension's inner works it needs its models to be located at 'extensions/deforum-for-automatic1111-webui/models'. So copy, symlink or move them there until a more elegant solution is found. And, as of now, it requires use_init checked for the first run. The ControlNet extension version used in the dev process is a24089a62e70a7fae44b7bf35b51fd584dd55e25, if even with all the other options above used it still breaks, upgrade/downgrade your CN version to this one. +

    + """) + controlnet_dict = setup_controlnet_ui() + # HYBRID VIDEO TAB + with gr.Tab('Hybrid Video'): + # this html only shows when not in 2d/3d mode + hybrid_msg_html = gr.HTML(value='Please, change animation mode to 2D or 3D to enable Hybrid Mode',visible=False, elem_id='hybrid_msg_html') + # HYBRID INFO ACCORD + with gr.Accordion("Info & Help", open=False): + hybrid_html = "

    Hybrid Video Compositing in 2D/3D Modeby reallybigname

    " + hybrid_html += "
    • Composite video with previous frame init image in 2D or 3D animation_mode (not for Video Input mode)
    • " + hybrid_html += "
    • Uses your Init settings for video_init_path, extract_nth_frame, overwrite_extracted_frames
    • " + hybrid_html += "
    • In Keyframes tab, you can also set color_coherence = 'Video Input'
    • " + hybrid_html += "
    • color_coherence_video_every_N_frames lets you only match every N frames
    • " + hybrid_html += "
    • Color coherence may be used with hybrid composite off, to just use video color.
    • " + hybrid_html += "
    • Hybrid motion may be used with hybrid composite off, to just use video motion.
    " + hybrid_html += "Hybrid Video Schedules" + hybrid_html += "
    • The alpha schedule controls overall alpha for video mix, whether using a composite mask or not.
    • " + hybrid_html += "
    • The hybrid_comp_mask_blend_alpha_schedule only affects the 'Blend' hybrid_comp_mask_type.
    • " + hybrid_html += "
    • Mask contrast schedule is from 0-255. Normal is 1. Affects all masks.
    • " + hybrid_html += "
    • Autocontrast low/high cutoff schedules 0-100. Low 0 High 100 is full range.
      (hybrid_comp_mask_auto_contrast must be enabled)
    " + hybrid_html += "Click Here for more info/ a Guide." + gr.HTML(hybrid_html) + # HYBRID SETTINGS ACCORD + with gr.Accordion("Hybrid Settings", open=True) as hybrid_settings_accord: + with gr.Row(variant='compact'): + with gr.Column(min_width=340): + with gr.Row(variant='compact'): + hybrid_generate_inputframes = gr.Checkbox(label="Generate inputframes", value=False, interactive=True) + hybrid_composite = gr.Checkbox(label="Hybrid composite", value=False, interactive=True) + with gr.Column(min_width=340) as hybrid_2nd_column: + with gr.Row(variant='compact'): + hybrid_use_first_frame_as_init_image = gr.Checkbox(label="First frame as init image", value=da.hybrid_use_first_frame_as_init_image, interactive=True, visible=False) + hybrid_motion_use_prev_img = gr.Checkbox(label="Motion use prev img", value=False, interactive=True, visible=False) + with gr.Row() as hybrid_flow_row: + with gr.Column(variant='compact'): + with gr.Row(variant='compact'): + hybrid_motion = gr.Radio(['None', 'Optical Flow', 'Perspective', 'Affine'], label="Hybrid motion", value=da.hybrid_motion, elem_id="hybrid_motion") + with gr.Column(variant='compact'): + with gr.Row(variant='compact'): + with gr.Column(scale=1): + hybrid_flow_method = gr.Radio(['DIS Medium', 'Farneback'], label="Flow method", value=da.hybrid_flow_method, elem_id="hybrid_flow_method", visible=False) + hybrid_comp_mask_type = gr.Radio(['None', 'Depth', 'Video Depth', 'Blend', 'Difference'], label="Comp mask type", value=da.hybrid_comp_mask_type, elem_id="hybrid_comp_mask_type", visible=False) + with gr.Row(visible=False, variant='compact') as hybrid_comp_mask_row: + hybrid_comp_mask_equalize = gr.Radio(['None', 'Before', 'After', 'Both'], label="Comp mask equalize", value=da.hybrid_comp_mask_equalize, elem_id="hybrid_comp_mask_equalize") + with gr.Column(variant='compact'): + hybrid_comp_mask_auto_contrast = gr.Checkbox(label="Comp mask auto contrast", value=False, interactive=True) + hybrid_comp_mask_inverse = gr.Checkbox(label="Comp mask inverse", value=False, interactive=True) + with gr.Row(variant='compact'): + hybrid_comp_save_extra_frames = gr.Checkbox(label="Comp save extra frames", value=False, interactive=True) + # HYBRID SCHEDULES ACCORD + with gr.Accordion("Hybrid Schedules", open=False, visible=False) as hybrid_sch_accord: + with gr.Row(variant='compact') as hybrid_comp_alpha_schedule_row: + hybrid_comp_alpha_schedule = gr.Textbox(label="Comp alpha schedule", lines=1, value = da.hybrid_comp_alpha_schedule, interactive=True) + with gr.Row(variant='compact', visible=False) as hybrid_comp_mask_blend_alpha_schedule_row: + hybrid_comp_mask_blend_alpha_schedule = gr.Textbox(label="Comp mask blend alpha schedule", lines=1, value = da.hybrid_comp_mask_blend_alpha_schedule, interactive=True, elem_id="hybridelemtest") + with gr.Row(variant='compact', visible=False) as hybrid_comp_mask_contrast_schedule_row: + hybrid_comp_mask_contrast_schedule = gr.Textbox(label="Comp mask contrast schedule", lines=1, value = da.hybrid_comp_mask_contrast_schedule, interactive=True) + with gr.Row(variant='compact', visible=False) as hybrid_comp_mask_auto_contrast_cutoff_high_schedule_row : + hybrid_comp_mask_auto_contrast_cutoff_high_schedule = gr.Textbox(label="Comp mask auto contrast cutoff high schedule", lines=1, value = da.hybrid_comp_mask_auto_contrast_cutoff_high_schedule, interactive=True) + with gr.Row(variant='compact', visible=False) as hybrid_comp_mask_auto_contrast_cutoff_low_schedule_row: + hybrid_comp_mask_auto_contrast_cutoff_low_schedule = gr.Textbox(label="Comp mask auto contrast cutoff low schedule", lines=1, value = da.hybrid_comp_mask_auto_contrast_cutoff_low_schedule, interactive=True) + # HUMANS MASKING ACCORD + with gr.Accordion("Humans Masking", open=False, visible=False) as humans_masking_accord: + with gr.Row(variant='compact'): + hybrid_generate_human_masks = gr.Radio(['None', 'PNGs', 'Video', 'Both'], label="Generate human masks", value=da.hybrid_generate_human_masks, elem_id="hybrid_generate_human_masks") + # OUTPUT TAB + with gr.Tab('Output'): + # VID OUTPUT ACCORD + with gr.Accordion('Video Output Settings', open=True): + with gr.Row(variant='compact') as fps_out_format_row: + fps = gr.Slider(label="FPS", value=dv.fps, minimum=1, maximum=240, step=1) + # NOT VISIBLE AS OF 11-02-23 moving to ffmpeg-only! + output_format = gr.Dropdown(visible=False, label="Output format", choices=['FFMPEG mp4'], value='FFMPEG mp4', type="value", elem_id="output_format", interactive=True) + with gr.Column(variant='compact'): + with gr.Row(variant='compact') as soundtrack_row: + add_soundtrack = gr.Radio(['None', 'File', 'Init Video'], label="Add soundtrack", value=dv.add_soundtrack) + soundtrack_path = gr.Textbox(label="Soundtrack path", lines=1, interactive=True, value = dv.soundtrack_path) + with gr.Row(variant='compact'): + skip_video_for_run_all = gr.Checkbox(label="Skip video for run all", value=dv.skip_video_for_run_all, interactive=True) + store_frames_in_ram = gr.Checkbox(label="Store frames in ram", value=dv.store_frames_in_ram, interactive=True) + save_depth_maps = gr.Checkbox(label="Save depth maps", value=da.save_depth_maps, interactive=True) + # the following param only shows for windows and linux users! + make_gif = gr.Checkbox(label="Make GIF", value=dv.make_gif, interactive=True) + with gr.Row(equal_height=True, variant='compact', visible=(True if dr.current_user_os in ["Windows", "Linux", "Mac"] else False)) as r_upscale_row: + r_upscale_video = gr.Checkbox(label="Upscale", value=dv.r_upscale_video, interactive=True) + r_upscale_model = gr.Dropdown(label="Upscale model", choices=['realesr-animevideov3', 'realesrgan-x4plus', 'realesrgan-x4plus-anime'], interactive=True, value = dv.r_upscale_model, type="value") + r_upscale_factor = gr.Dropdown(choices=['x2', 'x3', 'x4'], label="Upscale factor", interactive=True, value=dv.r_upscale_factor, type="value") + r_upscale_keep_imgs = gr.Checkbox(label="Keep Imgs", value=dv.r_upscale_keep_imgs, interactive=True) + with gr.Accordion('FFmpeg settings', visible=True, open=False) as ffmpeg_quality_accordion: + with gr.Row(equal_height=True, variant='compact', visible=True) as ffmpeg_set_row: + ffmpeg_crf = gr.Slider(minimum=0, maximum=51, step=1, label="CRF", value=dv.ffmpeg_crf, interactive=True) + ffmpeg_preset = gr.Dropdown(label="Preset", choices=['veryslow', 'slower', 'slow', 'medium', 'fast', 'faster', 'veryfast', 'superfast', 'ultrafast'], interactive=True, value = dv.ffmpeg_preset, type="value") + with gr.Row(equal_height=True, variant='compact', visible=True) as ffmpeg_location_row: + ffmpeg_location = gr.Textbox(label="Location", lines=1, interactive=True, value = dv.ffmpeg_location) + # FRAME INTERPOLATION TAB + with gr.Tab('Frame Interoplation') as frame_interp_tab: + with gr.Accordion('Important notes and Help', open=False): + gr.HTML(""" + Use RIFE / FILM Frame Interpolation to smooth out, slow-mo (or both) any video.

    +

    + Supported engines: +

      +
    • RIFE v4.6 and FILM.
    • +
    +

    +

    + Important notes: +

      +
    • Frame Interpolation will *not* run if any of the following are enabled: 'Store frames in ram' / 'Skip video for run all'.
    • +
    • Audio (if provided) will *not* be transferred to the interpolated video if Slow-Mo is enabled.
    • +
    • 'add_soundtrack' and 'soundtrack_path' aren't being honoured in "Interpolate an existing video" mode. Original vid audio will be used instead with the same slow-mo rules above.
    • +
    +

    + """) + with gr.Column(variant='compact'): + with gr.Row(variant='compact'): + # Interpolation Engine + frame_interpolation_engine = gr.Dropdown(label="Engine", choices=['None','RIFE v4.6','FILM'], value=dv.frame_interpolation_engine, type="value", elem_id="frame_interpolation_engine", interactive=True) + frame_interpolation_slow_mo_enabled = gr.Checkbox(label="Slow Mo", elem_id="frame_interpolation_slow_mo_enabled", value=dv.frame_interpolation_slow_mo_enabled, interactive=True, visible=False) + # If this is set to True, we keep all of the interpolated frames in a folder. Default is False - means we delete them at the end of the run + frame_interpolation_keep_imgs = gr.Checkbox(label="Keep Imgs", elem_id="frame_interpolation_keep_imgs", value=dv.frame_interpolation_keep_imgs, interactive=True, visible=False) + with gr.Row(variant='compact', visible=False) as frame_interp_amounts_row: + with gr.Column(min_width=180) as frame_interp_x_amount_column: + # How many times to interpolate (interp X) + frame_interpolation_x_amount = gr.Slider(minimum=2, maximum=10, step=1, label="Interp X", value=dv.frame_interpolation_x_amount, interactive=True) + with gr.Column(min_width=180, visible=False) as frame_interp_slow_mo_amount_column: + # Interp Slow-Mo (setting final output fps, not really doing anything direclty with RIFE/FILM) + frame_interpolation_slow_mo_amount = gr.Slider(minimum=2, maximum=10, step=1, label="Slow-Mo X", value=dv.frame_interpolation_x_amount, interactive=True) + # TODO: move these from here when done + def hide_slow_mo(choice): + return gr.update(visible=True) if choice else gr.update(visible=False) + def hide_interp_by_interp_status(choice): + return gr.update(visible=False) if choice == 'None' else gr.update(visible=True) + def change_interp_x_max_limit(engine_name, current_value): + if engine_name == 'FILM': + return gr.update(maximum=300) + elif current_value > 10: + return gr.update(maximum=10, value=2) + return gr.update(maximum=10) + frame_interpolation_slow_mo_enabled.change(fn=hide_slow_mo,inputs=frame_interpolation_slow_mo_enabled,outputs=frame_interp_slow_mo_amount_column) + interp_hide_list = [frame_interpolation_slow_mo_enabled,frame_interpolation_keep_imgs,frame_interp_amounts_row] + for output in interp_hide_list: + frame_interpolation_engine.change(fn=hide_interp_by_interp_status,inputs=frame_interpolation_engine,outputs=output) + frame_interpolation_engine.change(fn=change_interp_x_max_limit,inputs=[frame_interpolation_engine,frame_interpolation_x_amount],outputs=frame_interpolation_x_amount) + with gr.Row(visible=False) as interp_existing_video_row: + # Intrpolate any existing video from the connected PC + with gr.Accordion('Interpolate an existing video', open=False) as interp_existing_video_accord: + # A drag-n-drop UI box to which the user uploads a *single* (at this stage) video + vid_to_interpolate_chosen_file = gr.File(label="Video to Interpolate", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_interpolate_chosen_file") + with gr.Row(variant='compact'): + # Non interactive textbox showing uploaded input vid total Frame Count + in_vid_frame_count_window = gr.Textbox(label="In Frame Count", lines=1, interactive=False, value='---') + # Non interactive textbox showing uploaded input vid FPS + in_vid_fps_ui_window = gr.Textbox(label="In FPS", lines=1, interactive=False, value='---') + # Non interactive textbox showing expected output interpolated video FPS + out_interp_vid_estimated_fps = gr.Textbox(label="Interpolated Vid FPS", value='---') + # This is the actual button that's pressed to initiate the interpolation: + interpolate_button = gr.Button(value="*Interpolate uploaded video*") + # Show a text about CLI outputs: + gr.HTML("* check your CLI for outputs") + # make the functin call when the interpolation button is clicked + interpolate_button.click(upload_vid_to_interpolate,inputs=[vid_to_interpolate_chosen_file, frame_interpolation_engine, frame_interpolation_x_amount, frame_interpolation_slow_mo_enabled, frame_interpolation_slow_mo_amount, frame_interpolation_keep_imgs, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, in_vid_fps_ui_window]) + [change_fn.change(set_interp_out_fps, inputs=[frame_interpolation_x_amount, frame_interpolation_slow_mo_enabled, frame_interpolation_slow_mo_amount, in_vid_fps_ui_window], outputs=out_interp_vid_estimated_fps) for change_fn in [frame_interpolation_x_amount, frame_interpolation_slow_mo_amount, frame_interpolation_slow_mo_enabled]] + # Populate the above FPS and FCount values as soon as a video is uploaded to the FileUploadBox (vid_to_interpolate_chosen_file) + vid_to_interpolate_chosen_file.change(gradio_f_interp_get_fps_and_fcount,inputs=[vid_to_interpolate_chosen_file, frame_interpolation_x_amount, frame_interpolation_slow_mo_enabled, frame_interpolation_slow_mo_amount],outputs=[in_vid_fps_ui_window,in_vid_frame_count_window, out_interp_vid_estimated_fps]) + #TODO: move this from here + interp_hide_list = [frame_interpolation_slow_mo_enabled,frame_interpolation_keep_imgs,frame_interp_amounts_row,interp_existing_video_row] + for output in interp_hide_list: + frame_interpolation_engine.change(fn=hide_interp_by_interp_status,inputs=frame_interpolation_engine,outputs=output) + # TODO: add upscalers parameters to the settings and make them a part of the pipeline + # VIDEO UPSCALE TAB + with gr.Tab('Video Upscaling'): + vid_to_upscale_chosen_file = gr.File(label="Video to Upscale", interactive=True, file_count="single", file_types=["video"], elem_id="vid_to_upscale_chosen_file") + with gr.Column(): + # NCNN UPSCALE TAB + with gr.Tab('Upscale V2') as ncnn_upscale_tab: + with gr.Row(variant='compact') as ncnn_upload_vid_stats_row: + # Non interactive textbox showing uploaded input vid total Frame Count + ncnn_upscale_in_vid_frame_count_window = gr.Textbox(label="In Frame Count", lines=1, interactive=False, value='---') + # Non interactive textbox showing uploaded input vid FPS + ncnn_upscale_in_vid_fps_ui_window = gr.Textbox(label="In FPS", lines=1, interactive=False, value='---') + # Non interactive textbox showing uploaded input resolution + ncnn_upscale_in_vid_res = gr.Textbox(label="In Res", lines=1, interactive=False, value='---') + # Non interactive textbox showing expected output resolution + ncnn_upscale_out_vid_res = gr.Textbox(label="Out Res", value='---') + with gr.Column(): + with gr.Row(variant='compact', visible=(True if dr.current_user_os in ["Windows", "Linux", "Mac"] else False)) as ncnn_actual_upscale_row: + ncnn_upscale_model = gr.Dropdown(label="Upscale model", choices=['realesr-animevideov3', 'realesrgan-x4plus', 'realesrgan-x4plus-anime'], interactive=True, value = "realesr-animevideov3", type="value") + ncnn_upscale_factor = gr.Dropdown(choices=['x2', 'x3', 'x4'], label="Upscale factor", interactive=True, value="x2", type="value") + ncnn_upscale_keep_imgs = gr.Checkbox(label="Keep Imgs", value=True, interactive=True) # fix value + ncnn_upscale_btn = gr.Button(value="*Upscale uploaded video*") + ncnn_upscale_btn.click(ncnn_upload_vid_to_upscale,inputs=[vid_to_upscale_chosen_file, ncnn_upscale_in_vid_fps_ui_window, ncnn_upscale_in_vid_res, ncnn_upscale_out_vid_res, ncnn_upscale_model, ncnn_upscale_factor, ncnn_upscale_keep_imgs, ffmpeg_location, ffmpeg_crf, ffmpeg_preset]) + with gr.Tab('Upscale V1'): + with gr.Column(): + selected_tab = gr.State(value=0) + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Slider(label="Width", minimum=1, maximum=7680, step=1, value=512, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(label="Height", minimum=1, maximum=7680, step=1, value=512, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in sh.sd_upscalers], value=sh.sd_upscalers[3].name) + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in sh.sd_upscalers], value=sh.sd_upscalers[0].name) + with FormRow(): + with gr.Column(scale=3): + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + with gr.Column(scale=1, min_width=80): + upscale_keep_imgs = gr.Checkbox(label="Keep Imgs", elem_id="upscale_keep_imgs", value=True, interactive=True) + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + # This is the actual button that's pressed to initiate the Upscaling: + upscale_btn = gr.Button(value="*Upscale uploaded video*") + # Show a text about CLI outputs: + gr.HTML("* check your CLI for outputs") + # make the function call when the UPSCALE button is clicked + upscale_btn.click(upload_vid_to_upscale,inputs=[vid_to_upscale_chosen_file, selected_tab, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_keep_imgs, ffmpeg_location, ffmpeg_crf, ffmpeg_preset]) + # STITCH FRAMES TO VID TAB + with gr.Tab('Frames to Video') as stitch_imgs_to_vid_row: + with gr.Row(visible=False): + path_name_modifier = gr.Dropdown(label="Path name modifier", choices=['x0_pred', 'x'], value=dv.path_name_modifier, type="value", elem_id="path_name_modifier", interactive=True, visible=False) + gr.HTML(""" +

    + Important Notes: +

      +
    • Enter relative to webui folder or Full-Absolute path, and make sure it ends with something like this: '20230124234916_%05d.png', just replace 20230124234916 with your batch ID. The %05d is important, don't forget it!
    • +
    + """) + with gr.Row(variant='compact'): + image_path = gr.Textbox(label="Image path", lines=1, interactive=True, value = dv.image_path) + with gr.Row(visible=False): + mp4_path = gr.Textbox(label="MP4 path", lines=1, interactive=True, value = dv.mp4_path) + # not visible as of 06-02-23 since render_steps is disabled as well and they work together. Need to fix both. + with gr.Row(visible=False): + # rend_step Never worked - set to visible false 28-1-23 # MOVE OUT FROM HERE! + render_steps = gr.Checkbox(label="Render steps", value=dv.render_steps, interactive=True, visible=False) + ffmpeg_stitch_imgs_but = gr.Button(value="*Stitch frames to video*") + ffmpeg_stitch_imgs_but.click(direct_stitch_vid_from_frames,inputs=[image_path, fps, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, add_soundtrack, soundtrack_path]) + # **OLD + NON ACTIVES AREA** + with gr.Accordion(visible=False, label='INVISIBLE') as not_in_use_accordion: + # NOT VISIBLE AS OF 09-02-23 + mask_contrast_adjust = gr.Slider(label="Mask contrast adjust", minimum=0, maximum=1, step=0.01, value=d.mask_contrast_adjust, interactive=True) + mask_brightness_adjust = gr.Slider(label="Mask brightness adjust", minimum=0, maximum=1, step=0.01, value=d.mask_brightness_adjust, interactive=True) + from_img2img_instead_of_link = gr.Checkbox(label="from_img2img_instead_of_link", value=False, interactive=False, visible=False) + # INVISIBLE AS OF 08-02 (with static value of 8 for both W and H). Was in Perlin section before Perlin Octaves/Persistence + with gr.Column(min_width=200, visible=False): + perlin_w = gr.Slider(label="Perlin W", minimum=0.1, maximum=16, step=0.1, value=da.perlin_w, interactive=True) + perlin_h = gr.Slider(label="Perlin H", minimum=0.1, maximum=16, step=0.1, value=da.perlin_h, interactive=True) + with gr.Row(visible=False): + filename_format = gr.Textbox(label="Filename format", lines=1, interactive=True, value = d.filename_format, visible=False) + with gr.Row(visible=False): + save_settings = gr.Checkbox(label="save_settings", value=d.save_settings, interactive=True) + with gr.Row(visible=False): + save_samples = gr.Checkbox(label="save_samples", value=d.save_samples, interactive=True) + display_samples = gr.Checkbox(label="display_samples", value=False, interactive=False) + # NOT VISIBLE 11-02-23 htai + with gr.Accordion('Subseed controls & More', open=False, visible=False): + # Not visible until fixed, 06-02-23 + # NOT VISIBLE as of 11-02 - we have sch now. will delete the actual params in a later date + with gr.Row(variant='compact', visible=False): + seed_enable_extras = gr.Checkbox(label="Enable subseed controls", value=False) + n_batch = gr.Number(label="N Batch", value=d.n_batch, interactive=True, precision=0, visible=False) + with gr.Row(visible=False): + save_sample_per_step = gr.Checkbox(label="Save sample per step", value=d.save_sample_per_step, interactive=True) + show_sample_per_step = gr.Checkbox(label="Show sample per step", value=d.show_sample_per_step, interactive=True) + # Gradio's Change functions - hiding and renaming elements based on other elements + fps.change(fn=change_gif_button_visibility, inputs=fps, outputs=make_gif) + r_upscale_model.change(fn=update_r_upscale_factor, inputs=r_upscale_model, outputs=r_upscale_factor) + ncnn_upscale_model.change(fn=update_r_upscale_factor, inputs=ncnn_upscale_model, outputs=ncnn_upscale_factor) + ncnn_upscale_model.change(update_upscale_out_res_by_model_name, inputs=[ncnn_upscale_in_vid_res, ncnn_upscale_model], outputs=ncnn_upscale_out_vid_res) + ncnn_upscale_factor.change(update_upscale_out_res, inputs=[ncnn_upscale_in_vid_res, ncnn_upscale_factor], outputs=ncnn_upscale_out_vid_res) + vid_to_upscale_chosen_file.change(vid_upscale_gradio_update_stats,inputs=[vid_to_upscale_chosen_file, ncnn_upscale_factor],outputs=[ncnn_upscale_in_vid_fps_ui_window, ncnn_upscale_in_vid_frame_count_window, ncnn_upscale_in_vid_res, ncnn_upscale_out_vid_res]) + animation_mode.change(fn=change_max_frames_visibility, inputs=animation_mode, outputs=max_frames) + animation_mode.change(fn=change_diffusion_cadence_visibility, inputs=animation_mode, outputs=diffusion_cadence) + animation_mode.change(fn=disble_3d_related_stuff, inputs=animation_mode, outputs=depth_3d_warping_accord) + animation_mode.change(fn=disble_3d_related_stuff, inputs=animation_mode, outputs=fov_accord) + animation_mode.change(fn=disble_3d_related_stuff, inputs=animation_mode, outputs=only_3d_motion_column) + animation_mode.change(fn=enable_2d_related_stuff, inputs=animation_mode, outputs=only_2d_motion_column) + animation_mode.change(fn=disable_by_interpolation, inputs=animation_mode, outputs=force_grayscale_column) + animation_mode.change(fn=disable_pers_flip_accord, inputs=animation_mode, outputs=perspective_flip_accord) + animation_mode.change(fn=disable_pers_flip_accord, inputs=animation_mode, outputs=both_anim_mode_motion_params_column) + #Hybrid related: + animation_mode.change(fn=show_hybrid_html_msg, inputs=animation_mode, outputs=hybrid_msg_html) + animation_mode.change(fn=change_hybrid_tab_status, inputs=animation_mode, outputs=hybrid_sch_accord) + animation_mode.change(fn=change_hybrid_tab_status, inputs=animation_mode, outputs=hybrid_settings_accord) + animation_mode.change(fn=change_hybrid_tab_status, inputs=animation_mode, outputs=humans_masking_accord) + hybrid_comp_mask_type.change(fn=change_comp_mask_x_visibility, inputs=hybrid_comp_mask_type, outputs=hybrid_comp_mask_row) + hybrid_motion.change(fn=disable_by_non_optical_flow, inputs=hybrid_motion, outputs=hybrid_flow_method) + hybrid_motion.change(fn=disable_by_comp_mask, inputs=hybrid_motion, outputs=hybrid_motion_use_prev_img) + hybrid_composite.change(fn=disable_by_hybrid_composite_dynamic, inputs=[hybrid_composite, hybrid_comp_mask_type], outputs=hybrid_comp_mask_row) + hybrid_composite_outputs = [humans_masking_accord, hybrid_sch_accord, hybrid_comp_mask_type, hybrid_use_first_frame_as_init_image] + for output in hybrid_composite_outputs: + hybrid_composite.change(fn=disable_by_hybrid_composite, inputs=hybrid_composite, outputs=output) + hybrid_comp_mask_type_outputs = [hybrid_comp_mask_blend_alpha_schedule_row, hybrid_comp_mask_contrast_schedule_row, hybrid_comp_mask_auto_contrast_cutoff_high_schedule_row, hybrid_comp_mask_auto_contrast_cutoff_low_schedule_row] + for output in hybrid_comp_mask_type_outputs: + hybrid_comp_mask_type.change(fn=disable_by_comp_mask, inputs=hybrid_comp_mask_type, outputs=output) + # End of hybrid related + seed_behavior.change(fn=change_seed_iter_visibility, inputs=seed_behavior, outputs=seed_iter_N_row) + seed_behavior.change(fn=change_seed_schedule_visibility, inputs=seed_behavior, outputs=seed_schedule_row) + color_coherence.change(fn=change_color_coherence_video_every_N_frames_visibility, inputs=color_coherence, outputs=color_coherence_video_every_N_frames_row) + noise_type.change(fn=change_perlin_visibility, inputs=noise_type, outputs=perlin_row) + skip_video_for_run_all_outputs = [fps_out_format_row, soundtrack_row, ffmpeg_quality_accordion, store_frames_in_ram, make_gif, r_upscale_row] + for output in skip_video_for_run_all_outputs: + skip_video_for_run_all.change(fn=change_visibility_from_skip_video, inputs=skip_video_for_run_all, outputs=output) + # END OF UI TABS + stuff = locals() + stuff = {**stuff, **controlnet_dict} + stuff.pop('controlnet_dict') + return stuff + +### SETTINGS STORAGE UPDATE! 2023-01-27 +### To Reduce The Number Of Settings Overrides, +### They Are Being Passed As Dictionaries +### It Would Have Been Also Nice To Retrieve Them +### From Functions Like Deforumoutputargs(), +### But Over Time There Was Some Cross-Polination, +### So They Are Now Hardcoded As 'List'-Strings Below +### If you're adding a new setting, add it to one of the lists +### besides writing it in the setup functions above + +anim_args_names = str(r'''animation_mode, max_frames, border, + angle, zoom, translation_x, translation_y, translation_z, + rotation_3d_x, rotation_3d_y, rotation_3d_z, + enable_perspective_flip, + perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, perspective_flip_fv, + noise_schedule, strength_schedule, contrast_schedule, cfg_scale_schedule, pix2pix_img_cfg_scale_schedule, + enable_subseed_scheduling, subseed_schedule, subseed_strength_schedule, + enable_steps_scheduling, steps_schedule, + fov_schedule, near_schedule, far_schedule, + seed_schedule, + enable_sampler_scheduling, sampler_schedule, + mask_schedule, use_noise_mask, noise_mask_schedule, + enable_checkpoint_scheduling, checkpoint_schedule, + enable_clipskip_scheduling, clipskip_schedule, + kernel_schedule, sigma_schedule, amount_schedule, threshold_schedule, + color_coherence, color_coherence_video_every_N_frames, color_force_grayscale, + diffusion_cadence, + noise_type, perlin_w, perlin_h, perlin_octaves, perlin_persistence, + use_depth_warping, midas_weight, + padding_mode, sampling_mode, save_depth_maps, + video_init_path, extract_nth_frame, extract_from_frame, extract_to_frame, overwrite_extracted_frames, + use_mask_video, video_mask_path, + resume_from_timestring, resume_timestring''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') +hybrid_args_names = str(r'''hybrid_generate_inputframes, hybrid_generate_human_masks, hybrid_use_first_frame_as_init_image, + hybrid_motion, hybrid_motion_use_prev_img, hybrid_flow_method, hybrid_composite, hybrid_comp_mask_type, hybrid_comp_mask_inverse, + hybrid_comp_mask_equalize, hybrid_comp_mask_auto_contrast, hybrid_comp_save_extra_frames, + hybrid_comp_alpha_schedule, hybrid_comp_mask_blend_alpha_schedule, hybrid_comp_mask_contrast_schedule, + hybrid_comp_mask_auto_contrast_cutoff_high_schedule, hybrid_comp_mask_auto_contrast_cutoff_low_schedule''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') +args_names = str(r'''W, H, tiling, restore_faces, + seed, sampler, + seed_enable_extras, seed_resize_from_w, seed_resize_from_h, + steps, ddim_eta, + n_batch, + save_settings, save_samples, display_samples, + save_sample_per_step, show_sample_per_step, + batch_name, filename_format, + seed_behavior, seed_iter_N, + use_init, from_img2img_instead_of_link, strength_0_no_init, strength, init_image, + use_mask, use_alpha_as_mask, invert_mask, overlay_mask, + mask_file, mask_contrast_adjust, mask_brightness_adjust, mask_overlay_blur, + fill, full_res_mask, full_res_mask_padding, + reroll_blank_frames''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') +video_args_names = str(r'''skip_video_for_run_all, + fps, make_gif, output_format, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, + add_soundtrack, soundtrack_path, + r_upscale_video, r_upscale_model, r_upscale_factor, r_upscale_keep_imgs, + render_steps, + path_name_modifier, image_path, mp4_path, store_frames_in_ram, + frame_interpolation_engine, frame_interpolation_x_amount, frame_interpolation_slow_mo_enabled, frame_interpolation_slow_mo_amount, + frame_interpolation_keep_imgs''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') +parseq_args_names = str(r'''parseq_manifest, parseq_use_deltas''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') +loop_args_names = str(r'''use_looper, init_images, image_strength_schedule, blendFactorMax, blendFactorSlope, + tweening_frames_schedule, color_correction_factor''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') + +component_names = ['override_settings_with_file', 'custom_settings_file'] + anim_args_names +['animation_prompts', 'animation_prompts_positive', 'animation_prompts_negative'] + args_names + video_args_names + parseq_args_names + hybrid_args_names + loop_args_names + controlnet_component_names() +settings_component_names = [name for name in component_names if name not in video_args_names] + +def setup_deforum_setting_ui(self, is_img2img, is_extension = True): + ds = setup_deforum_setting_dictionary(self, is_img2img, is_extension) + return [ds[name] for name in (['btn'] + component_names)] + +def pack_anim_args(args_dict): + return {name: args_dict[name] for name in (anim_args_names + hybrid_args_names)} + +def pack_args(args_dict): + args_dict = {name: args_dict[name] for name in args_names} + args_dict['precision'] = 'autocast' + args_dict['scale'] = 7 + args_dict['subseed'] = -1 + args_dict['subseed_strength'] = 0 + args_dict['C'] = 4 + args_dict['f'] = 8 + args_dict['timestring'] = "" + args_dict['init_latent'] = None + args_dict['init_sample'] = None + args_dict['init_c'] = None + args_dict['noise_mask'] = None + args_dict['seed_internal'] = 0 + return args_dict + +def pack_video_args(args_dict): + return {name: args_dict[name] for name in video_args_names} + +def pack_parseq_args(args_dict): + return {name: args_dict[name] for name in parseq_args_names} + +def pack_loop_args(args_dict): + return {name: args_dict[name] for name in loop_args_names} + +def pack_controlnet_args(args_dict): + return {name: args_dict[name] for name in controlnet_component_names()} + +def process_args(args_dict_main): + override_settings_with_file = args_dict_main['override_settings_with_file'] + custom_settings_file = args_dict_main['custom_settings_file'] + args_dict = pack_args(args_dict_main) + anim_args_dict = pack_anim_args(args_dict_main) + video_args_dict = pack_video_args(args_dict_main) + parseq_args_dict = pack_parseq_args(args_dict_main) + loop_args_dict = pack_loop_args(args_dict_main) + controlnet_args_dict = pack_controlnet_args(args_dict_main) + + import json + + root = SimpleNamespace(**Root()) + root.p = args_dict_main['p'] + p = root.p + root.animation_prompts = json.loads(args_dict_main['animation_prompts']) + positive_prompts = args_dict_main['animation_prompts_positive'] + negative_prompts = args_dict_main['animation_prompts_negative'] + # remove --neg from negative_prompts if recieved by mistake + negative_prompts = negative_prompts.replace('--neg', '') + for key in root.animation_prompts: + animationPromptCurr = root.animation_prompts[key] + root.animation_prompts[key] = f"{positive_prompts} {animationPromptCurr} {'' if '--neg' in animationPromptCurr else '--neg'} {negative_prompts}" + from deforum_helpers.settings import load_args + + if override_settings_with_file: + load_args(args_dict, anim_args_dict, parseq_args_dict, loop_args_dict, controlnet_args_dict, custom_settings_file, root) + + if not os.path.exists(root.models_path): + os.mkdir(root.models_path) + + args = SimpleNamespace(**args_dict) + anim_args = SimpleNamespace(**anim_args_dict) + video_args = SimpleNamespace(**video_args_dict) + parseq_args = SimpleNamespace(**parseq_args_dict) + loop_args = SimpleNamespace(**loop_args_dict) + controlnet_args = SimpleNamespace(**controlnet_args_dict) + + p.width, p.height = map(lambda x: x - x % 64, (args.W, args.H)) + p.steps = args.steps + p.seed = args.seed + p.sampler_name = args.sampler + p.batch_size = args.n_batch + p.tiling = args.tiling + p.restore_faces = args.restore_faces + p.seed_enable_extras = args.seed_enable_extras + p.subseed = args.subseed + p.subseed_strength = args.subseed_strength + p.seed_resize_from_w = args.seed_resize_from_w + p.seed_resize_from_h = args.seed_resize_from_h + p.fill = args.fill + p.ddim_eta = args.ddim_eta + + # TODO: Handle batch name dynamically? + current_arg_list = [args, anim_args, video_args, parseq_args] + args.outdir = os.path.join(p.outpath_samples, args.batch_name) + root.outpath_samples = args.outdir + args.outdir = os.path.join(os.getcwd(), args.outdir) + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + args.seed = get_fixed_seed(args.seed) + + args.timestring = time.strftime('%Y%m%d%H%M%S') + args.strength = max(0.0, min(1.0, args.strength)) + + if not args.use_init: + args.init_image = None + + if anim_args.animation_mode == 'None': + anim_args.max_frames = 1 + elif anim_args.animation_mode == 'Video Input': + args.use_init = True + + return root, args, anim_args, video_args, parseq_args, loop_args, controlnet_args + +def print_args(args): + print("ARGS: /n") + for key, value in args.__dict__.items(): + print(f"{key}: {value}") + +# Local gradio-to-frame-interoplation function. *Needs* to stay here since we do Root() and use gradio elements directly, to be changed in the future +def upload_vid_to_interpolate(file, engine, x_am, sl_enabled, sl_am, keep_imgs, f_location, f_crf, f_preset, in_vid_fps): + # print msg and do nothing if vid not uploaded or interp_x not provided + if not file or engine == 'None': + return print("Please upload a video and set a proper value for 'Interp X'. Can't interpolate x0 times :)") + + root_params = Root() + f_models_path = root_params['models_path'] + + process_interp_vid_upload_logic(file, engine, x_am, sl_enabled, sl_am, keep_imgs, f_location, f_crf, f_preset, in_vid_fps, f_models_path, file.orig_name) + +# Local gradio-to-upscalers function. *Needs* to stay here since we do Root() and use gradio elements directly, to be changed in the future +def upload_vid_to_upscale(vid_to_upscale_chosen_file, selected_tab, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_keep_imgs, ffmpeg_location, ffmpeg_crf, ffmpeg_preset): + # print msg and do nothing if vid not uploaded + if not vid_to_upscale_chosen_file: + return print("Please upload a video :)") + + process_upscale_vid_upload_logic(vid_to_upscale_chosen_file, selected_tab, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, vid_to_upscale_chosen_file.orig_name, upscale_keep_imgs, ffmpeg_location, ffmpeg_crf, ffmpeg_preset) + +def ncnn_upload_vid_to_upscale(vid_path, in_vid_fps, in_vid_res, out_vid_res, upscale_model, upscale_factor, keep_imgs, f_location, f_crf, f_preset): + if vid_path is None: + print("Please upload a video :)") + return + root_params = Root() + f_models_path = root_params['models_path'] + current_user = root_params['current_user_os'] + process_ncnn_upscale_vid_upload_logic(vid_path, in_vid_fps, in_vid_res, out_vid_res, f_models_path, upscale_model, upscale_factor, keep_imgs, f_location, f_crf, f_preset, current_user) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/blank_frame_reroll.py b/extensions/deforum/scripts/deforum_helpers/blank_frame_reroll.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c5954921ffada030cdf24104b211e630a05295 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/blank_frame_reroll.py @@ -0,0 +1,24 @@ +from .generate import generate +#WebUI +from modules.shared import opts, cmd_opts, state + +def blank_frame_reroll(image, args, root, frame_idx): + patience = 10 + print("Blank frame detected! If you don't have the NSFW filter enabled, this may be due to a glitch!") + if args.reroll_blank_frames == 'reroll': + while not image.getbbox(): + print("Rerolling with +1 seed...") + args.seed += 1 + image = generate(args, root, frame_idx) + patience -= 1 + if patience == 0: + print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...") + state.interrupted = True + state.current_image = image + return None + elif args.reroll_blank_frames == 'interrupt': + print("Interrupting to save your eyes...") + state.interrupted = True + state.current_image = image + return None + return image \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/colors.py b/extensions/deforum/scripts/deforum_helpers/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec81e197ef2b918a352d04f57337b956137b0e6 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/colors.py @@ -0,0 +1,16 @@ +from skimage.exposure import match_histograms +import cv2 + +def maintain_colors(prev_img, color_match_sample, mode): + if mode == 'Match Frame 0 RGB': + return match_histograms(prev_img, color_match_sample, multichannel=True) + elif mode == 'Match Frame 0 HSV': + prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) + color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) + matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) + return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) + else: # Match Frame 0 LAB + prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) + color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) + matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) + return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/composable_masks.py b/extensions/deforum/scripts/deforum_helpers/composable_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a509a1a517df1fcca8c7e1dadaaa9be97f5769 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/composable_masks.py @@ -0,0 +1,198 @@ +# At the moment there are three types of masks: mask from variable, file mask and word mask +# Variable masks include init_mask for the predefined whole-video mask, frame_mask from video-masking system +# and human_mask for a model which better segments people in the background video +# They are put in {}-brackets +# Word masks are framed with <>-bracets, like: , +# File masks are put in []-brackes +# Empty strings are counted as the whole frame +# We want to put them all into a sequence of boolean operations + +# Example: +# \ +# (({human_mask} & [mask1.png]) ^ ) + +# Writing the parser for the boolean sequence +# using regex and PIL operations +import re +from .load_images import get_mask_from_file, check_mask_for_errors, blank_if_none +from .word_masking import get_word_mask +from torch import Tensor +import PIL +from PIL import Image, ImageChops + +# val_masks: name, PIL Image mask +# Returns an image in mode '1' (needed for bool ops), convert to 'L' in the sender function +def compose_mask(root, args, mask_seq, val_masks, frame_image, inner_idx:int = 0): + # Compose_mask recursively: go to inner brackets, then b-op it and go upstack + + # Step 1: + # recursive parenthesis pass + # regex is not powerful here + + seq = "" + inner_seq = "" + parentheses_counter = 0 + + for c in mask_seq: + if c == ')': + parentheses_counter = parentheses_counter - 1 + if parentheses_counter > 0: + inner_seq += c + if c == '(': + parentheses_counter = parentheses_counter + 1 + if parentheses_counter == 0: + if len(inner_seq) > 0: + inner_idx += 1 + seq += compose_mask(root, args, inner_seq, val_masks, frame_image, inner_idx) + inner_seq = "" + else: + seq += c + + if parentheses_counter != 0: + raise Exception('Mismatched parentheses in {mask_seq}!') + + mask_seq = seq + + # Step 2: + # Load the word masks and file masks as vars + + # File masks + pattern = r'\[(?P[\S\s]*?)\]' + + def parse(match_object): + nonlocal inner_idx + inner_idx += 1 + content = match_object.groupdict()['inner'] + val_masks[str(inner_idx)] = get_mask_from_file(content, args).convert('1') # TODO: add caching + return f"{{{inner_idx}}}" + + mask_seq = re.sub(pattern, parse, mask_seq) + + # Word masks + pattern = r'<(?P[\S\s]*?)>' + + def parse(match_object): + nonlocal inner_idx + inner_idx += 1 + content = match_object.groupdict()['inner'] + val_masks[str(inner_idx)] = get_word_mask(root, frame_image, content).convert('1') + return f"{{{inner_idx}}}" + + mask_seq = re.sub(pattern, parse, mask_seq) + + # Now that all inner parenthesis are eliminated we're left with a linear string + + # Step 3: + # Boolean operations with masks + # Operators: invert !, and &, or |, xor ^, difference \ + + # Invert vars with '!' + pattern = r'![\S\s]*{(?P[\S\s]*?)}' + def parse(match_object): + nonlocal inner_idx + inner_idx += 1 + content = match_object.groupdict()['inner'] + savename = content + if content in root.mask_preset_names: + inner_idx += 1 + savename = str(inner_idx) + val_masks[savename] = ImageChops.invert(val_masks[content]) + return f"{{{savename}}}" + + mask_seq = re.sub(pattern, parse, mask_seq) + + # Multiply neighbouring vars with '&' + # Wait for replacements stall (like in Markov chains) + while True: + pattern = r'{(?P[\S\s]*?)}[\s]*&[\s]*{(?P[\S\s]*?)}' + def parse(match_object): + nonlocal inner_idx + inner_idx += 1 + content = match_object.groupdict()['inner1'] + content_second = match_object.groupdict()['inner2'] + savename = content + if content in root.mask_preset_names: + inner_idx += 1 + savename = str(inner_idx) + val_masks[savename] = ImageChops.logical_and(val_masks[content], val_masks[content_second]) + return f"{{{savename}}}" + + prev_mask_seq = mask_seq + mask_seq = re.sub(pattern, parse, mask_seq) + if mask_seq is prev_mask_seq: + break + + # Add neighbouring vars with '|' + while True: + pattern = r'{(?P[\S\s]*?)}[\s]*?\|[\s]*?{(?P[\S\s]*?)}' + def parse(match_object): + nonlocal inner_idx + inner_idx += 1 + content = match_object.groupdict()['inner1'] + content_second = match_object.groupdict()['inner2'] + savename = content + if content in root.mask_preset_names: + inner_idx += 1 + savename = str(inner_idx) + val_masks[savename] = ImageChops.logical_or(val_masks[content], val_masks[content_second]) + return f"{{{savename}}}" + + prev_mask_seq = mask_seq + mask_seq = re.sub(pattern, parse, mask_seq) + if mask_seq is prev_mask_seq: + break + + # Mutually exclude neighbouring vars with '^' + while True: + pattern = r'{(?P[\S\s]*?)}[\s]*\^[\s]*{(?P[\S\s]*?)}' + def parse(match_object): + nonlocal inner_idx + inner_idx += 1 + content = match_object.groupdict()['inner1'] + content_second = match_object.groupdict()['inner2'] + savename = content + if content in root.mask_preset_names: + inner_idx += 1 + savename = str(inner_idx) + val_masks[savename] = ImageChops.logical_xor(val_masks[content], val_masks[content_second]) + return f"{{{savename}}}" + + prev_mask_seq = mask_seq + mask_seq = re.sub(pattern, parse, mask_seq) + if mask_seq is prev_mask_seq: + break + + # Set-difference the regions with '\' + while True: + pattern = r'{(?P[\S\s]*?)}[\s]*\\[\s]*{(?P[\S\s]*?)}' + def parse(match_object): + content = match_object.groupdict()['inner1'] + content_second = match_object.groupdict()['inner2'] + savename = content + if content in root.mask_preset_names: + nonlocal inner_idx + inner_idx += 1 + savename = str(inner_idx) + val_masks[savename] = ImageChops.logical_and(val_masks[content], ImageChops.invert(val_masks[content_second])) + return f"{{{savename}}}" + + prev_mask_seq = mask_seq + mask_seq = re.sub(pattern, parse, mask_seq) + if mask_seq is prev_mask_seq: + break + + # Step 4: + # Output + # Now we should have a single var left to return. If not, raise an error message + pattern = r'{(?P[\S\s]*?)}' + matches = re.findall(pattern, mask_seq) + + if len(matches) != 1: + raise Exception(f'Wrong composable mask expression format! Broken mask sequence: {mask_seq}') + + return f"{{{matches[0]}}}" + +def compose_mask_with_check(root, args, mask_seq, val_masks, frame_image): + for k, v in val_masks.items(): + val_masks[k] = blank_if_none(v, args.W, args.H, '1').convert('1') + return check_mask_for_errors(val_masks[compose_mask(root, args, mask_seq, val_masks, frame_image, 0)[1:-1]].convert('L')) diff --git a/extensions/deforum/scripts/deforum_helpers/deforum_controlnet.py b/extensions/deforum/scripts/deforum_helpers/deforum_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b72c8d4723a32721ce3c1242d6b8b33a7b21b2 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/deforum_controlnet.py @@ -0,0 +1,462 @@ +# This helper script is responsible for ControlNet/Deforum integration +# https://github.com/Mikubill/sd-webui-controlnet — controlnet repo + +import os, sys +import gradio as gr +import scripts +import modules.scripts as scrpts +from PIL import Image +import numpy as np +from modules.processing import process_images +from .rich import console +from rich.table import Table +from rich import box + +has_controlnet = None + +def find_controlnet(): + global has_controlnet + if has_controlnet is not None: + return has_controlnet + + try: + from scripts import controlnet + except Exception as e: + print(f'\033[33mFailed to import controlnet! The exact error is {e}. Deforum support for ControlNet will not be activated\033[0m') + has_controlnet = False + return False + has_controlnet = True + print(f"\033[0;32m*Deforum ControlNet support: enabled*\033[0m") + return True + +# The most parts below are plainly copied from controlnet.py +# TODO: come up with a cleaner way + +gradio_compat = True +try: + from distutils.version import LooseVersion + from importlib_metadata import version + if LooseVersion(version("gradio")) < LooseVersion("3.10"): + gradio_compat = False +except ImportError: + pass + +# svgsupports +svgsupport = False +try: + import io + import base64 + from svglib.svglib import svg2rlg + from reportlab.graphics import renderPM + svgsupport = True +except ImportError: + pass + +def ControlnetArgs(): + controlnet_enabled = False + controlnet_scribble_mode = False + controlnet_rgbbgr_mode = False + controlnet_lowvram = False + controlnet_module = "none" + controlnet_model = "None" + controlnet_weight = 1.0 + controlnet_guidance_strength = 1.0 + blendFactorMax = "0:(0.35)" + blendFactorSlope = "0:(0.25)" + tweening_frames_schedule = "0:(20)" + color_correction_factor = "0:(0.075)" + return locals() + +def setup_controlnet_ui_raw(): + # Already under an accordion + from scripts import controlnet + from scripts.controlnet import update_cn_models, cn_models, cn_models_names + + refresh_symbol = '\U0001f504' # 🔄 + switch_values_symbol = '\U000021C5' # ⇅ + model_dropdowns = [] + infotext_fields = [] + # Main part + class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", **kwargs) + + def get_block_name(self): + return "button" + + from scripts.processor import canny, midas, midas_normal, leres, hed, mlsd, openpose, pidinet, simple_scribble, fake_scribble, uniformer + + preprocessor = { + "none": lambda x, *args, **kwargs: x, + "canny": canny, + "depth": midas, + "depth_leres": leres, + "hed": hed, + "mlsd": mlsd, + "normal_map": midas_normal, + "openpose": openpose, + # "openpose_hand": openpose_hand, + "pidinet": pidinet, + # "scribble": simple_scribble, + "fake_scribble": fake_scribble, + "segmentation": uniformer, + } + + # Copying the main ControlNet widgets while getting rid of static elements such as the scribble pad + with gr.Row(): + controlnet_enabled = gr.Checkbox(label='Enable', value=False) + controlnet_scribble_mode = gr.Checkbox(label='Scribble Mode (Invert colors)', value=False, visible=False) + controlnet_rgbbgr_mode = gr.Checkbox(label='RGB to BGR', value=False, visible=False) + controlnet_lowvram = gr.Checkbox(label='Low VRAM', value=False, visible=False) + + def refresh_all_models(*inputs): + update_cn_models() + + dd = inputs[0] + selected = dd if dd in cn_models else "None" + return gr.Dropdown.update(value=selected, choices=list(cn_models.keys())) + + with gr.Row(visible=False) as cn_mod_row: + controlnet_module = gr.Dropdown(list(preprocessor.keys()), label=f"Preprocessor", value="none") + controlnet_model = gr.Dropdown(list(cn_models.keys()), label=f"Model", value="None") + refresh_models = ToolButton(value=refresh_symbol) + refresh_models.click(refresh_all_models, controlnet_model, controlnet_model) + # ctrls += (refresh_models, ) + with gr.Row(visible=False) as cn_weight_row: + controlnet_weight = gr.Slider(label=f"Weight", value=1.0, minimum=0.0, maximum=2.0, step=.05) + controlnet_guidance_strength = gr.Slider(label="Guidance strength (T)", value=1.0, minimum=0.0, maximum=1.0, interactive=True) + # ctrls += (module, model, weight,) + # model_dropdowns.append(model) + + # advanced options + controlnet_advanced = gr.Column(visible=False) + with controlnet_advanced: + controlnet_processor_res = gr.Slider(label="Annotator resolution", value=64, minimum=64, maximum=2048, interactive=False) + controlnet_threshold_a = gr.Slider(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False) + controlnet_threshold_b = gr.Slider(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False) + + if gradio_compat: + controlnet_module.change(build_sliders, inputs=[controlnet_module], outputs=[controlnet_processor_res, controlnet_threshold_a, controlnet_threshold_b, controlnet_advanced]) + + infotext_fields.extend([ + (controlnet_module, f"ControlNet Preprocessor"), + (controlnet_model, f"ControlNet Model"), + (controlnet_weight, f"ControlNet Weight"), + ]) + + with gr.Row(visible=False) as cn_env_row: + controlnet_resize_mode = gr.Radio(choices=["Envelope (Outer Fit)", "Scale to Fit (Inner Fit)", "Just Resize"], value="Scale to Fit (Inner Fit)", label="Resize Mode") + + # Video input to be fed into ControlNet + #input_video_url = gr.Textbox(source='upload', type='numpy', tool='sketch') # TODO + controlnet_input_video_chosen_file = gr.File(label="ControlNet Video Input", interactive=True, file_count="single", file_types=["video"], elem_id="controlnet_input_video_chosen_file", visible=False) + controlnet_input_video_mask_chosen_file = gr.File(label="ControlNet Video Mask Input", interactive=True, file_count="single", file_types=["video"], elem_id="controlnet_input_video_mask_chosen_file", visible=False) + + cn_hide_output_list = [controlnet_scribble_mode,controlnet_rgbbgr_mode,controlnet_lowvram,cn_mod_row,cn_weight_row,cn_env_row,controlnet_input_video_chosen_file,controlnet_input_video_mask_chosen_file] + for cn_output in cn_hide_output_list: + controlnet_enabled.change(fn=hide_ui_by_cn_status, inputs=controlnet_enabled,outputs=cn_output) + + return locals() + + +def setup_controlnet_ui(): + if not find_controlnet(): + gr.HTML(""" + ControlNet not found. Please install it :) + """, elem_id='controlnet_not_found_html_msg') + return {} + + return setup_controlnet_ui_raw() + +def controlnet_component_names(): + if not find_controlnet(): + return [] + + controlnet_args_names = str(r'''controlnet_input_video_chosen_file, controlnet_input_video_mask_chosen_file, +controlnet_enabled, controlnet_scribble_mode, controlnet_rgbbgr_mode, controlnet_lowvram, +controlnet_module, controlnet_model, +controlnet_weight, controlnet_guidance_strength, +controlnet_processor_res, +controlnet_threshold_a, controlnet_threshold_b, controlnet_resize_mode''' + ).replace("\n", "").replace("\r", "").replace(" ", "").split(',') + + return controlnet_args_names + +def is_controlnet_enabled(controlnet_args): + return 'controlnet_enabled' in vars(controlnet_args) and controlnet_args.controlnet_enabled + +def process_txt2img_with_controlnet(p, args, anim_args, loop_args, controlnet_args, root, frame_idx = 1): + # TODO: use init image and mask here + p.control_net_enabled = False # we don't want to cause concurrence + p.init_images = [] + controlnet_frame_path = os.path.join(args.outdir, 'controlnet_inputframes', f"{frame_idx:05}.jpg") + controlnet_mask_frame_path = os.path.join(args.outdir, 'controlnet_maskframes', f"{frame_idx:05}.jpg") + cn_mask_np = None + cn_image_np = None + + if not os.path.exists(controlnet_frame_path) and not os.path.exists(controlnet_mask_frame_path): + print(f'\033[33mNeither the base nor the masking frames for ControlNet were found. Using the regular pipeline\033[0m') + from .deforum_controlnet_hardcode import restore_networks + unet = p.sd_model.model.diffusion_model + restore_networks(unet) + return process_images(p) + + if os.path.exists(controlnet_frame_path): + cn_image_np = Image.open(controlnet_frame_path).convert("RGB") + + if os.path.exists(controlnet_mask_frame_path): + cn_mask_np = Image.open(controlnet_mask_frame_path).convert("RGB") + + cn_args = { + "enabled": True, + "module": controlnet_args.controlnet_module, + "model": controlnet_args.controlnet_model, + "weight": controlnet_args.controlnet_weight, + "input_image": {'image': cn_image_np, 'mask': cn_mask_np}, + "scribble_mode": controlnet_args.controlnet_scribble_mode, + "resize_mode": controlnet_args.controlnet_resize_mode, + "rgbbgr_mode": controlnet_args.controlnet_rgbbgr_mode, + "lowvram": controlnet_args.controlnet_lowvram, + "processor_res": controlnet_args.controlnet_processor_res, + "threshold_a": controlnet_args.controlnet_threshold_a, + "threshold_b": controlnet_args.controlnet_threshold_b, + "guidance_strength": controlnet_args.controlnet_guidance_strength,"guidance_strength": controlnet_args.controlnet_guidance_strength, + } + + from .deforum_controlnet_hardcode import process + p.script_args = ( + cn_args["enabled"], + cn_args["module"], + cn_args["model"], + cn_args["weight"], + cn_args["input_image"], + cn_args["scribble_mode"], + cn_args["resize_mode"], + cn_args["rgbbgr_mode"], + cn_args["lowvram"], + cn_args["processor_res"], + cn_args["threshold_a"], + cn_args["threshold_b"], + cn_args["guidance_strength"], + ) + + table = Table(title="ControlNet params",padding=0, box=box.ROUNDED) + + field_names = [] + field_names += ["module", "model", "weight", "guidance", "scribble", "resize", "rgb->bgr", "proc res", "thr a", "thr b"] + for field_name in field_names: + table.add_column(field_name, justify="center") + + rows = [] + rows += [cn_args["module"], cn_args["model"], cn_args["weight"], cn_args["guidance_strength"], cn_args["scribble_mode"], cn_args["resize_mode"], cn_args["rgbbgr_mode"], cn_args["processor_res"], cn_args["threshold_a"], cn_args["threshold_b"]] + rows = [str(x) for x in rows] + + table.add_row(*rows) + + console.print(table) + + processed = process(p, *(p.script_args)) + + if processed is None: # the script just swaps the pipeline, so failing is OK for the first time + processed = process_images(p) + + if processed is None: # now it's definitely not OK + raise Exception("\033[31mFailed to process a frame with ControlNet enabled!\033[0m") + + p.close() + + return processed + +def process_img2img_with_controlnet(p, args, anim_args, loop_args, controlnet_args, root, frame_idx = 0): + p.control_net_enabled = False # we don't want to cause concurrence + controlnet_frame_path = os.path.join(args.outdir, 'controlnet_inputframes', f"{frame_idx:05}.jpg") + controlnet_mask_frame_path = os.path.join(args.outdir, 'controlnet_maskframes', f"{frame_idx:05}.jpg") + + print(f'Reading ControlNet base frame {frame_idx} at {controlnet_frame_path}') + print(f'Reading ControlNet mask frame {frame_idx} at {controlnet_mask_frame_path}') + + cn_mask_np = None + cn_image_np = None + + if not os.path.exists(controlnet_frame_path) and not os.path.exists(controlnet_mask_frame_path): + print(f'\033[33mNeither the base nor the masking frames for ControlNet were found. Using the regular pipeline\033[0m') + return process_images(p) + + if os.path.exists(controlnet_frame_path): + cn_image_np = np.array(Image.open(controlnet_frame_path).convert("RGB")).astype('uint8') + + if os.path.exists(controlnet_mask_frame_path): + cn_mask_np = np.array(Image.open(controlnet_mask_frame_path).convert("RGB")).astype('uint8') + + cn_args = { + "enabled": True, + "module": controlnet_args.controlnet_module, + "model": controlnet_args.controlnet_model, + "weight": controlnet_args.controlnet_weight, + "input_image": {'image': cn_image_np, 'mask': cn_mask_np}, + "scribble_mode": controlnet_args.controlnet_scribble_mode, + "resize_mode": controlnet_args.controlnet_resize_mode, + "rgbbgr_mode": controlnet_args.controlnet_rgbbgr_mode, + "lowvram": controlnet_args.controlnet_lowvram, + "processor_res": controlnet_args.controlnet_processor_res, + "threshold_a": controlnet_args.controlnet_threshold_a, + "threshold_b": controlnet_args.controlnet_threshold_b, + "guidance_strength": controlnet_args.controlnet_guidance_strength, + } + + from .deforum_controlnet_hardcode import process + p.script_args = ( + cn_args["enabled"], + cn_args["module"], + cn_args["model"], + cn_args["weight"], + cn_args["input_image"], + cn_args["scribble_mode"], + cn_args["resize_mode"], + cn_args["rgbbgr_mode"], + cn_args["lowvram"], + cn_args["processor_res"], + cn_args["threshold_a"], + cn_args["threshold_b"], + cn_args["guidance_strength"], + ) + + table = Table(title="ControlNet params",padding=0, box=box.ROUNDED) + + field_names = [] + field_names += ["module", "model", "weight", "guidance", "scribble", "resize", "rgb->bgr", "proc res", "thr a", "thr b"] + for field_name in field_names: + table.add_column(field_name, justify="center") + + rows = [] + rows += [cn_args["module"], cn_args["model"], cn_args["weight"], cn_args["guidance_strength"], cn_args["scribble_mode"], cn_args["resize_mode"], cn_args["rgbbgr_mode"], cn_args["processor_res"], cn_args["threshold_a"], cn_args["threshold_b"]] + rows = [str(x) for x in rows] + + table.add_row(*rows) + + console.print(table) + + processed = process(p, *(p.script_args)) + + if processed is None: # the script just swaps the pipeline, so failing is OK for the first time + processed = process_images(p) + + if processed is None: # now it's definitely not OK + raise Exception("\033[31mFailed to process a frame with ControlNet enabled!\033[0m") + + p.close() + + return processed + +import pathlib +from .video_audio_utilities import vid2frames + +def unpack_controlnet_vids(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root): + if controlnet_args.controlnet_input_video_chosen_file is not None and len(controlnet_args.controlnet_input_video_chosen_file.name) > 0: + print(f'Unpacking ControlNet base video') + # create a folder for the video input frames to live in + mask_in_frame_path = os.path.join(args.outdir, 'controlnet_inputframes') + os.makedirs(mask_in_frame_path, exist_ok=True) + + # save the video frames from mask video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") + vid2frames(video_path=controlnet_args.controlnet_input_video_chosen_file.name, video_in_frame_path=mask_in_frame_path, n=anim_args.extract_nth_frame, overwrite=anim_args.overwrite_extracted_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame, numeric_files_output=True) + + print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}") + print(f'ControlNet base video unpacked!') + + if controlnet_args.controlnet_input_video_mask_chosen_file is not None and len(controlnet_args.controlnet_input_video_mask_chosen_file.name) > 0: + print(f'Unpacking ControlNet video mask') + # create a folder for the video input frames to live in + mask_in_frame_path = os.path.join(args.outdir, 'controlnet_maskframes') + os.makedirs(mask_in_frame_path, exist_ok=True) + + # save the video frames from mask video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") + vid2frames(video_path=controlnet_args.controlnet_input_video_mask_chosen_file.name, video_in_frame_path=mask_in_frame_path, n=anim_args.extract_nth_frame, overwrite=anim_args.overwrite_extracted_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame, numeric_files_output=True) + + print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}") + print(f'ControlNet video mask unpacked!') + +def hide_ui_by_cn_status(choice): + return gr.update(visible=True) if choice else gr.update(visible=False) + +def build_sliders(cn_model): + if cn_model == "canny": + return [ + gr.update(label="Annotator resolution", value=512, minimum=64, maximum=2048, step=1, interactive=True), + gr.update(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1, interactive=True), + gr.update(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1, interactive=True), + gr.update(visible=True) + ] + elif cn_model == "mlsd": #Hough + return [ + gr.update(label="Hough Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True), + gr.update(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01, interactive=True), + gr.update(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01, interactive=True), + gr.update(visible=True) + ] + elif cn_model in ["hed", "fake_scribble"]: + return [ + gr.update(label="HED Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True), + gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=True) + ] + elif cn_model in ["openpose", "openpose_hand", "segmentation"]: + return [ + gr.update(label="Annotator Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True), + gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=True) + ] + elif cn_model == "depth": + return [ + gr.update(label="Midas Resolution", minimum=64, maximum=2048, value=384, step=1, interactive=True), + gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=True) + ] + elif cn_model == "depth_leres": + return [ + gr.update(label="LeReS Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True), + gr.update(label="Remove Near %", value=0, minimum=0, maximum=100, step=0.1, interactive=True), + gr.update(label="Remove Background %", value=0, minimum=0, maximum=100, step=0.1, interactive=True), + gr.update(visible=True) + ] + elif cn_model == "normal_map": + return [ + gr.update(label="Normal Resolution", minimum=64, maximum=2048, value=512, step=1, interactive=True), + gr.update(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01, interactive=True), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=True) + ] + elif cn_model == "none": + return [ + gr.update(label="Normal Resolution", value=64, minimum=64, maximum=2048, interactive=False), + gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=False) + ] + else: + return [ + gr.update(label="Annotator resolution", value=512, minimum=64, maximum=2048, step=1, interactive=True), + gr.update(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False), + gr.update(visible=True) + ] + + # def svgPreprocess(inputs): + # if (inputs): + # if (inputs['image'].startswith("data:image/svg+xml;base64,") and svgsupport): + # svg_data = base64.b64decode(inputs['image'].replace('data:image/svg+xml;base64,','')) + # drawing = svg2rlg(io.BytesIO(svg_data)) + # png_data = renderPM.drawToString(drawing, fmt='PNG') + # encoded_string = base64.b64encode(png_data) + # base64_str = str(encoded_string, "utf-8") + # base64_str = "data:image/png;base64,"+ base64_str + # inputs['image'] = base64_str + # return input_image.orgpreprocess(inputs) + # return None \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/deforum_controlnet_hardcode.py b/extensions/deforum/scripts/deforum_helpers/deforum_controlnet_hardcode.py new file mode 100644 index 0000000000000000000000000000000000000000..1446f77634a54c294eb1327786ae33c1ee7b4dcd --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/deforum_controlnet_hardcode.py @@ -0,0 +1,193 @@ +# TODO HACK FIXME HARDCODE — as using the scripts doesn't seem to work for some reason +deforum_latest_network = None +deforum_latest_params = (None, 'placeholder to trigger the model loading') +deforum_input_image = None +from scripts.processor import unload_hed, unload_mlsd, unload_midas, unload_leres, unload_pidinet, unload_openpose, unload_uniformer, HWC3 +import modules.shared as shared +import modules.devices as devices +import modules.processing as processing +from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img +import numpy as np +from scripts.controlnet import update_cn_models, cn_models, cn_models_names +import os +import modules.scripts as scrpts +import torch +from scripts.cldm import PlugableControlModel +from scripts.adapter import PlugableAdapter +from scripts.utils import load_state_dict +from torchvision.transforms import Resize, InterpolationMode, CenterCrop, Compose +from einops import rearrange +cn_models_dir = os.path.join(scrpts.basedir(), "models") +default_conf_adapter = os.path.join(cn_models_dir, "sketch_adapter_v14.yaml") +default_conf = os.path.join(cn_models_dir, "cldm_v15.yaml") +unloadable = { + "hed": unload_hed, + "fake_scribble": unload_hed, + "mlsd": unload_mlsd, + "depth": unload_midas, + "depth_leres": unload_leres, + "normal_map": unload_midas, + "pidinet": unload_pidinet, + "openpose": unload_openpose, + "openpose_hand": unload_openpose, + "segmentation": unload_uniformer, +} +deforum_latest_model_hash = "" + +def restore_networks(unet): + global deforum_latest_network + global deforum_latest_params + if deforum_latest_network is not None: + print("restoring last networks") + deforum_input_image = None + deforum_latest_network.restore(unet) + deforum_latest_network = None + + last_module = deforum_latest_params[0] + if last_module is not None: + unloadable.get(last_module, lambda:None)() + +def process(p, *args): + + global deforum_latest_network + global deforum_latest_params + global deforum_input_image + global deforum_latest_model_hash + + unet = p.sd_model.model.diffusion_model + + enabled, module, model, weight, image, scribble_mode, \ + resize_mode, rgbbgr_mode, lowvram, pres, pthr_a, pthr_b, guidance_strength = args + + if not enabled: + restore_networks(unet) + return + + models_changed = deforum_latest_params[1] != model \ + or deforum_latest_model_hash != p.sd_model.sd_model_hash or deforum_latest_network == None \ + or (deforum_latest_network is not None and deforum_latest_network.lowvram != lowvram) + + deforum_latest_params = (module, model) + deforum_latest_model_hash = p.sd_model.sd_model_hash + if models_changed: + restore_networks(unet) + model_path = cn_models.get(model, None) + + if model_path is None: + raise RuntimeError(f"model not found: {model}") + + # trim '"' at start/end + if model_path.startswith("\"") and model_path.endswith("\""): + model_path = model_path[1:-1] + + if not os.path.exists(model_path): + raise ValueError(f"file not found: {model_path}") + + print(f"Loading preprocessor: {module}, model: {model}") + state_dict = load_state_dict(model_path) + network_module = PlugableControlModel + network_config = shared.opts.data.get("control_net_model_config", default_conf) + if any([k.startswith("body.") for k, v in state_dict.items()]): + # adapter model + network_module = PlugableAdapter + network_config = shared.opts.data.get("control_net_model_adapter_config", default_conf_adapter) + + network = network_module( + state_dict=state_dict, + config_path=network_config, + weight=weight, + lowvram=lowvram, + base_model=unet, + ) + network.to(p.sd_model.device, dtype=p.sd_model.dtype) + network.hook(unet, p.sd_model) + + print(f"ControlNet model {model} loaded.") + deforum_latest_network = network + + if image is not None: + deforum_input_image = HWC3(image['image']) + if 'mask' in image and image['mask'] is not None and not ((image['mask'][:, :, 0]==0).all() or (image['mask'][:, :, 0]==255).all()): + print("using mask as input") + deforum_input_image = HWC3(image['mask'][:, :, 0]) + scribble_mode = True + else: + # use img2img init_image as default + deforum_input_image = getattr(p, "init_images", [None])[0] + if deforum_input_image is None: + raise ValueError('controlnet is enabled but no input image is given') + deforum_input_image = HWC3(np.asarray(deforum_input_image)) + + if scribble_mode: + detected_map = np.zeros_like(deforum_input_image, dtype=np.uint8) + detected_map[np.min(deforum_input_image, axis=2) < 127] = 255 + deforum_input_image = detected_map + + from scripts.processor import canny, midas, midas_normal, leres, hed, mlsd, openpose, pidinet, simple_scribble, fake_scribble, uniformer + + preprocessor = { + "none": lambda x, *args, **kwargs: x, + "canny": canny, + "depth": midas, + "depth_leres": leres, + "hed": hed, + "mlsd": mlsd, + "normal_map": midas_normal, + "openpose": openpose, + # "openpose_hand": openpose_hand, + "pidinet": pidinet, + "scribble": simple_scribble, + "fake_scribble": fake_scribble, + "segmentation": uniformer, + } + + preprocessor = preprocessor[deforum_latest_params[0]] + h, w, bsz = p.height, p.width, p.batch_size + if pres > 64: + detected_map = preprocessor(deforum_input_image, res=pres, thr_a=pthr_a, thr_b=pthr_b) + else: + detected_map = preprocessor(deforum_input_image) + detected_map = HWC3(detected_map) + + if module == "normal_map" or rgbbgr_mode: + control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(devices.get_device_for("controlnet")) / 255.0 + else: + control = torch.from_numpy(detected_map.copy()).float().to(devices.get_device_for("controlnet")) / 255.0 + + control = rearrange(control, 'h w c -> c h w') + detected_map = rearrange(torch.from_numpy(detected_map), 'h w c -> c h w') + if resize_mode == "Scale to Fit (Inner Fit)": + transform = Compose([ + Resize(h if hw else w, interpolation=InterpolationMode.BICUBIC), + CenterCrop(size=(h, w)) + ]) + control = transform(control) + detected_map = transform(detected_map) + else: + control = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(control) + detected_map = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(detected_map) + + # for log use + detected_map = rearrange(detected_map, 'c h w -> h w c').numpy().astype(np.uint8) + + # control = torch.stack([control for _ in range(bsz)], dim=0) + deforum_latest_network.notify(control, weight, guidance_strength) + + if shared.opts.data.get("control_net_skip_img2img_processing") and hasattr(p, "init_images"): + swap_img2img_pipeline(p) + +def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img): + p.__class__ = processing.StableDiffusionProcessingTxt2Img + dummy = processing.StableDiffusionProcessingTxt2Img() + for k,v in dummy.__dict__.items(): + if hasattr(p, k): + continue + setattr(p, k, v) + diff --git a/extensions/deforum/scripts/deforum_helpers/deprecation_utils.py b/extensions/deforum/scripts/deforum_helpers/deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9485b1b39629ce1c0c1c584e1294e64e300c06db --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/deprecation_utils.py @@ -0,0 +1,20 @@ +# This file is used to map deprecated setting names in a dictionary +# and print a message containing the old and the new names +# if the latter is removed completely, put a warning + +# as of 2023-02-05 +# "histogram_matching" -> None + +deprecation_map = { + "histogram_matching": None, + "flip_2d_perspective": "enable_perspective_flip" +} + +def handle_deprecated_settings(settings_json): + for old_name, new_name in deprecation_map.items(): + if old_name in settings_json: + if new_name is None: + print(f"WARNING: Setting '{old_name}' has been removed. It will be discarded and the default value used instead!") + else: + print(f"WARNING: Setting '{old_name}' has been renamed to '{new_name}'. The saved settings file will reflect the change") + settings_json[new_name] = settings_json.pop(old_name) diff --git a/extensions/deforum/scripts/deforum_helpers/depth.py b/extensions/deforum/scripts/deforum_helpers/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..61a50459a4a3ed046ed1c4cdcbd914437026fc0d --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/depth.py @@ -0,0 +1,166 @@ +import math, os, subprocess +import cv2 +import hashlib +import numpy as np +import torch +import gc +import torchvision.transforms as T +from einops import rearrange, repeat +from PIL import Image +from infer import InferenceHelper +from midas.dpt_depth import DPTDepthModel +from midas.transforms import Resize, NormalizeImage, PrepareForNet +import torchvision.transforms.functional as TF +from .general_utils import checksum + +class DepthModel(): + def __init__(self, device): + self.adabins_helper = None + self.depth_min = 1000 + self.depth_max = -1000 + self.device = device + self.midas_model = None + self.midas_transform = None + + def load_adabins(self, models_path): + if not os.path.exists(os.path.join(models_path,'AdaBins_nyu.pt')): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(r"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt", models_path) + if checksum(os.path.join(models_path,'AdaBins_nyu.pt')) != "643db9785c663aca72f66739427642726b03acc6c4c1d3755a4587aa2239962746410d63722d87b49fc73581dbc98ed8e3f7e996ff7b9c0d56d0fbc98e23e41a": + raise Exception(r"Error while downloading AdaBins_nyu.pt. Please download from here: https://drive.google.com/file/d/1lvyZZbC9NLcS8a__YPcUP7rDiIpbRpoF and place in: " + models_path) + self.adabins_helper = InferenceHelper(models_path=models_path, dataset='nyu', device=self.device) + + def load_midas(self, models_path, half_precision=True): + if not os.path.exists(os.path.join(models_path, 'dpt_large-midas-2f21e586.pt')): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(r"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", models_path) + if checksum(os.path.join(models_path,'dpt_large-midas-2f21e586.pt')) != "fcc4829e65d00eeed0a38e9001770676535d2e95c8a16965223aba094936e1316d569563552a852d471f310f83f597e8a238987a26a950d667815e08adaebc06": + raise Exception(r"Error while downloading dpt_large-midas-2f21e586.pt. Please download from here: https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt and place in: " + models_path) + + self.midas_model = DPTDepthModel( + path=f"{models_path}/dpt_large-midas-2f21e586.pt", + backbone="vitl16_384", + non_negative=True, + ) + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + self.midas_transform = T.Compose([ + Resize( + 384, 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet() + ]) + + self.midas_model.eval() + if self.device == torch.device("cuda"): + self.midas_model = self.midas_model.to(memory_format=torch.channels_last) + if half_precision: + self.midas_model = self.midas_model.half() + self.midas_model.to(self.device) + + def predict(self, prev_img_cv2, anim_args, half_precision) -> torch.Tensor: + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + # predict depth with AdaBins + use_adabins = anim_args.midas_weight < 1.0 and self.adabins_helper is not None + if use_adabins: + MAX_ADABINS_AREA = 500000 + MIN_ADABINS_AREA = 448*448 + + # resize image if too large or too small + img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR)) + image_pil_area = w*h + resized = True + if image_pil_area > MAX_ADABINS_AREA: + scale = math.sqrt(MAX_ADABINS_AREA) / math.sqrt(image_pil_area) + depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.LANCZOS) # LANCZOS is good for downsampling + print(f" resized to {depth_input.width}x{depth_input.height}") + elif image_pil_area < MIN_ADABINS_AREA: + scale = math.sqrt(MIN_ADABINS_AREA) / math.sqrt(image_pil_area) + depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.BICUBIC) + print(f" resized to {depth_input.width}x{depth_input.height}") + else: + depth_input = img_pil + resized = False + + # predict depth and resize back to original dimensions + try: + with torch.no_grad(): + _, adabins_depth = self.adabins_helper.predict_pil(depth_input) + if resized: + adabins_depth = TF.resize( + torch.from_numpy(adabins_depth), + torch.Size([h, w]), + interpolation=TF.InterpolationMode.BICUBIC + ) + adabins_depth = adabins_depth.cpu().numpy() + adabins_depth = adabins_depth.squeeze() + except: + print(f" exception encountered, falling back to pure MiDaS") + use_adabins = False + torch.cuda.empty_cache() + + if self.midas_model is not None: + # convert image from 0->255 uint8 to 0->1 float for feeding to MiDaS + img_midas = prev_img_cv2.astype(np.float32) / 255.0 + img_midas_input = self.midas_transform({"image": img_midas})["image"] + + # MiDaS depth estimation implementation + sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0) + if self.device == torch.device("cuda"): + sample = sample.to(memory_format=torch.channels_last) + if half_precision: + sample = sample.half() + with torch.no_grad(): + midas_depth = self.midas_model.forward(sample) + midas_depth = torch.nn.functional.interpolate( + midas_depth.unsqueeze(1), + size=img_midas.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + midas_depth = midas_depth.cpu().numpy() + torch.cuda.empty_cache() + + # MiDaS makes the near values greater, and the far values lesser. Let's reverse that and try to align with AdaBins a bit better. + midas_depth = np.subtract(50.0, midas_depth) + midas_depth = midas_depth / 19.0 + + # blend between MiDaS and AdaBins predictions + if use_adabins: + depth_map = midas_depth*anim_args.midas_weight + adabins_depth*(1.0-anim_args.midas_weight) + else: + depth_map = midas_depth + + depth_map = np.expand_dims(depth_map, axis=0) + depth_tensor = torch.from_numpy(depth_map).squeeze().to(self.device) + else: + depth_tensor = torch.ones((h, w), device=self.device) + + return depth_tensor + + def save(self, filename: str, depth: torch.Tensor): + depth = depth.cpu().numpy() + if len(depth.shape) == 2: + depth = np.expand_dims(depth, axis=0) + self.depth_min = min(self.depth_min, depth.min()) + self.depth_max = max(self.depth_max, depth.max()) + print(f" depth min:{depth.min()} max:{depth.max()}") + denom = max(1e-8, self.depth_max - self.depth_min) + temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c') + temp = repeat(temp, 'h w 1 -> h w c', c=3) + Image.fromarray(temp.astype(np.uint8)).save(filename) + + def to(self, device): + self.device = device + self.midas_model.to(device) + if self.adabins_helper is not None: + self.adabins_helper.to(device) + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions/deforum/scripts/deforum_helpers/frame_interpolation.py b/extensions/deforum/scripts/deforum_helpers/frame_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..59d205fdc3294a39b46a9347aa174d2503d37e00 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/frame_interpolation.py @@ -0,0 +1,192 @@ +import os +from pathlib import Path +from rife.inference_video import run_rife_new_video_infer +from .video_audio_utilities import get_quick_vid_info, vid2frames, media_file_has_audio, extract_number, ffmpeg_stitch_video +from film_interpolation.film_inference import run_film_interp_infer +from .general_utils import duplicate_pngs_from_folder, checksum + +# gets 'RIFE v4.3', returns: 'RIFE43' +def extract_rife_name(string): + parts = string.split() + if len(parts) != 2 or parts[0] != "RIFE" or (parts[1][0] != "v" or not parts[1][1:].replace('.','').isdigit()): + raise ValueError("Input string should contain exactly 2 words, first word should be 'RIFE' and second word should start with 'v' followed by 2 numbers") + return "RIFE"+parts[1][1:].replace('.','') + +# This function usually gets a filename, and converts it to a legal linux/windows *folder* name +def clean_folder_name(string): + illegal_chars = ["/", "\\", "<", ">", ":", "\"", "|", "?", "*", "."] + for char in illegal_chars: + string = string.replace(char, "_") + return string + +def set_interp_out_fps(interp_x, slow_x_enabled, slom_x, in_vid_fps): + if interp_x == 'Disabled' or in_vid_fps in ('---', None, '', 'None'): + return '---' + + # clean_interp_x = extract_number(interp_x) + # clean_slom_x = extract_number(slom_x) + fps = float(in_vid_fps) * int(interp_x) + # if slom_x != -1: + if slow_x_enabled: + fps /= int(slom_x) + return int(fps) if fps.is_integer() else fps + +# get uploaded video frame count, fps, and return 3 valuees for the gradio UI: in fcount, in fps, out fps (using the set_interp_out_fps function above) +def gradio_f_interp_get_fps_and_fcount(vid_path, interp_x, slow_x_enabled, slom_x): + if vid_path is None: + return '---', '---', '---' + fps, fcount, resolution = get_quick_vid_info(vid_path.name) + expected_out_fps = set_interp_out_fps(interp_x, slow_x_enabled, slom_x, fps) + return (str(round(fps,2)) if fps is not None else '---', (round(fcount,2)) if fcount is not None else '---', round(expected_out_fps,2)) + +# handle call to interpolate an uploaded video from gradio button in args.py (the function that calls this func is named 'upload_vid_to_rife') +def process_interp_vid_upload_logic(file, engine, x_am, sl_enabled, sl_am, keep_imgs, f_location, f_crf, f_preset, in_vid_fps, f_models_path, vid_file_name): + + print("got a request to *frame interpolate* an existing video.") + + _, _, resolution = get_quick_vid_info(file.name) + folder_name = clean_folder_name(Path(vid_file_name).stem) + outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-interpolation', folder_name) + i = 1 + while os.path.exists(outdir_no_tmp): + outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-interpolation', folder_name + '_' + str(i)) + i += 1 + + outdir = os.path.join(outdir_no_tmp, 'tmp_input_frames') + os.makedirs(outdir, exist_ok=True) + + vid2frames(video_path=file.name, video_in_frame_path=outdir, overwrite=True, extract_from_frame=0, extract_to_frame=-1, numeric_files_output=True, out_img_format='png') + + # check if the uploaded vid has an audio stream. If it doesn't, set audio param to None so that ffmpeg won't try to add non-existing audio to final video. + audio_file_to_pass = None + if media_file_has_audio(file.name, f_location): + audio_file_to_pass = file.name + + process_video_interpolation(frame_interpolation_engine=engine, frame_interpolation_x_amount=x_am, frame_interpolation_slow_mo_enabled = sl_enabled,frame_interpolation_slow_mo_amount=sl_am, orig_vid_fps=in_vid_fps, deforum_models_path=f_models_path, real_audio_track=audio_file_to_pass, raw_output_imgs_path=outdir, img_batch_id=None, ffmpeg_location=f_location, ffmpeg_crf=f_crf, ffmpeg_preset=f_preset, keep_interp_imgs=keep_imgs, orig_vid_name=folder_name, resolution=resolution) + +# handle params before talking with the actual interpolation module (rifee/film, more to be added) +def process_video_interpolation(frame_interpolation_engine, frame_interpolation_x_amount, frame_interpolation_slow_mo_enabled, frame_interpolation_slow_mo_amount, orig_vid_fps, deforum_models_path, real_audio_track, raw_output_imgs_path, img_batch_id, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, keep_interp_imgs, orig_vid_name, resolution): + + # set initial output vid fps + fps = float(orig_vid_fps) * frame_interpolation_x_amount + + # re-calculate fps param to pass if slow_mo mode is enabled + if frame_interpolation_slow_mo_enabled: + fps = float(orig_vid_fps) * frame_interpolation_x_amount / int(frame_interpolation_slow_mo_amount) + # disable audio-adding by setting real_audio_track to None if slow-mo is enabled + if real_audio_track is not None and frame_interpolation_slow_mo_enabled: + real_audio_track = None + + if frame_interpolation_engine == 'None': + return + elif frame_interpolation_engine.startswith("RIFE"): + # make sure interp_x is valid and in range + if frame_interpolation_x_amount not in range(2, 11): + raise Error("frame_interpolation_x_amount must be between 2x and 10x") + + # set UHD to True if res' is 2K or higher + if resolution: + UHD = resolution[0] >= 2048 and resolution[1] >= 2048 + else: + UHD = False + # e.g from "RIFE v2.3 to RIFE23" + actual_model_folder_name = extract_rife_name(frame_interpolation_engine) + + # run actual rife interpolation and video stitching etc - the whole suite + run_rife_new_video_infer(interp_x_amount=frame_interpolation_x_amount, slow_mo_enabled = frame_interpolation_slow_mo_enabled, slow_mo_x_amount=frame_interpolation_slow_mo_amount, model=actual_model_folder_name, fps=fps, deforum_models_path=deforum_models_path, audio_track=real_audio_track, raw_output_imgs_path=raw_output_imgs_path, img_batch_id=img_batch_id, ffmpeg_location=ffmpeg_location, ffmpeg_crf=ffmpeg_crf, ffmpeg_preset=ffmpeg_preset, keep_imgs=keep_interp_imgs, orig_vid_name=orig_vid_name, UHD=UHD) + elif frame_interpolation_engine == 'FILM': + prepare_film_inference(deforum_models_path=deforum_models_path, x_am=frame_interpolation_x_amount, sl_enabled=frame_interpolation_slow_mo_enabled, sl_am=frame_interpolation_slow_mo_amount, keep_imgs=keep_interp_imgs, raw_output_imgs_path=raw_output_imgs_path, img_batch_id=img_batch_id, f_location=ffmpeg_location, f_crf=ffmpeg_crf, f_preset=ffmpeg_preset, fps=fps, audio_track=real_audio_track, orig_vid_name=orig_vid_name) + else: + print("Unknown Frame Interpolation engine chosen. Doing nothing.") + return + +def prepare_film_inference(deforum_models_path, x_am, sl_enabled, sl_am, keep_imgs, raw_output_imgs_path, img_batch_id, f_location, f_crf, f_preset, fps, audio_track, orig_vid_name): + import shutil + + parent_folder = os.path.dirname(raw_output_imgs_path) + grandparent_folder = os.path.dirname(parent_folder) + if orig_vid_name is not None: + interp_vid_path = os.path.join(parent_folder, str(orig_vid_name) +'_FILM_x' + str(x_am)) + else: + interp_vid_path = os.path.join(raw_output_imgs_path, str(img_batch_id) +'_FILM_x' + str(x_am)) + + film_model_name = 'film_net_fp16.pt' + film_model_folder = os.path.join(deforum_models_path,'film_interpolation') + film_model_path = os.path.join(film_model_folder, film_model_name) # actual full path to the film .pt model file + output_interp_imgs_folder = os.path.join(raw_output_imgs_path, 'interpolated_frames_film') + # set custom name depending on if we interpolate after a run, or interpolate a video (related/unrelated to deforum, we don't know) directly from within the interpolation tab + # interpolated_path = os.path.join(args.raw_output_imgs_path, 'interpolated_frames_rife') + if orig_vid_name is not None: # interpolating a video (deforum or unrelated) + custom_interp_path = "{}_{}".format(output_interp_imgs_folder, orig_vid_name) + else: # interpolating after a deforum run: + custom_interp_path = "{}_{}".format(output_interp_imgs_folder, img_batch_id) + + # interp_vid_path = os.path.join(raw_output_imgs_path, str(img_batch_id) + '_FILM_x' + str(x_am)) + img_path_for_ffmpeg = os.path.join(custom_interp_path, "frame_%05d.png") + + if sl_enabled: + interp_vid_path = interp_vid_path + '_slomo_x' + str(sl_am) + interp_vid_path = interp_vid_path + '.mp4' + + # In this folder we temporarily keep the original frames (converted/ copy-pasted and img format depends on scenario) + # the convertion case is done to avert a problem with 24 and 32 mixed outputs from the same animation run + temp_convert_raw_png_path = os.path.join(raw_output_imgs_path, "tmp_film_folder") + total_frames = duplicate_pngs_from_folder(raw_output_imgs_path, temp_convert_raw_png_path, img_batch_id, None) + check_and_download_film_model('film_net_fp16.pt', film_model_folder) # TODO: split this part + + # get number of in-between-frames to provide to FILM - mimics how RIFE works, we should get the same amount of total frames in the end + film_in_between_frames_count = calculate_frames_to_add(total_frames, x_am) + # Run actual FILM inference + run_film_interp_infer( + model_path = film_model_path, + input_folder = temp_convert_raw_png_path, + save_folder = custom_interp_path, # output folder is created in the infer part + inter_frames = film_in_between_frames_count) + + add_soundtrack = 'None' + if not audio_track is None: + add_soundtrack = 'File' + + print (f"*Passing interpolated frames to ffmpeg...*") + exception_raised = False + try: + ffmpeg_stitch_video(ffmpeg_location=f_location, fps=fps, outmp4_path=interp_vid_path, stitch_from_frame=0, stitch_to_frame=999999, imgs_path=img_path_for_ffmpeg, add_soundtrack=add_soundtrack, audio_path=audio_track, crf=f_crf, preset=f_preset) + except Exception as e: + exception_raised = True + print(f"An error occurred while stitching the video: {e}") + + if orig_vid_name and (keep_imgs or exception_raised): + shutil.move(custom_interp_path, parent_folder) + if not keep_imgs and not exception_raised: + if fps <= 450: # keep interp frames automatically if out_vid fps is above 450 + shutil.rmtree(custom_interp_path, ignore_errors=True) + # delete duplicated raw non-interpolated frames + shutil.rmtree(temp_convert_raw_png_path, ignore_errors=True) + # remove folder with raw (non-interpolated) vid input frames in case of input VID and not PNGs + if orig_vid_name: + shutil.rmtree(raw_output_imgs_path, ignore_errors=True) + +def check_and_download_film_model(model_name, model_dest_folder): + from basicsr.utils.download_util import load_file_from_url + if model_name == 'film_net_fp16.pt': + model_dest_path = os.path.join(model_dest_folder, model_name) + download_url = 'https://github.com/hithereai/frame-interpolation-pytorch/releases/download/film_net_fp16.pt/film_net_fp16.pt' + film_model_hash = '0a823815b111488ac2b7dd7fe6acdd25d35a22b703e8253587764cf1ee3f8f93676d24154d9536d2ce5bc3b2f102fb36dfe0ca230dfbe289d5cd7bde5a34ec12' + else: # Unknown FILM model + raise Exception("Got a request to download an unknown FILM model. Can't proceed.") + if os.path.exists(model_dest_path): + return + try: + os.makedirs(model_dest_folder, exist_ok=True) + # download film model from url + load_file_from_url(download_url, model_dest_folder) + # verify checksum + if checksum(model_dest_path) != film_model_hash: + raise Exception(f"Error while downloading {model_name}. Please download from: {download_url}, and put in: {model_dest_folder}") + except Exception as e: + raise Exception(f"Error while downloading {model_name}. Please download from: {download_url}, and put in: {model_dest_folder}") + +# get film no. of frames to add after each pic from tot frames in interp_x values +def calculate_frames_to_add(total_frames, interp_x): + frames_to_add = (total_frames * interp_x - total_frames) / (total_frames - 1) + return int(round(frames_to_add)) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/general_utils.py b/extensions/deforum/scripts/deforum_helpers/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..893a7ff2870a93c64822c3e1a9cbe8c51faca811 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/general_utils.py @@ -0,0 +1,32 @@ +import hashlib +def checksum(filename, hash_factory=hashlib.blake2b, chunk_num_blocks=128): + h = hash_factory() + with open(filename,'rb') as f: + while chunk := f.read(chunk_num_blocks*h.block_size): + h.update(chunk) + return h.hexdigest() + +def get_os(): + import platform + return {"Windows": "Windows", "Linux": "Linux", "Darwin": "Mac"}.get(platform.system(), "Unknown") + +# used in src/rife/inference_video.py and more, soon +def duplicate_pngs_from_folder(from_folder, to_folder, img_batch_id, orig_vid_name): + import os, cv2, shutil #, subprocess + #TODO: don't copy-paste at all if the input is a video (now it copy-pastes, and if input is deforum run is also converts to make sure no errors rise cuz of 24-32 bit depth differences) + temp_convert_raw_png_path = os.path.join(from_folder, to_folder) + if not os.path.exists(temp_convert_raw_png_path): + os.makedirs(temp_convert_raw_png_path) + + frames_handled = 0 + for f in os.listdir(from_folder): + if ('png' in f or 'jpg' in f) and '-' not in f and '_depth_' not in f and ((img_batch_id is not None and f.startswith(img_batch_id) or img_batch_id is None)): + frames_handled +=1 + original_img_path = os.path.join(from_folder, f) + if orig_vid_name is not None: + shutil.copy(original_img_path, temp_convert_raw_png_path) + else: + image = cv2.imread(original_img_path) + new_path = os.path.join(temp_convert_raw_png_path, f) + cv2.imwrite(new_path, image, [cv2.IMWRITE_PNG_COMPRESSION, 0]) + return frames_handled \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/generate.py b/extensions/deforum/scripts/deforum_helpers/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..4203405bdf3a06330a655ebc6b58c5bd9dcccca6 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/generate.py @@ -0,0 +1,244 @@ +import numpy as np +import cv2 +from PIL import Image +from .prompt import split_weighted_subprompts +from .load_images import load_img, prepare_mask, check_mask_for_errors +from .webui_sd_pipeline import get_webui_sd_pipeline +from .animation import sample_from_cv2, sample_to_cv2 +from .rich import console +#Webui +import cv2 +from .animation import sample_from_cv2, sample_to_cv2 +from modules import processing, sd_models +from modules.shared import opts, sd_model +from modules.processing import process_images, StableDiffusionProcessingTxt2Img +from .deforum_controlnet import is_controlnet_enabled, process_txt2img_with_controlnet, process_img2img_with_controlnet + +import math, json, itertools +import requests + +def load_mask_latent(mask_input, shape): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + + if isinstance(mask_input, str): # mask input is probably a file name + if mask_input.startswith('http://') or mask_input.startswith('https://'): + mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA') + else: + mask_image = Image.open(mask_input).convert('RGBA') + elif isinstance(mask_input, Image.Image): + mask_image = mask_input + else: + raise Exception("mask_input must be a PIL image or a file name") + + mask_w_h = (shape[-1], shape[-2]) + mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS) + mask = mask.convert("L") + return mask + +def isJson(myjson): + try: + json.loads(myjson) + except ValueError as e: + return False + return True + +# Add pairwise implementation here not to upgrade +# the whole python to 3.10 just for one function +def pairwise_repl(iterable): + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + +def generate(args, anim_args, loop_args, controlnet_args, root, frame = 0, return_sample=False, sampler_name=None): + assert args.prompt is not None + + # Setup the pipeline + p = get_webui_sd_pipeline(args, root, frame) + p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame) + + if not args.use_init and args.strength > 0 and args.strength_0_no_init: + print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.") + print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n") + args.strength = 0 + processed = None + mask_image = None + init_image = None + image_init0 = None + + if loop_args.use_looper: + # TODO find out why we need to set this in the init tab + if args.strength == 0: + raise RuntimeError("Strength needs to be greater than 0 in Init tab and strength_0_no_init should *not* be checked") + if args.seed_behavior != "schedule": + raise RuntimeError("seed_behavior needs to be set to schedule in under 'Keyframes' tab --> 'Seed scheduling'") + if not isJson(loop_args.imagesToKeyframe): + raise RuntimeError("The images set for use with keyframe-guidance are not in a proper JSON format") + args.strength = loop_args.imageStrength + tweeningFrames = loop_args.tweeningFrameSchedule + blendFactor = .07 + colorCorrectionFactor = loop_args.colorCorrectionFactor + jsonImages = json.loads(loop_args.imagesToKeyframe) + framesToImageSwapOn = list(map(int, list(jsonImages.keys()))) + # find which image to show + frameToChoose = 0 + for swappingFrame in framesToImageSwapOn[1:]: + frameToChoose += (frame >= int(swappingFrame)) + + #find which frame to do our swapping on for tweening + skipFrame = 25 + for fs, fe in pairwise_repl(framesToImageSwapOn): + if fs <= frame <= fe: + skipFrame = fe - fs + + if frame % skipFrame <= tweeningFrames: # number of tweening frames + blendFactor = loop_args.blendFactorMax - loop_args.blendFactorSlope*math.cos((frame % tweeningFrames) / (tweeningFrames / 2)) + init_image2, _ = load_img(list(jsonImages.values())[frameToChoose], + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + image_init0 = list(jsonImages.values())[0] + + else: # they passed in a single init image + image_init0 = args.init_image + + + available_samplers = { + 'euler a':'Euler a', + 'euler':'Euler', + 'lms':'LMS', + 'heun':'Heun', + 'dpm2':'DPM2', + 'dpm2 a':'DPM2 a', + 'dpm++ 2s a':'DPM++ 2S a', + 'dpm++ 2m':'DPM++ 2M', + 'dpm++ sde':'DPM++ SDE', + 'dpm fast':'DPM fast', + 'dpm adaptive':'DPM adaptive', + 'lms karras':'LMS Karras' , + 'dpm2 karras':'DPM2 Karras', + 'dpm2 a karras':'DPM2 a Karras', + 'dpm++ 2s a karras':'DPM++ 2S a Karras', + 'dpm++ 2m karras':'DPM++ 2M Karras', + 'dpm++ sde karras':'DPM++ SDE Karras' + } + if sampler_name is not None: + if sampler_name in available_samplers.keys(): + args.sampler = available_samplers[sampler_name] + + if args.checkpoint is not None: + info = sd_models.get_closet_checkpoint_match(args.checkpoint) + if info is None: + raise RuntimeError(f"Unknown checkpoint: {args.checkpoint}") + sd_models.reload_model_weights(info=info) + + if args.init_sample is not None: + # TODO: cleanup init_sample remains later + img = args.init_sample + init_image = img + image_init0 = img + if loop_args.use_looper and isJson(loop_args.imagesToKeyframe): + init_image = Image.blend(init_image, init_image2, blendFactor) + correction_colors = Image.blend(init_image, init_image2, colorCorrectionFactor) + p.color_corrections = [processing.setup_color_correction(correction_colors)] + + # this is the first pass + elif loop_args.use_looper or (args.use_init and ((args.init_image != None and args.init_image != ''))): + init_image, mask_image = load_img(image_init0, # initial init image + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + + else: + + if anim_args.animation_mode != 'Interpolation': + print(f"Not using an init image (doing pure txt2img)") + p_txt = StableDiffusionProcessingTxt2Img( + sd_model=sd_model, + outpath_samples=root.tmp_deforum_run_duplicated_folder, + outpath_grids=root.tmp_deforum_run_duplicated_folder, + prompt=p.prompt, + styles=p.styles, + negative_prompt=p.negative_prompt, + seed=p.seed, + subseed=p.subseed, + subseed_strength=p.subseed_strength, + seed_resize_from_h=p.seed_resize_from_h, + seed_resize_from_w=p.seed_resize_from_w, + sampler_name=p.sampler_name, + batch_size=p.batch_size, + n_iter=p.n_iter, + steps=p.steps, + cfg_scale=p.cfg_scale, + width=p.width, + height=p.height, + restore_faces=p.restore_faces, + tiling=p.tiling, + enable_hr=None, + denoising_strength=None, + ) + # print dynamic table to cli + print_generate_table(args, anim_args, p_txt) + + if is_controlnet_enabled(controlnet_args): + processed = process_txt2img_with_controlnet(p, args, anim_args, loop_args, controlnet_args, root, frame) + else: + processed = processing.process_images(p_txt) + + if processed is None: + # Mask functions + if args.use_mask: + mask = args.mask_image + #assign masking options to pipeline + if mask is not None: + p.inpainting_mask_invert = args.invert_mask + p.inpainting_fill = args.fill + p.inpaint_full_res= args.full_res_mask + p.inpaint_full_res_padding = args.full_res_mask_padding + else: + mask = None + + assert not ( (mask is not None and args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True" + + p.init_images = [init_image] + p.image_mask = mask + p.image_cfg_scale = args.pix2pix_img_cfg_scale + + # print dynamic table to cli + print_generate_table(args, anim_args, p) + + if is_controlnet_enabled(controlnet_args): + processed = process_img2img_with_controlnet(p, args, anim_args, loop_args, controlnet_args, root, frame) + else: + processed = processing.process_images(p) + + if root.initial_info == None: + root.initial_seed = processed.seed + root.initial_info = processed.info + + if root.first_frame == None: + root.first_frame = processed.images[0] + + results = processed.images[0] + + return results + +def print_generate_table(args, anim_args, p): + from rich.table import Table + from rich import box + table = Table(padding=0, box=box.ROUNDED) + field_names = ["Steps", "CFG"] + if anim_args.animation_mode != 'Interpolation': + field_names.append("Denoise") + field_names += ["Subseed", "Subs. str"] * (anim_args.enable_subseed_scheduling) + field_names += ["Sampler"] * anim_args.enable_sampler_scheduling + field_names += ["Checkpoint"] * anim_args.enable_checkpoint_scheduling + for field_name in field_names: + table.add_column(field_name, justify="center") + rows = [str(p.steps), str(p.cfg_scale)] + if anim_args.animation_mode != 'Interpolation': + rows.append(str(p.denoising_strength)) + rows += [str(p.subseed), str(p.subseed_strength)] * (anim_args.enable_subseed_scheduling) + rows += [p.sampler_name] * anim_args.enable_sampler_scheduling + rows += [str(args.checkpoint)] * anim_args.enable_checkpoint_scheduling + table.add_row(*rows) + + console.print(table) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/gradio_funcs.py b/extensions/deforum/scripts/deforum_helpers/gradio_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..2656647b66a0dca9c71759e98016251f22a02815 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/gradio_funcs.py @@ -0,0 +1,83 @@ +import gradio as gr +from .video_audio_utilities import extract_number, get_quick_vid_info + +def change_visibility_from_skip_video(choice): + return gr.update(visible=False) if choice else gr.update(visible=True) + +def update_r_upscale_factor(choice): + return gr.update(value='x4', choices = ['x4']) if choice != 'realesr-animevideov3' else gr.update(value='x2', choices = ['x2', 'x3', 'x4']) + +def change_perlin_visibility(choice): + return gr.update(visible=choice=="perlin") + +def change_color_coherence_video_every_N_frames_visibility(choice): + return gr.update(visible=choice=="Video Input") + +def change_seed_iter_visibility(choice): + return gr.update(visible=choice=="iter") + +def change_seed_schedule_visibility(choice): + return gr.update(visible=choice=="schedule") + +def disable_pers_flip_accord(choice): + return gr.update(visible=True) if choice in ['2D','3D'] else gr.update(visible=False) + +def change_max_frames_visibility(choice): + return gr.update(visible=choice != "Video Input") + +def change_diffusion_cadence_visibility(choice): + return gr.update(visible=choice not in ['Video Input', 'Interpolation']) + +def disble_3d_related_stuff(choice): + return gr.update(visible=False) if choice != '3D' else gr.update(visible=True) + +def enable_2d_related_stuff(choice): + return gr.update(visible=True) if choice == '2D' else gr.update(visible=False) + +def disable_by_interpolation(choice): + return gr.update(visible=False) if choice in ['Interpolation'] else gr.update(visible=True) + +def disable_by_video_input(choice): + return gr.update(visible=False) if choice in ['Video Input'] else gr.update(visible=True) + +def change_comp_mask_x_visibility(choice): + return gr.update(visible=choice != "None") + +def change_gif_button_visibility(choice): + return gr.update(visible=False, value=False) if int(choice) > 30 else gr.update(visible=True) + +def disable_by_hybrid_composite(choice): + return gr.update(visible=True) if choice else gr.update(visible=False) + +def disable_by_hybrid_composite_dynamic(choice, comp_mask_type): + if choice == True: + if comp_mask_type != 'None': + return gr.update(visible=True) + return gr.update(visible=False) + +def disable_by_comp_mask(choice): + return gr.update(visible=False) if choice == 'None' else gr.update(visible=True) + +def disable_by_non_optical_flow(choice): + return gr.update(visible=False) if choice != 'Optical Flow' else gr.update(visible=True) + +# Upscaling Gradio UI related funcs +def vid_upscale_gradio_update_stats(vid_path, upscale_factor): + if not vid_path: + return '---', '---', '---', '---' + factor = extract_number(upscale_factor) + fps, fcount, resolution = get_quick_vid_info(vid_path.name) + in_res_str = f"{resolution[0]}*{resolution[1]}" + out_res_str = f"{resolution[0] * factor}*{resolution[1] * factor}" + return fps, fcount, in_res_str, out_res_str +def update_upscale_out_res(in_res, upscale_factor): + if not in_res: + return '---' + factor = extract_number(upscale_factor) + w, h = [int(x) * factor for x in in_res.split('*')] + return f"{w}*{h}" +def update_upscale_out_res_by_model_name(in_res, upscale_model_name): + if not upscale_model_name or in_res == '---': + return '---' + factor = 2 if upscale_model_name == 'realesr-animevideov3' else 4 + return f"{int(in_res.split('*')[0]) * factor}*{int(in_res.split('*')[1]) * factor}" \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/human_masking.py b/extensions/deforum/scripts/deforum_helpers/human_masking.py new file mode 100644 index 0000000000000000000000000000000000000000..daabbe8f4436fdccba90c3a11c9439543446d187 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/human_masking.py @@ -0,0 +1,72 @@ +import os, cv2 +import torch +from pathlib import Path +from multiprocessing import freeze_support + +def extract_frames(input_video_path, output_imgs_path): + # Open the video file + vidcap = cv2.VideoCapture(input_video_path) + + # Get the total number of frames in the video + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Create the output directory if it does not exist + if not os.path.exists(output_imgs_path): + os.makedirs(output_imgs_path) + + # Extract the frames + for i in range(frame_count): + success, image = vidcap.read() + if success: + cv2.imwrite(os.path.join(output_imgs_path, f"frame{i}.png"), image) + print(f"{frame_count} frames extracted and saved to {output_imgs_path}") + + +def video2humanmasks(input_frames_path, output_folder_path, output_type, fps): + # freeze support is needed for video outputting + freeze_support() + + # check if input path exists and is a directory + if not os.path.exists(input_frames_path) or not os.path.isdir(input_frames_path): + raise ValueError("Invalid input path: {}".format(input_frames_path)) + + # check if output path exists and is a directory + if not os.path.exists(output_folder_path) or not os.path.isdir(output_folder_path): + raise ValueError("Invalid output path: {}".format(output_folder_path)) + + # check if output_type is valid + valid_output_types = ["video", "pngs", "both"] + if output_type.lower() not in valid_output_types: + raise ValueError("Invalid output type: {}. Must be one of {}".format(output_type, valid_output_types)) + + # try to predict where torch cache lives, so we can try and fetch models from cache in the next step + predicted_torch_model_cache_path = os.path.join(Path.home(), ".cache", "torch", "hub", "hithereai_RobustVideoMatting_master") + predicted_rvm_cache_testilfe = os.path.join(predicted_torch_model_cache_path, "hubconf.py") + + # try to fetch the models from cache, and only if it can't be find, download from the internet (to enable offline usage) + try: + # Try to fetch the models from cache + convert_video = torch.hub.load(predicted_torch_model_cache_path, "converter", source='local') + model = torch.hub.load(predicted_torch_model_cache_path, "mobilenetv3", source='local').cuda() + except: + # Download from the internet if not found in cache + convert_video = torch.hub.load("hithereai/RobustVideoMatting", "converter") + model = torch.hub.load("hithereai/RobustVideoMatting", "mobilenetv3").cuda() + + output_alpha_vid_path = os.path.join(output_folder_path, "human_masked_video.mp4") + # extract humans masks from the input folder' imgs. + # in this step PNGs will be extracted only if output_type is set to PNGs. Otherwise a video will be made, and in the case of Both, the video will be extracted in the next step to PNGs + convert_video( + model, + input_source=input_frames_path, # full path of the folder that contains all of the extracted input imgs + output_type='video' if output_type.upper() in ("VIDEO", "BOTH") else 'png_sequence', + output_alpha=output_alpha_vid_path if output_type.upper() in ("VIDEO", "BOTH") else output_folder_path, + output_video_mbps=4, + output_video_fps=fps, + downsample_ratio=None, # None for auto + seq_chunk=12, # Process n frames at once for better parallelism + progress=True # show extraction progress + ) + + if output_type.lower() == "both": + extract_frames(output_alpha_vid_path, output_folder_path) diff --git a/extensions/deforum/scripts/deforum_helpers/hybrid_video.py b/extensions/deforum/scripts/deforum_helpers/hybrid_video.py new file mode 100644 index 0000000000000000000000000000000000000000..76401712387cbda1bb29dbd6669fc9f774903c7e --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/hybrid_video.py @@ -0,0 +1,436 @@ +import cv2 +import os +import pathlib +import numpy as np +import random +from PIL import Image, ImageChops, ImageOps, ImageEnhance +from .video_audio_utilities import vid2frames, get_quick_vid_info, get_frame_name, get_next_frame +from .human_masking import video2humanmasks + +def delete_all_imgs_in_folder(folder_path): + files = list(pathlib.Path(folder_path).glob('*.jpg')) + files.extend(list(pathlib.Path(folder_path).glob('*.png'))) + for f in files: os.remove(f) + +def hybrid_generation(args, anim_args, root): + video_in_frame_path = os.path.join(args.outdir, 'inputframes') + hybrid_frame_path = os.path.join(args.outdir, 'hybridframes') + human_masks_path = os.path.join(args.outdir, 'human_masks') + + if anim_args.hybrid_generate_inputframes: + # create folders for the video input frames and optional hybrid frames to live in + os.makedirs(video_in_frame_path, exist_ok=True) + os.makedirs(hybrid_frame_path, exist_ok=True) + + # delete frames if overwrite = true + if anim_args.overwrite_extracted_frames: + delete_all_imgs_in_folder(hybrid_frame_path) + + # save the video frames from input video + print(f"Video to extract: {anim_args.video_init_path}") + print(f"Extracting video (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") + video_fps = vid2frames(video_path=anim_args.video_init_path, video_in_frame_path=video_in_frame_path, n=anim_args.extract_nth_frame, overwrite=anim_args.overwrite_extracted_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame) + + # extract alpha masks of humans from the extracted input video imgs + if anim_args.hybrid_generate_human_masks != "None": + # create a folder for the human masks imgs to live in + print(f"Checking /creating a folder for the human masks") + os.makedirs(human_masks_path, exist_ok=True) + + # delete frames if overwrite = true + if anim_args.overwrite_extracted_frames: + delete_all_imgs_in_folder(human_masks_path) + + # in case that generate_input_frames isn't selected, we won't get the video fps rate as vid2frames isn't called, So we'll check the video fps in here instead + if not anim_args.hybrid_generate_inputframes: + _, video_fps, _ = get_quick_vid_info(anim_args.video_init_path) + + # calculate the correct fps of the masked video according to the original video fps and 'extract_nth_frame' + output_fps = video_fps/anim_args.extract_nth_frame + + # generate the actual alpha masks from the input imgs + print(f"Extracting alpha humans masks from the input frames") + video2humanmasks(video_in_frame_path, human_masks_path, anim_args.hybrid_generate_human_masks, output_fps) + + # determine max frames from length of input frames + anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) + print(f"Using {anim_args.max_frames} input frames from {video_in_frame_path}...") + + # get sorted list of inputfiles + inputfiles = sorted(pathlib.Path(video_in_frame_path).glob('*.jpg')) + + # use first frame as init + if anim_args.hybrid_use_first_frame_as_init_image: + for f in inputfiles: + args.init_image = str(f) + args.use_init = True + print(f"Using init_image from video: {args.init_image}") + break + + return args, anim_args, inputfiles + +def hybrid_composite(args, anim_args, frame_idx, prev_img, depth_model, hybrid_comp_schedules, root): + video_frame = os.path.join(args.outdir, 'inputframes', get_frame_name(anim_args.video_init_path) + f"{frame_idx:05}.jpg") + video_depth_frame = os.path.join(args.outdir, 'hybridframes', get_frame_name(anim_args.video_init_path) + f"_vid_depth{frame_idx:05}.jpg") + depth_frame = os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx-1:05}.png") + mask_frame = os.path.join(args.outdir, 'hybridframes', get_frame_name(anim_args.video_init_path) + f"_mask{frame_idx:05}.jpg") + comp_frame = os.path.join(args.outdir, 'hybridframes', get_frame_name(anim_args.video_init_path) + f"_comp{frame_idx:05}.jpg") + prev_frame = os.path.join(args.outdir, 'hybridframes', get_frame_name(anim_args.video_init_path) + f"_prev{frame_idx:05}.jpg") + prev_img = cv2.cvtColor(prev_img, cv2.COLOR_BGR2RGB) + prev_img_hybrid = Image.fromarray(prev_img) + video_image = Image.open(video_frame) + video_image = video_image.resize((args.W, args.H), Image.Resampling.LANCZOS) + hybrid_mask = None + + # composite mask types + if anim_args.hybrid_comp_mask_type == 'Depth': # get depth from last generation + hybrid_mask = Image.open(depth_frame) + elif anim_args.hybrid_comp_mask_type == 'Video Depth': # get video depth + video_depth = depth_model.predict(np.array(video_image), anim_args, root.half_precision) + depth_model.save(video_depth_frame, video_depth) + hybrid_mask = Image.open(video_depth_frame) + elif anim_args.hybrid_comp_mask_type == 'Blend': # create blend mask image + hybrid_mask = Image.blend(ImageOps.grayscale(prev_img_hybrid), ImageOps.grayscale(video_image), hybrid_comp_schedules['mask_blend_alpha']) + elif anim_args.hybrid_comp_mask_type == 'Difference': # create difference mask image + hybrid_mask = ImageChops.difference(ImageOps.grayscale(prev_img_hybrid), ImageOps.grayscale(video_image)) + + # optionally invert mask, if mask type is defined + if anim_args.hybrid_comp_mask_inverse and anim_args.hybrid_comp_mask_type != "None": + hybrid_mask = ImageOps.invert(hybrid_mask) + + # if a mask type is selected, make composition + if hybrid_mask == None: + hybrid_comp = video_image + else: + # ensure grayscale + hybrid_mask = ImageOps.grayscale(hybrid_mask) + # equalization before + if anim_args.hybrid_comp_mask_equalize in ['Before', 'Both']: + hybrid_mask = ImageOps.equalize(hybrid_mask) + # contrast + hybrid_mask = ImageEnhance.Contrast(hybrid_mask).enhance(hybrid_comp_schedules['mask_contrast']) + # auto contrast with cutoffs lo/hi + if anim_args.hybrid_comp_mask_auto_contrast: + hybrid_mask = autocontrast_grayscale(np.array(hybrid_mask), hybrid_comp_schedules['mask_auto_contrast_cutoff_low'], hybrid_comp_schedules['mask_auto_contrast_cutoff_high']) + hybrid_mask = Image.fromarray(hybrid_mask) + hybrid_mask = ImageOps.grayscale(hybrid_mask) + if anim_args.hybrid_comp_save_extra_frames: + hybrid_mask.save(mask_frame) + # equalization after + if anim_args.hybrid_comp_mask_equalize in ['After', 'Both']: + hybrid_mask = ImageOps.equalize(hybrid_mask) + # do compositing and save + hybrid_comp = Image.composite(prev_img_hybrid, video_image, hybrid_mask) + if anim_args.hybrid_comp_save_extra_frames: + hybrid_comp.save(comp_frame) + + # final blend of composite with prev_img, or just a blend if no composite is selected + hybrid_blend = Image.blend(prev_img_hybrid, hybrid_comp, hybrid_comp_schedules['alpha']) + if anim_args.hybrid_comp_save_extra_frames: + hybrid_blend.save(prev_frame) + + prev_img = cv2.cvtColor(np.array(hybrid_blend), cv2.COLOR_RGB2BGR) + + # restore to np array and return + return args, prev_img + +def get_matrix_for_hybrid_motion(frame_idx, dimensions, inputfiles, hybrid_motion): + img1 = cv2.cvtColor(get_resized_image_from_filename(str(inputfiles[frame_idx-1]), dimensions), cv2.COLOR_BGR2GRAY) + img2 = cv2.cvtColor(get_resized_image_from_filename(str(inputfiles[frame_idx]), dimensions), cv2.COLOR_BGR2GRAY) + matrix = get_transformation_matrix_from_images(img1, img2, hybrid_motion) + print(f"Calculating {hybrid_motion} RANSAC matrix for frames {frame_idx} to {frame_idx+1}") + return matrix + +def get_matrix_for_hybrid_motion_prev(frame_idx, dimensions, inputfiles, prev_img, hybrid_motion): + # first handle invalid images from cadence by returning default matrix + height, width = prev_img.shape[:2] + if height == 0 or width == 0 or prev_img != np.uint8: + return get_hybrid_motion_default_matrix(hybrid_motion) + else: + prev_img_gray = cv2.cvtColor(prev_img, cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(get_resized_image_from_filename(str(inputfiles[frame_idx]), dimensions), cv2.COLOR_BGR2GRAY) + matrix = get_transformation_matrix_from_images(prev_img_gray, img, hybrid_motion) + print(f"Calculating {hybrid_motion} RANSAC matrix for frames {frame_idx} to {frame_idx+1}") + return matrix + +def get_flow_for_hybrid_motion(frame_idx, dimensions, inputfiles, hybrid_frame_path, method, do_flow_visualization=False): + print(f"Calculating {method} optical flow for frames {frame_idx} to {frame_idx+1}") + i1 = get_resized_image_from_filename(str(inputfiles[frame_idx]), dimensions) + i2 = get_resized_image_from_filename(str(inputfiles[frame_idx+1]), dimensions) + flow = get_flow_from_images(i1, i2, method) + if do_flow_visualization: + save_flow_visualization(frame_idx, dimensions, flow, inputfiles, hybrid_frame_path) + return flow + +def get_flow_for_hybrid_motion_prev(frame_idx, dimensions, inputfiles, hybrid_frame_path, prev_img, method, do_flow_visualization=False): + print(f"Calculating {method} optical flow for frames {frame_idx} to {frame_idx+1}") + # first handle invalid images from cadence by returning default matrix + height, width = prev_img.shape[:2] + if height == 0 or width == 0: + flow = get_hybrid_motion_default_flow(dimensions) + else: + i1 = prev_img.astype(np.uint8) + i2 = get_resized_image_from_filename(str(inputfiles[frame_idx]), dimensions) + flow = get_flow_from_images(i1, i2, method) + if do_flow_visualization: + save_flow_visualization(frame_idx, dimensions, flow, inputfiles, hybrid_frame_path) + return flow + +def image_transform_ransac(image_cv2, xform, hybrid_motion, border_mode=cv2.BORDER_REPLICATE): + if hybrid_motion == "Perspective": + return image_transform_perspective(image_cv2, xform, border_mode=border_mode) + else: # Affine + return image_transform_affine(image_cv2, xform, border_mode=border_mode) + +def image_transform_optical_flow(img, flow, border_mode=cv2.BORDER_REPLICATE, flow_reverse=False): + if not flow_reverse: + flow = -flow + h, w = img.shape[:2] + flow[:, :, 0] += np.arange(w) + flow[:, :, 1] += np.arange(h)[:,np.newaxis] + return remap(img, flow, border_mode) + +def image_transform_affine(image_cv2, xform, border_mode=cv2.BORDER_REPLICATE): + return cv2.warpAffine( + image_cv2, + xform, + (image_cv2.shape[1],image_cv2.shape[0]), + borderMode=border_mode + ) + +def image_transform_perspective(image_cv2, xform, border_mode=cv2.BORDER_REPLICATE): + return cv2.warpPerspective( + image_cv2, + xform, + (image_cv2.shape[1], image_cv2.shape[0]), + borderMode=border_mode + ) + +def get_hybrid_motion_default_matrix(hybrid_motion): + if hybrid_motion == "Perspective": + arr = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + else: + arr = np.array([[1., 0., 0.], [0., 1., 0.]]) + return arr + +def get_hybrid_motion_default_flow(dimensions): + cols, rows = dimensions + flow = np.zeros((rows, cols, 2), np.float32) + return flow + +def get_transformation_matrix_from_images(img1, img2, hybrid_motion, max_corners=200, quality_level=0.01, min_distance=30, block_size=3): + # Detect feature points in previous frame + prev_pts = cv2.goodFeaturesToTrack(img1, + maxCorners=max_corners, + qualityLevel=quality_level, + minDistance=min_distance, + blockSize=block_size) + + if prev_pts is None or len(prev_pts) < 8 or img1 is None or img2 is None: + return get_hybrid_motion_default_matrix(hybrid_motion) + + # Get optical flow + curr_pts, status, err = cv2.calcOpticalFlowPyrLK(img1, img2, prev_pts, None) + + # Filter only valid points + idx = np.where(status==1)[0] + prev_pts = prev_pts[idx] + curr_pts = curr_pts[idx] + + if len(prev_pts) < 8 or len(curr_pts) < 8: + return get_hybrid_motion_default_matrix(hybrid_motion) + + if hybrid_motion == "Perspective": # Perspective - Find the transformation between points + transformation_matrix, mask = cv2.findHomography(prev_pts, curr_pts, cv2.RANSAC, 5.0) + return transformation_matrix + else: # Affine - Compute a rigid transformation (without depth, only scale + rotation + translation) + transformation_rigid_matrix, rigid_mask = cv2.estimateAffinePartial2D(prev_pts, curr_pts) + return transformation_rigid_matrix + +def get_flow_from_images(i1, i2, method): + if method =="DIS Medium": + r = get_flow_from_images_DIS(i1, i2, cv2.DISOPTICAL_FLOW_PRESET_MEDIUM) + elif method =="DIS Fast": + r = get_flow_from_images_DIS(i1, i2, cv2.DISOPTICAL_FLOW_PRESET_FAST) + elif method =="DIS UltraFast": + r = get_flow_from_images_DIS(i1, i2, cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) + elif method == "DenseRLOF": # requires running opencv-contrib-python (full opencv) INSTEAD of opencv-python + r = get_flow_from_images_Dense_RLOF(i1, i2) + elif method == "SF": # requires running opencv-contrib-python (full opencv) INSTEAD of opencv-python + r = get_flow_from_images_SF(i1, i2) + elif method =="Farneback Fine": + r = get_flow_from_images_Farneback(i1, i2, 'fine') + else: # Farneback Normal: + r = get_flow_from_images_Farneback(i1, i2) + return r + +def get_flow_from_images_DIS(i1, i2, preset): + i1 = cv2.cvtColor(i1, cv2.COLOR_BGR2GRAY) + i2 = cv2.cvtColor(i2, cv2.COLOR_BGR2GRAY) + dis=cv2.DISOpticalFlow_create(preset) + return dis.calc(i1, i2, None) + +def get_flow_from_images_Dense_RLOF(i1, i2, last_flow=None): + return cv2.optflow.calcOpticalFlowDenseRLOF(i1, i2, flow = last_flow) + +def get_flow_from_images_SF(i1, i2, last_flow=None, layers = 3, averaging_block_size = 2, max_flow = 4): + return cv2.optflow.calcOpticalFlowSF(i1, i2, layers, averaging_block_size, max_flow) + +def get_flow_from_images_Farneback(i1, i2, preset="normal", last_flow=None, pyr_scale = 0.5, levels = 3, winsize = 15, iterations = 3, poly_n = 5, poly_sigma = 1.2, flags = 0): + flags = cv2.OPTFLOW_FARNEBACK_GAUSSIAN # Specify the operation flags + pyr_scale = 0.5 # The image scale (<1) to build pyramids for each image + if preset == "fine": + levels = 13 # The number of pyramid layers, including the initial image + winsize = 77 # The averaging window size + iterations = 13 # The number of iterations at each pyramid level + poly_n = 15 # The size of the pixel neighborhood used to find polynomial expansion in each pixel + poly_sigma = 0.8 # The standard deviation of the Gaussian used to smooth derivatives used as a basis for the polynomial expansion + else: # "normal" + levels = 5 # The number of pyramid layers, including the initial image + winsize = 21 # The averaging window size + iterations = 5 # The number of iterations at each pyramid level + poly_n = 7 # The size of the pixel neighborhood used to find polynomial expansion in each pixel + poly_sigma = 1.2 # The standard deviation of the Gaussian used to smooth derivatives used as a basis for the polynomial expansion + i1 = cv2.cvtColor(i1, cv2.COLOR_BGR2GRAY) + i2 = cv2.cvtColor(i2, cv2.COLOR_BGR2GRAY) + flags = 0 # flags = cv2.OPTFLOW_USE_INITIAL_FLOW + flow = cv2.calcOpticalFlowFarneback(i1, i2, last_flow, pyr_scale, levels, winsize, iterations, poly_n, poly_sigma, flags) + return flow + +def save_flow_visualization(frame_idx, dimensions, flow, inputfiles, hybrid_frame_path): + flow_img_file = os.path.join(hybrid_frame_path, f"flow{frame_idx:05}.jpg") + flow_img = cv2.imread(str(inputfiles[frame_idx])) + flow_img = cv2.resize(flow_img, (dimensions[0], dimensions[1]), cv2.INTER_AREA) + flow_img = cv2.cvtColor(flow_img, cv2.COLOR_RGB2GRAY) + flow_img = cv2.cvtColor(flow_img, cv2.COLOR_GRAY2BGR) + flow_img = draw_flow_lines_in_grid_in_color(flow_img, flow) + flow_img = cv2.cvtColor(flow_img, cv2.COLOR_BGR2RGB) + cv2.imwrite(flow_img_file, flow_img) + print(f"Saved optical flow visualization: {flow_img_file}") + +def draw_flow_lines_in_grid_in_color(img, flow, step=8, magnitude_multiplier=1, min_magnitude = 1, max_magnitude = 10000): + flow = flow * magnitude_multiplier + h, w = img.shape[:2] + y, x = np.mgrid[step/2:h:step, step/2:w:step].reshape(2,-1).astype(int) + fx, fy = flow[y,x].T + lines = np.vstack([x, y, x+fx, y+fy]).T.reshape(-1, 2, 2) + lines = np.int32(lines + 0.5) + vis = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR) + + mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1]) + hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) + hsv[...,0] = ang*180/np.pi/2 + hsv[...,1] = 255 + hsv[...,2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + vis = cv2.add(vis, bgr) + + # Iterate through the lines + for (x1, y1), (x2, y2) in lines: + # Calculate the magnitude of the line + magnitude = np.sqrt((x2 - x1)**2 + (y2 - y1)**2) + + # Only draw the line if it falls within the magnitude range + if min_magnitude <= magnitude <= max_magnitude: + b = int(bgr[y1, x1, 0]) + g = int(bgr[y1, x1, 1]) + r = int(bgr[y1, x1, 2]) + color = (b, g, r) + cv2.arrowedLine(vis, (x1, y1), (x2, y2), color, thickness=1, tipLength=0.1) + return vis + +def draw_flow_lines_in_color(img, flow, threshold=3, magnitude_multiplier=1, min_magnitude = 0, max_magnitude = 10000): + # h, w = img.shape[:2] + vis = img.copy() # Create a copy of the input image + + # Find the locations in the flow field where the magnitude of the flow is greater than the threshold + mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1]) + idx = np.where(mag > threshold) + + # Create HSV image + hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) + hsv[...,0] = ang*180/np.pi/2 + hsv[...,1] = 255 + hsv[...,2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + + # Convert HSV image to BGR + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + + # Add color from bgr + vis = cv2.add(vis, bgr) + + # Draw an arrow at each of these locations to indicate the direction of the flow + for i, (y, x) in enumerate(zip(idx[0], idx[1])): + # Calculate the magnitude of the line + x2 = x + magnitude_multiplier * int(flow[y, x, 0]) + y2 = y + magnitude_multiplier * int(flow[y, x, 1]) + magnitude = np.sqrt((x2 - x)**2 + (y2 - y)**2) + + # Only draw the line if it falls within the magnitude range + if min_magnitude <= magnitude <= max_magnitude: + if i % random.randint(100, 200) == 0: + b = int(bgr[y, x, 0]) + g = int(bgr[y, x, 1]) + r = int(bgr[y, x, 2]) + color = (b, g, r) + cv2.arrowedLine(vis, (x, y), (x2, y2), color, thickness=1, tipLength=0.25) + + return vis + +def autocontrast_grayscale(image, low_cutoff=0, high_cutoff=100): + # Perform autocontrast on a grayscale np array image. + # Find the minimum and maximum values in the image + min_val = np.percentile(image, low_cutoff) + max_val = np.percentile(image, high_cutoff) + + # Scale the image so that the minimum value is 0 and the maximum value is 255 + image = 255 * (image - min_val) / (max_val - min_val) + + # Clip values that fall outside the range [0, 255] + image = np.clip(image, 0, 255) + + return image + +def get_resized_image_from_filename(im, dimensions): + img = cv2.imread(im) + return cv2.resize(img, (dimensions[0], dimensions[1]), cv2.INTER_AREA) + +def remap(img, flow, border_mode = cv2.BORDER_REFLECT_101): + # copyMakeBorder doesn't support wrap, but supports replicate. Replaces wrap with reflect101. + if border_mode == cv2.BORDER_WRAP: + border_mode = cv2.BORDER_REFLECT_101 + h, w = img.shape[:2] + displacement = int(h * 0.25), int(w * 0.25) + larger_img = cv2.copyMakeBorder(img, displacement[0], displacement[0], displacement[1], displacement[1], border_mode) + lh, lw = larger_img.shape[:2] + larger_flow = extend_flow(flow, lw, lh) + remapped_img = cv2.remap(larger_img, larger_flow, None, cv2.INTER_LINEAR, border_mode) + output_img = center_crop_image(remapped_img, w, h) + return output_img + +def center_crop_image(img, w, h): + y, x, _ = img.shape + width_indent = int((x - w) / 2) + height_indent = int((y - h) / 2) + cropped_img = img[height_indent:y-height_indent, width_indent:x-width_indent] + return cropped_img + +def extend_flow(flow, w, h): + # Get the shape of the original flow image + flow_h, flow_w = flow.shape[:2] + # Calculate the position of the image in the new image + x_offset = int((w - flow_w) / 2) + y_offset = int((h - flow_h) / 2) + # Generate the X and Y grids + x_grid, y_grid = np.meshgrid(np.arange(w), np.arange(h)) + # Create the new flow image and set it to the X and Y grids + new_flow = np.dstack((x_grid, y_grid)).astype(np.float32) + # Shift the values of the original flow by the size of the border + flow[:,:,0] += x_offset + flow[:,:,1] += y_offset + # Overwrite the middle of the grid with the original flow + new_flow[y_offset:y_offset+flow_h, x_offset:x_offset+flow_w, :] = flow + # Return the extended image + return new_flow + \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/image_sharpening.py b/extensions/deforum/scripts/deforum_helpers/image_sharpening.py new file mode 100644 index 0000000000000000000000000000000000000000..aea81449e3c61808c44de72450ed4d1c0ce88fc4 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/image_sharpening.py @@ -0,0 +1,22 @@ +import cv2 +import numpy as np + +def unsharp_mask(img, kernel_size=(5, 5), sigma=1.0, amount=1.0, threshold=0, mask=None): + if amount == 0: + return img + # Return a sharpened version of the image, using an unsharp mask. + # If mask is not None, only areas under mask are handled + blurred = cv2.GaussianBlur(img, kernel_size, sigma) + sharpened = float(amount + 1) * img - float(amount) * blurred + sharpened = np.maximum(sharpened, np.zeros(sharpened.shape)) + sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape)) + sharpened = sharpened.round().astype(np.uint8) + if threshold > 0: + low_contrast_mask = np.absolute(img - blurred) < threshold + np.copyto(sharpened, img, where=low_contrast_mask) + if mask is not None: + mask = np.array(mask) + masked_sharpened = cv2.bitwise_and(sharpened, sharpened, mask=mask) + masked_img = cv2.bitwise_and(img, img, mask=255-mask) + sharpened = cv2.add(masked_img, masked_sharpened) + return sharpened diff --git a/extensions/deforum/scripts/deforum_helpers/load_images.py b/extensions/deforum/scripts/deforum_helpers/load_images.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc5726f8aed86fb190ae15aa6098c3bcac8ec2c --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/load_images.py @@ -0,0 +1,102 @@ +import requests +import os +from PIL import Image, ImageOps +import cv2 +import numpy as np +import socket +import torchvision.transforms.functional as TF + +def load_img(path : str, shape=None, use_alpha_as_mask=False): + # use_alpha_as_mask: Read the alpha channel of the image as the mask image + image = load_image(path) + if use_alpha_as_mask: + image = image.convert('RGBA') + else: + image = image.convert('RGB') + + if shape is not None: + image = image.resize(shape, resample=Image.LANCZOS) + + mask_image = None + if use_alpha_as_mask: + # Split alpha channel into a mask_image + red, green, blue, alpha = Image.Image.split(image) + mask_image = alpha.convert('L') + image = image.convert('RGB') + + # check using init image alpha as mask if mask is not blank + extrema = mask_image.getextrema() + if (extrema == (0,0)) or extrema == (255,255): + print("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.") + print("ignoring alpha as mask.") + mask_image = None + + return image, mask_image + +def load_image(image_path :str): + image = None + if image_path.startswith('http://') or image_path.startswith('https://'): + try: + host = socket.gethostbyname("www.google.com") + s = socket.create_connection((host, 80), 2) + s.close() + except: + raise ConnectionError("There is no active internet connection available - please use local masks and init files only.") + + try: + response = requests.get(image_path, stream=True) + except requests.exceptions.RequestException as e: + raise ConnectionError("Failed to download image due to no internet connection. Error: {}".format(e)) + if response.status_code == 404 or response.status_code != 200: + raise ConnectionError("Init image url or mask image url is not valid") + image = Image.open(response.raw).convert('RGB') + else: + if not os.path.exists(image_path): + raise RuntimeError("Init image path or mask image path is not valid") + image = Image.open(image_path).convert('RGB') + + return image + +def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0): + """ + prepares mask for use in webui + """ + if isinstance(mask_input, Image.Image): + mask = mask_input + else : + mask = load_image(mask_input) + mask = mask.resize(mask_shape, resample=Image.LANCZOS) + if mask_brightness_adjust != 1: + mask = TF.adjust_brightness(mask, mask_brightness_adjust) + if mask_contrast_adjust != 1: + mask = TF.adjust_contrast(mask, mask_contrast_adjust) + mask = mask.convert('L') + return mask + +def check_mask_for_errors(mask_input, invert_mask=False): + extrema = mask_input.getextrema() + if (invert_mask): + if extrema == (255,255): + print("after inverting mask will be blank. ignoring mask") + return None + elif extrema == (0,0): + print("mask is blank. ignoring mask") + return None + else: + return mask_input + +def get_mask(args): + return check_mask_for_errors( + prepare_mask(args.mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust) + ) + +def get_mask_from_file(mask_file, args): + return check_mask_for_errors( + prepare_mask(mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust) + ) + +def blank_if_none(mask, w, h, mode): + return Image.new(mode, (w, h), (0)) if mask is None else mask + +def none_if_blank(mask): + return None if mask.getextrema() == (0,0) else mask diff --git a/extensions/deforum/scripts/deforum_helpers/noise.py b/extensions/deforum/scripts/deforum_helpers/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..768f0e9f73ea50b3262c643b712730f614488895 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/noise.py @@ -0,0 +1,64 @@ +import torch +import numpy as np +from PIL import ImageOps +import math +from .animation import sample_to_cv2 +import cv2 + +deforum_noise_gen = torch.Generator(device='cpu') + +# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 +def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1) % 1 + angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=deforum_noise_gen) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1) + + tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]) + t = fade(grid[:shape[0], :shape[1]]) + return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + +def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): + noise = torch.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(int(octaves)): + noise += amplitude * rand_perlin_2d(shape, (frequency*res[0], frequency*res[1])) + frequency *= 2 + amplitude *= persistence + return noise + +def condition_noise_mask(noise_mask, invert_mask = False): + if invert_mask: + noise_mask = ImageOps.invert(noise_mask) + noise_mask = np.array(noise_mask.convert("L")) + noise_mask = noise_mask.astype(np.float32) / 255.0 + noise_mask = np.around(noise_mask, decimals=0) + noise_mask = torch.from_numpy(noise_mask) + #noise_mask = torch.round(noise_mask) + return noise_mask + +def add_noise(sample, noise_amt: float, seed: int, noise_type: str, noise_args, noise_mask = None, invert_mask = False): + deforum_noise_gen.manual_seed(seed) # Reproducibility + sample2dshape = (sample.shape[0], sample.shape[1]) #sample is cv2, so height - width + noise = torch.randn((sample.shape[2], sample.shape[0], sample.shape[1]), generator=deforum_noise_gen) # White noise + if noise_type == 'perlin': + # rand_perlin_2d_octaves is between -1 and 1, so we need to shift it to be between 0 and 1 + # print(sample.shape) + noise = noise * ((rand_perlin_2d_octaves(sample2dshape, (int(noise_args[0]), int(noise_args[1])), octaves=noise_args[2], persistence=noise_args[3]) + torch.ones(sample2dshape)) / 2) + if noise_mask is not None: + noise_mask = condition_noise_mask(noise_mask, invert_mask) + noise_to_add = sample_to_cv2(noise * noise_mask) + else: + noise_to_add = sample_to_cv2(noise) + sample = cv2.addWeighted(sample, 1-noise_amt, noise_to_add, noise_amt, 0) + + return sample diff --git a/extensions/deforum/scripts/deforum_helpers/parseq_adapter.py b/extensions/deforum/scripts/deforum_helpers/parseq_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..314b594ac25792358807bdb602cae7f97387edf4 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/parseq_adapter.py @@ -0,0 +1,164 @@ +import copy +import json +import logging +import operator +from operator import itemgetter + +import numpy as np +import pandas as pd +import requests + +from .animation_key_frames import DeformAnimKeys + +logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) + + +class ParseqAnimKeys(): + def __init__(self, parseq_args, anim_args): + + # Resolve manifest either directly from supplied value + # or via supplied URL + manifestOrUrl = parseq_args.parseq_manifest.strip() + if (manifestOrUrl.startswith('http')): + logging.info(f"Loading Parseq manifest from URL: {manifestOrUrl}") + try: + body = requests.get(manifestOrUrl).text + logging.debug(f"Loaded remote manifest: {body}") + self.parseq_json = json.loads(body) + + # Add the parseq manifest without the detailed frame data to parseq_args. + # This ensures it will be saved in the settings file, so that you can always + # see exactly what parseq prompts and keyframes were used, even if what the URL + # points to changes. + parseq_args.fetched_parseq_manifest_summary = copy.deepcopy(self.parseq_json) + if parseq_args.fetched_parseq_manifest_summary['rendered_frames']: + del parseq_args.fetched_parseq_manifest_summary['rendered_frames'] + if parseq_args.fetched_parseq_manifest_summary['rendered_frames_meta']: + del parseq_args.fetched_parseq_manifest_summary['rendered_frames_meta'] + + except Exception as e: + logging.error(f"Unable to load Parseq manifest from URL: {manifestOrUrl}") + raise e + else: + self.parseq_json = json.loads(manifestOrUrl) + + self.default_anim_keys = DeformAnimKeys(anim_args) + self.rendered_frames = self.parseq_json['rendered_frames'] + self.max_frame = self.get_max('frame') + count_defined_frames = len(self.rendered_frames) + expected_defined_frames = self.max_frame+1 # frames are 0-indexed + + self.required_frames = anim_args.max_frames + + if (expected_defined_frames != count_defined_frames): + logging.warning(f"There may be duplicated or missing frame data in the Parseq input: expected {expected_defined_frames} frames including frame 0 because the highest frame number is {self.max_frame}, but there are {count_defined_frames} frames defined.") + + if (anim_args.max_frames > count_defined_frames): + logging.info(f"Parseq data defines {count_defined_frames} frames, but the requested animation is {anim_args.max_frames} frames. The last Parseq frame definition will be duplicated to match the expected frame count.") + if (anim_args.max_frames < count_defined_frames): + logging.info(f"Parseq data defines {count_defined_frames} frames, but the requested animation is {anim_args.max_frames} frames. The last Parseq frame definitions will be ignored.") + else: + logging.info(f"Parseq data defines {count_defined_frames} frames.") + + # Parseq treats input values as absolute values. So if you want to + # progressively rotate 180 degrees over 4 frames, you specify: 45, 90, 135, 180. + # However, many animation parameters are relative to the previous frame if there is enough + # loopback strength. So if you want to rotate 180 degrees over 5 frames, the animation engine expects: + # 45, 45, 45, 45. Therefore, for such parameter, we use the fact that Parseq supplies delta values. + optional_delta = '_delta' if parseq_args.parseq_use_deltas else '' + self.angle_series = self.parseq_to_anim_series('angle' + optional_delta) + self.zoom_series = self.parseq_to_anim_series('zoom' + optional_delta) + self.translation_x_series = self.parseq_to_anim_series('translation_x' + optional_delta) + self.translation_y_series = self.parseq_to_anim_series('translation_y' + optional_delta) + self.translation_z_series = self.parseq_to_anim_series('translation_z' + optional_delta) + self.rotation_3d_x_series = self.parseq_to_anim_series('rotation_3d_x' + optional_delta) + self.rotation_3d_y_series = self.parseq_to_anim_series('rotation_3d_y' + optional_delta) + self.rotation_3d_z_series = self.parseq_to_anim_series('rotation_3d_z' + optional_delta) + self.perspective_flip_theta_series = self.parseq_to_anim_series('perspective_flip_theta' + optional_delta) + self.perspective_flip_phi_series = self.parseq_to_anim_series('perspective_flip_phi' + optional_delta) + self.perspective_flip_gamma_series = self.parseq_to_anim_series('perspective_flip_gamma' + optional_delta) + + # Non-motion animation args + self.perspective_flip_fv_series = self.parseq_to_anim_series('perspective_flip_fv') + self.noise_schedule_series = self.parseq_to_anim_series('noise') + self.strength_schedule_series = self.parseq_to_anim_series('strength') + self.sampler_schedule_series = self.parseq_to_anim_series('sampler_schedule') + self.contrast_schedule_series = self.parseq_to_anim_series('contrast') + self.cfg_scale_schedule_series = self.parseq_to_anim_series('scale') + self.steps_schedule_series = self.parseq_to_anim_series("steps_schedule") + self.seed_schedule_series = self.parseq_to_anim_series('seed') + self.fov_series = self.parseq_to_anim_series('fov') + self.near_series = self.parseq_to_anim_series('near') + self.far_series = self.parseq_to_anim_series('far') + self.prompts = self.parseq_to_anim_series('deforum_prompt') # formatted as "{positive} --neg {negative}" + self.subseed_series = self.parseq_to_anim_series('subseed') + self.subseed_strength_series = self.parseq_to_anim_series('subseed_strength') + self.kernel_schedule_series = self.parseq_to_anim_series('antiblur_kernel') + self.sigma_schedule_series = self.parseq_to_anim_series('antiblur_sigma') + self.amount_schedule_series = self.parseq_to_anim_series('antiblur_amount') + self.threshold_schedule_series = self.parseq_to_anim_series('antiblur_threshold') + + # Config: + # TODO this is currently ignored. User must ensure the output FPS set in parseq + # matches the one set in Deforum to avoid unexpected results. + self.config_output_fps = self.parseq_json['options']['output_fps'] + + def get_max(self, seriesName): + return max(self.rendered_frames, key=itemgetter(seriesName))[seriesName] + + def parseq_to_anim_series(self, seriesName): + + # Check if valus is present in first frame of JSON data. If not, assume it's undefined. + # The Parseq contract is that the first frame (at least) must define values for all fields. + try: + if self.rendered_frames[0][seriesName] is not None: + logging.info(f"Found {seriesName} in first frame of Parseq data. Assuming it's defined.") + except KeyError: + return None + + key_frame_series = pd.Series([np.nan for a in range(self.required_frames)]) + + for frame in self.rendered_frames: + frame_idx = frame['frame'] + if frame_idx < self.required_frames: + if not np.isnan(key_frame_series[frame_idx]): + logging.warning(f"Duplicate frame definition {frame_idx} detected for data {seriesName}. Latest wins.") + key_frame_series[frame_idx] = frame[seriesName] + + # If the animation will have more frames than Parseq defines, + # duplicate final value to match the required frame count. + while (frame_idx < self.required_frames): + key_frame_series[frame_idx] = operator.itemgetter(-1)(self.rendered_frames)[seriesName] + frame_idx += 1 + + return key_frame_series + + # fallback to anim_args if the series is not defined in the Parseq data + def __getattribute__(inst, name): + try: + definedField = super(ParseqAnimKeys, inst).__getattribute__(name) + except AttributeError: + # No field with this name has been explicitly extracted from the JSON data. + # It must be a new parameter. Let's see if it's in the raw JSON. + + # parseq doesn't use _series, _schedule or _schedule_series suffixes in the + # JSON data - remove them. + strippableSuffixes = ['_series', '_schedule'] + parseqName = name + while any(parseqName.endswith(suffix) for suffix in strippableSuffixes): + for suffix in strippableSuffixes: + if parseqName.endswith(suffix): + parseqName = parseqName[:-len(suffix)] + + # returns None if not defined in Parseq JSON data + definedField = inst.parseq_to_anim_series(parseqName) + if (definedField is not None): + # add the field to the instance so we don't compute it again. + setattr(inst, name, definedField) + + if (definedField is not None): + return definedField + else: + logging.info(f"Data for {name} not defined in Parseq data (looked for: '{parseqName}'). Falling back to standard Deforum values.") + return getattr(inst.default_anim_keys, name) + diff --git a/extensions/deforum/scripts/deforum_helpers/prompt.py b/extensions/deforum/scripts/deforum_helpers/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..f81aa7c281376933c95c854ed2ecc8ae99ad92a7 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/prompt.py @@ -0,0 +1,113 @@ +import re + +def check_is_number(value): + float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' + return re.match(float_pattern, value) + +def parse_weight(match, frame = 0)->float: + import numexpr + w_raw = match.group("weight") + if w_raw == None: + return 1 + if check_is_number(w_raw): + return float(w_raw) + else: + t = frame + if len(w_raw) < 3: + print('the value inside `-characters cannot represent a math function') + return 1 + return float(numexpr.evaluate(w_raw[1:-1])) + +def split_weighted_subprompts(text, frame = 0): + """ + splits the prompt based on deforum webui implementation, moved from generate.py + """ + math_parser = re.compile(""" + (?P( + `[\S\s]*?`# a math function wrapped in `-characters + )) + """, re.VERBOSE) + + parsed_prompt = re.sub(math_parser, lambda m: str(parse_weight(m, frame)), text) + + negative_prompts = [] + positive_prompts = [] + + prompt_split = parsed_prompt.split("--neg") + if len(prompt_split) > 1: + positive_prompts, negative_prompts = parsed_prompt.split("--neg") #TODO: add --neg to vanilla Deforum for compat + else: + positive_prompts = prompt_split[0] + negative_prompts = "" + + return positive_prompts, negative_prompts + +def interpolate_prompts(animation_prompts, max_frames): + import numpy as np + import pandas as pd + # Get prompts sorted by keyframe + sorted_prompts = sorted(animation_prompts.items(), key=lambda item: int(item[0])) + + # Setup container for interpolated prompts + prompt_series = pd.Series([np.nan for a in range(max_frames)]) + + # For every keyframe prompt except the last + for i in range(0,len(sorted_prompts)-1): + + # Get current and next keyframe + current_frame = int(sorted_prompts[i][0]) + next_frame = int(sorted_prompts[i+1][0]) + + # Ensure there's no weird ordering issues or duplication in the animation prompts + # (unlikely because we sort above, and the json parser will strip dupes) + if current_frame>=next_frame: + print(f"WARNING: Sequential prompt keyframes {i}:{current_frame} and {i+1}:{next_frame} are not monotonously increasing; skipping interpolation.") + continue + + # Get current and next keyframes' positive and negative prompts (if any) + current_prompt = sorted_prompts[i][1] + next_prompt = sorted_prompts[i+1][1] + current_positive, current_negative, *_ = current_prompt.split("--neg") + [None] + next_positive, next_negative, *_ = next_prompt.split("--neg") + [None] + + # Calculate how much to shift the weight from current to next prompt at each frame + weight_step = 1/(next_frame-current_frame) + + # Apply weighted prompt interpolation for each frame between current and next keyframe + # using the syntax: prompt1 :weight1 AND prompt1 :weight2 --neg nprompt1 :weight1 AND nprompt1 :weight2 + # (See: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#composable-diffusion ) + for f in range(current_frame,next_frame): + next_weight = weight_step * (f-current_frame) + current_weight = 1 - next_weight + + # We will build the prompt incrementally depending on which prompts are present + prompt_series[f] = '' + + # Cater for the case where neither, either or both current & next have positive prompts: + if current_positive: + prompt_series[f] += f"{current_positive} :{current_weight}" + if current_positive and next_positive: + prompt_series[f] += f" AND " + if next_positive: + prompt_series[f] += f"{next_positive} :{next_weight}" + + # Cater for the case where neither, either or both current & next have negative prompts: + if current_negative or next_negative: + prompt_series[f] += " --neg " + if current_negative: + prompt_series[f] += f" {current_negative} :{current_weight}" + if current_negative and next_negative: + prompt_series[f] += f" AND " + if next_negative: + prompt_series[f] += f" {next_negative} :{next_weight}" + + # Set explicitly declared keyframe prompts (overwriting interpolated values at the keyframe idx). This ensures: + # - That final prompt is set, and + # - Gives us a chance to emit warnings if any keyframe prompts are already using composable diffusion + for i, prompt in animation_prompts.items(): + prompt_series[int(i)] = prompt + if ' AND ' in prompt: + print(f"WARNING: keyframe {i}'s prompt is using composable diffusion (aka the 'AND' keyword). This will cause unexpected behaviour with interpolation.") + + # Return the filled series, in case max_frames is greater than the last keyframe or any ranges were skipped. + return prompt_series.ffill().bfill() diff --git a/extensions/deforum/scripts/deforum_helpers/render.py b/extensions/deforum/scripts/deforum_helpers/render.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3d141f3e00216b530d05c205c5f94f0ad814ab --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/render.py @@ -0,0 +1,507 @@ +import os +import json +import pandas as pd +import cv2 +import numpy as np +from PIL import Image, ImageOps +from .rich import console + +from .generate import generate +from .noise import add_noise +from .animation import sample_from_cv2, sample_to_cv2, anim_frame_warp +from .animation_key_frames import DeformAnimKeys, LooperAnimKeys +from .video_audio_utilities import get_frame_name, get_next_frame +from .depth import DepthModel +from .colors import maintain_colors +from .parseq_adapter import ParseqAnimKeys +from .seed import next_seed +from .blank_frame_reroll import blank_frame_reroll +from .image_sharpening import unsharp_mask +from .load_images import get_mask, load_img, get_mask_from_file +from .hybrid_video import hybrid_generation, hybrid_composite +from .hybrid_video import get_matrix_for_hybrid_motion, get_matrix_for_hybrid_motion_prev, get_flow_for_hybrid_motion, get_flow_for_hybrid_motion_prev, image_transform_ransac, image_transform_optical_flow +from .save_images import save_image +from .composable_masks import compose_mask_with_check +from .settings import get_keys_to_exclude +from .deforum_controlnet import unpack_controlnet_vids, is_controlnet_enabled +# Webui +from modules.shared import opts, cmd_opts, state, sd_model +from modules import lowvram, devices, sd_hijack + +def render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root): + # handle hybrid video generation + if anim_args.animation_mode in ['2D','3D']: + if anim_args.hybrid_composite or anim_args.hybrid_motion in ['Affine', 'Perspective', 'Optical Flow']: + args, anim_args, inputfiles = hybrid_generation(args, anim_args, root) + # path required by hybrid functions, even if hybrid_comp_save_extra_frames is False + hybrid_frame_path = os.path.join(args.outdir, 'hybridframes') + + # handle controlnet video input frames generation + if is_controlnet_enabled(controlnet_args): + unpack_controlnet_vids(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root) + + # use parseq if manifest is provided + use_parseq = parseq_args.parseq_manifest != None and parseq_args.parseq_manifest.strip() + # expand key frame strings to values + keys = DeformAnimKeys(anim_args) if not use_parseq else ParseqAnimKeys(parseq_args, anim_args) + loopSchedulesAndData = LooperAnimKeys(loop_args, anim_args) + # resume animation + start_frame = 0 + if anim_args.resume_from_timestring: + for tmp in os.listdir(args.outdir): + if ".txt" in tmp : + pass + else: + filename = tmp.split("_") + # don't use saved depth maps to count number of frames + if anim_args.resume_timestring in filename and "depth" not in filename: + start_frame += 1 + #start_frame = start_frame - 1 + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to:\n{args.outdir}") + + # save settings for the batch + exclude_keys = get_keys_to_exclude('general') + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + args.__dict__["prompts"] = animation_prompts + s = {} + for d in [dict(args.__dict__), dict(anim_args.__dict__), dict(parseq_args.__dict__), dict(loop_args.__dict__)]: + for key, value in d.items(): + if key not in exclude_keys: + s[key] = value + json.dump(s, f, ensure_ascii=False, indent=4) + + # resume from timestring + if anim_args.resume_from_timestring: + args.timestring = anim_args.resume_timestring + + # Always enable pseudo-3d with parseq. No need for an extra toggle: + # Whether it's used or not in practice is defined by the schedules + if use_parseq: + anim_args.flip_2d_perspective = True + + # expand prompts out to per-frame + if use_parseq: + prompt_series = keys.prompts + else: + prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + for i, prompt in animation_prompts.items(): + prompt_series[int(i)] = prompt + prompt_series = prompt_series.ffill().bfill() + + # check for video inits + using_vid_init = anim_args.animation_mode == 'Video Input' + + # load depth model for 3D + predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps + predict_depths = predict_depths or (anim_args.hybrid_composite and anim_args.hybrid_comp_mask_type in ['Depth','Video Depth']) + if predict_depths: + depth_model = DepthModel('cpu' if cmd_opts.lowvram or cmd_opts.medvram else root.device) + depth_model.load_midas(root.models_path, root.half_precision) + if anim_args.midas_weight < 1.0: + depth_model.load_adabins(root.models_path) + # depth-based hybrid composite mask requires saved depth maps + if anim_args.hybrid_composite and anim_args.hybrid_comp_mask_type =='Depth': + anim_args.save_depth_maps = True + else: + depth_model = None + anim_args.save_depth_maps = False + + # state for interpolating between diffusion steps + turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) + turbo_prev_image, turbo_prev_frame_idx = None, 0 + turbo_next_image, turbo_next_frame_idx = None, 0 + + # resume animation + prev_img = None + color_match_sample = None + if anim_args.resume_from_timestring: + last_frame = start_frame-1 + if turbo_steps > 1: + last_frame -= last_frame%turbo_steps + path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") + img = cv2.imread(path) + #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Changed the colors on resume + prev_img = img + if anim_args.color_coherence != 'None': + color_match_sample = img + if turbo_steps > 1: + turbo_next_image, turbo_next_frame_idx = prev_img, last_frame + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + start_frame = last_frame+turbo_steps + + args.n_samples = 1 + frame_idx = start_frame + + # reset the mask vals as they are overwritten in the compose_mask algorithm + mask_vals = {} + noise_mask_vals = {} + + mask_vals['everywhere'] = Image.new('1', (args.W, args.H), 1) + noise_mask_vals['everywhere'] = Image.new('1', (args.W, args.H), 1) + + mask_image = None + + if args.use_init and args.init_image != None and args.init_image != '': + _, mask_image = load_img(args.init_image, + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + mask_vals['init_mask'] = mask_image + noise_mask_vals['init_mask'] = mask_image + + # Grab the first frame masks since they wont be provided until next frame + if mask_image is None and args.use_mask: + mask_vals['init_mask'] = get_mask(args) + noise_mask_vals['init_mask'] = get_mask(args) # TODO?: add a different default noise mask + + if anim_args.use_mask_video: + mask_vals['video_mask'] = get_mask_from_file(get_next_frame(args.outdir, anim_args.video_mask_path, frame_idx, True), args) + noise_mask_vals['video_mask'] = get_mask_from_file(get_next_frame(args.outdir, anim_args.video_mask_path, frame_idx, True), args) + else: + mask_vals['video_mask'] = None + noise_mask_vals['video_mask'] = None + + #Webui + state.job_count = anim_args.max_frames + + while frame_idx < anim_args.max_frames: + #Webui + state.job = f"frame {frame_idx + 1}/{anim_args.max_frames}" + state.job_no = frame_idx + 1 + if state.interrupted: + break + + print(f"\033[36mAnimation frame: \033[0m{frame_idx}/{anim_args.max_frames} ") + + noise = keys.noise_schedule_series[frame_idx] + strength = keys.strength_schedule_series[frame_idx] + scale = keys.cfg_scale_schedule_series[frame_idx] + contrast = keys.contrast_schedule_series[frame_idx] + kernel = int(keys.kernel_schedule_series[frame_idx]) + sigma = keys.sigma_schedule_series[frame_idx] + amount = keys.amount_schedule_series[frame_idx] + threshold = keys.threshold_schedule_series[frame_idx] + hybrid_comp_schedules = { + "alpha": keys.hybrid_comp_alpha_schedule_series[frame_idx], + "mask_blend_alpha": keys.hybrid_comp_mask_blend_alpha_schedule_series[frame_idx], + "mask_contrast": keys.hybrid_comp_mask_contrast_schedule_series[frame_idx], + "mask_auto_contrast_cutoff_low": int(keys.hybrid_comp_mask_auto_contrast_cutoff_low_schedule_series[frame_idx]), + "mask_auto_contrast_cutoff_high": int(keys.hybrid_comp_mask_auto_contrast_cutoff_high_schedule_series[frame_idx]), + } + scheduled_sampler_name = None + scheduled_clipskip = None + mask_seq = None + noise_mask_seq = None + if anim_args.enable_steps_scheduling and keys.steps_schedule_series[frame_idx] is not None: + args.steps = int(keys.steps_schedule_series[frame_idx]) + if anim_args.enable_sampler_scheduling and keys.sampler_schedule_series[frame_idx] is not None: + scheduled_sampler_name = keys.sampler_schedule_series[frame_idx].casefold() + if anim_args.enable_clipskip_scheduling and keys.clipskip_schedule_series[frame_idx] is not None: + scheduled_clipskip = int(keys.clipskip_schedule_series[frame_idx]) + if args.use_mask and keys.mask_schedule_series[frame_idx] is not None: + mask_seq = keys.mask_schedule_series[frame_idx] + if anim_args.use_noise_mask and keys.noise_mask_schedule_series[frame_idx] is not None: + noise_mask_seq = keys.noise_mask_schedule_series[frame_idx] + + if args.use_mask and not anim_args.use_noise_mask: + noise_mask_seq = mask_seq + + depth = None + + if anim_args.animation_mode == '3D' and (cmd_opts.lowvram or cmd_opts.medvram): + # Unload the main checkpoint and load the depth model + lowvram.send_everything_to_cpu() + sd_hijack.model_hijack.undo_hijack(sd_model) + devices.torch_gc() + depth_model.to(root.device) + + # emit in-between frames + if turbo_steps > 1: + tween_frame_start_idx = max(0, frame_idx-turbo_steps) + for tween_frame_idx in range(tween_frame_start_idx, frame_idx): + tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) + print(f" Creating in-between frame: {tween_frame_idx}; tween:{tween:0.2f};") + + advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx + advance_next = tween_frame_idx > turbo_next_frame_idx + + if depth_model is not None: + assert(turbo_next_image is not None) + depth = depth_model.predict(turbo_next_image, anim_args, root.half_precision) + + if advance_prev: + turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device, half_precision=root.half_precision) + if advance_next: + turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device, half_precision=root.half_precision) + + # hybrid video motion - warps turbo_prev_image or turbo_next_image to match motion + if tween_frame_idx > 0: + if anim_args.hybrid_motion in ['Affine', 'Perspective']: + if anim_args.hybrid_motion_use_prev_img: + if advance_prev: + matrix = get_matrix_for_hybrid_motion_prev(tween_frame_idx, (args.W, args.H), inputfiles, turbo_prev_image, anim_args.hybrid_motion) + turbo_prev_image = image_transform_ransac(turbo_prev_image, matrix, anim_args.hybrid_motion, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + if advance_next: + matrix = get_matrix_for_hybrid_motion_prev(tween_frame_idx, (args.W, args.H), inputfiles, turbo_next_image, anim_args.hybrid_motion) + turbo_next_image = image_transform_ransac(turbo_next_image, matrix, anim_args.hybrid_motion, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + else: + matrix = get_matrix_for_hybrid_motion(tween_frame_idx-1, (args.W, args.H), inputfiles, anim_args.hybrid_motion) + if advance_prev: + turbo_prev_image = image_transform_ransac(turbo_prev_image, matrix, anim_args.hybrid_motion, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + if advance_next: + turbo_next_image = image_transform_ransac(turbo_next_image, matrix, anim_args.hybrid_motion, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + if anim_args.hybrid_motion in ['Optical Flow']: + if anim_args.hybrid_motion_use_prev_img: + if advance_prev: + flow = get_flow_for_hybrid_motion_prev(tween_frame_idx-1, (args.W, args.H), inputfiles, hybrid_frame_path, turbo_prev_image, anim_args.hybrid_flow_method, anim_args.hybrid_comp_save_extra_frames) + turbo_prev_image = image_transform_optical_flow(turbo_prev_image, flow, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + if advance_next: + flow = get_flow_for_hybrid_motion_prev(tween_frame_idx-1, (args.W, args.H), inputfiles, hybrid_frame_path, turbo_next_image, anim_args.hybrid_flow_method, anim_args.hybrid_comp_save_extra_frames) + turbo_next_image = image_transform_optical_flow(turbo_next_image, flow, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + else: + flow = get_flow_for_hybrid_motion(tween_frame_idx-1, (args.W, args.H), inputfiles, hybrid_frame_path, anim_args.hybrid_flow_method, anim_args.hybrid_comp_save_extra_frames) + if advance_prev: + turbo_prev_image = image_transform_optical_flow(turbo_prev_image, flow, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + if advance_next: + turbo_next_image = image_transform_optical_flow(turbo_next_image, flow, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + + turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx + + if turbo_prev_image is not None and tween < 1.0: + img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween + else: + img = turbo_next_image + + # intercept and override to grayscale + if anim_args.color_force_grayscale: + img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + filename = f"{args.timestring}_{tween_frame_idx:05}.png" + cv2.imwrite(os.path.join(args.outdir, filename), img) + if anim_args.save_depth_maps: + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) + if turbo_next_image is not None: + prev_img = turbo_next_image + + # apply transforms to previous frame + if prev_img is not None: + prev_img, depth = anim_frame_warp(prev_img, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device, half_precision=root.half_precision) + + # hybrid video motion - warps prev_img to match motion, usually to prepare for compositing + if frame_idx > 0: + if anim_args.hybrid_motion in ['Affine', 'Perspective']: + if anim_args.hybrid_motion_use_prev_img: + matrix = get_matrix_for_hybrid_motion_prev(frame_idx, (args.W, args.H), inputfiles, prev_img, anim_args.hybrid_motion) + else: + matrix = get_matrix_for_hybrid_motion(frame_idx-1, (args.W, args.H), inputfiles, anim_args.hybrid_motion) + prev_img = image_transform_ransac(prev_img, matrix, anim_args.hybrid_motion, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + if anim_args.hybrid_motion in ['Optical Flow']: + if anim_args.hybrid_motion_use_prev_img: + flow = get_flow_for_hybrid_motion_prev(frame_idx-1, (args.W, args.H), inputfiles, hybrid_frame_path, prev_img, anim_args.hybrid_flow_method, anim_args.hybrid_comp_save_extra_frames) + else: + flow = get_flow_for_hybrid_motion(frame_idx-1, (args.W, args.H), inputfiles, hybrid_frame_path, anim_args.hybrid_flow_method, anim_args.hybrid_comp_save_extra_frames) + prev_img = image_transform_optical_flow(prev_img, flow, cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE) + + # do hybrid video - composites video frame into prev_img (now warped if using motion) + if anim_args.hybrid_composite: + args, prev_img = hybrid_composite(args, anim_args, frame_idx, prev_img, depth_model, hybrid_comp_schedules, root) + + # apply color matching + if anim_args.color_coherence != 'None': + # video color matching + hybrid_available = anim_args.hybrid_composite or anim_args.hybrid_motion in ['Optical Flow', 'Affine', 'Perspective'] + if anim_args.color_coherence == 'Video Input' and hybrid_available: + video_color_coherence_frame = int(frame_idx) % int(anim_args.color_coherence_video_every_N_frames) == 0 + if video_color_coherence_frame: + prev_vid_img = Image.open(os.path.join(args.outdir, 'inputframes', get_frame_name(anim_args.video_init_path) + f"{frame_idx:05}.jpg")) + prev_vid_img = prev_vid_img.resize((args.W, args.H), Image.Resampling.LANCZOS) + color_match_sample = np.asarray(prev_vid_img) + color_match_sample = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2BGR) + if color_match_sample is None: + color_match_sample = prev_img.copy() + else: + prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) + + # intercept and override to grayscale + if anim_args.color_force_grayscale: + prev_img = cv2.cvtColor(prev_img, cv2.COLOR_BGR2GRAY) + prev_img = cv2.cvtColor(prev_img, cv2.COLOR_GRAY2BGR) + + # apply scaling + contrast_image = (prev_img * contrast).round().astype(np.uint8) + # anti-blur + if amount > 0: + contrast_image = unsharp_mask(contrast_image, (kernel, kernel), sigma, amount, threshold, mask_image if args.use_mask else None) + # apply frame noising + if args.use_mask or anim_args.use_noise_mask: + args.noise_mask = compose_mask_with_check(root, args, noise_mask_seq, noise_mask_vals, Image.fromarray(cv2.cvtColor(contrast_image, cv2.COLOR_BGR2RGB))) + noised_image = add_noise(contrast_image, noise, args.seed, anim_args.noise_type, + (anim_args.perlin_w, anim_args.perlin_h, anim_args.perlin_octaves, anim_args.perlin_persistence), + args.noise_mask, args.invert_mask) + + # use transformed previous frame as init for current + args.use_init = True + args.init_sample = Image.fromarray(cv2.cvtColor(noised_image, cv2.COLOR_BGR2RGB)) + args.strength = max(0.0, min(1.0, strength)) + + args.scale = scale + + # Pix2Pix Image CFG Scale - does *nothing* with non pix2pix checkpoints + args.pix2pix_img_cfg_scale = float(keys.pix2pix_img_cfg_scale_series[frame_idx]) + + # grab prompt for current frame + args.prompt = prompt_series[frame_idx] + + if args.seed_behavior == 'schedule' or use_parseq: + args.seed = int(keys.seed_schedule_series[frame_idx]) + + if anim_args.enable_checkpoint_scheduling: + args.checkpoint = keys.checkpoint_schedule_series[frame_idx] + else: + args.checkpoint = None + + #SubSeed scheduling + if anim_args.enable_subseed_scheduling: + args.subseed = int(keys.subseed_schedule_series[frame_idx]) + args.subseed_strength = float(keys.subseed_strength_schedule_series[frame_idx]) + + if use_parseq: + args.seed_enable_extras = True + args.subseed = int(keys.subseed_series[frame_idx]) + args.subseed_strength = keys.subseed_strength_series[frame_idx] + + prompt_to_print, *after_neg = args.prompt.strip().split("--neg") + prompt_to_print = prompt_to_print.strip() + after_neg = "".join(after_neg).strip() + + print(f"\033[32mSeed: \033[0m{args.seed}") + print(f"\033[35mPrompt: \033[0m{prompt_to_print}") + if after_neg and after_neg.strip(): + print(f"\033[91mNeg Prompt: \033[0m{after_neg}") + if not using_vid_init: + # print motion table to cli if anim mode = 2D or 3D + if anim_args.animation_mode in ['2D','3D']: + print_render_table(anim_args, keys, frame_idx) + + # grab init image for current frame + elif using_vid_init: + init_frame = get_next_frame(args.outdir, anim_args.video_init_path, frame_idx, False) + print(f"Using video init frame {init_frame}") + args.init_image = init_frame + if anim_args.use_mask_video: + mask_vals['video_mask'] = get_mask_from_file(get_next_frame(args.outdir, anim_args.video_mask_path, frame_idx, True), args) + + if args.use_mask: + args.mask_image = compose_mask_with_check(root, args, mask_seq, mask_vals, args.init_sample) if args.init_sample is not None else None # we need it only after the first frame anyway + + # setting up some arguments for the looper + loop_args.imageStrength = loopSchedulesAndData.image_strength_schedule_series[frame_idx] + loop_args.blendFactorMax = loopSchedulesAndData.blendFactorMax_series[frame_idx] + loop_args.blendFactorSlope = loopSchedulesAndData.blendFactorSlope_series[frame_idx] + loop_args.tweeningFrameSchedule = loopSchedulesAndData.tweening_frames_schedule_series[frame_idx] + loop_args.colorCorrectionFactor = loopSchedulesAndData.color_correction_factor_series[frame_idx] + loop_args.use_looper = loopSchedulesAndData.use_looper + loop_args.imagesToKeyframe = loopSchedulesAndData.imagesToKeyframe + + if scheduled_clipskip is not None: + opts.data["CLIP_stop_at_last_layers"] = scheduled_clipskip + + if anim_args.animation_mode == '3D' and (cmd_opts.lowvram or cmd_opts.medvram): + depth_model.to('cpu') + devices.torch_gc() + lowvram.setup_for_low_vram(sd_model, cmd_opts.medvram) + sd_hijack.model_hijack.hijack(sd_model) + + # sample the diffusion model + image = generate(args, anim_args, loop_args, controlnet_args, root, frame_idx, sampler_name=scheduled_sampler_name) + patience = 10 + + # intercept and override to grayscale + if anim_args.color_force_grayscale: + image = ImageOps.grayscale(image) + image = ImageOps.colorize(image, black ="black", white ="white") + + # reroll blank frame + if not image.getbbox(): + print("Blank frame detected! If you don't have the NSFW filter enabled, this may be due to a glitch!") + if args.reroll_blank_frames == 'reroll': + while not image.getbbox(): + print("Rerolling with +1 seed...") + args.seed += 1 + image = generate(args, anim_args, loop_args, controlnet_args, root, frame_idx, sampler_name=scheduled_sampler_name) + patience -= 1 + if patience == 0: + print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...") + state.interrupted = True + state.current_image = image + return + elif args.reroll_blank_frames == 'interrupt': + print("Interrupting to save your eyes...") + state.interrupted = True + state.current_image = image + image = blank_frame_reroll(image, args, root, frame_idx) + if image == None: + return + + opencv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + if not using_vid_init: + prev_img = opencv_image + + if turbo_steps > 1: + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + turbo_next_image, turbo_next_frame_idx = opencv_image, frame_idx + frame_idx += turbo_steps + else: + filename = f"{args.timestring}_{frame_idx:05}.png" + save_image(image, 'PIL', filename, args, video_args, root) + + if anim_args.save_depth_maps: + if cmd_opts.lowvram or cmd_opts.medvram: + lowvram.send_everything_to_cpu() + sd_hijack.model_hijack.undo_hijack(sd_model) + devices.torch_gc() + depth_model.to(root.device) + depth = depth_model.predict(opencv_image, anim_args, root.half_precision) + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) + if cmd_opts.lowvram or cmd_opts.medvram: + depth_model.to('cpu') + devices.torch_gc() + lowvram.setup_for_low_vram(sd_model, cmd_opts.medvram) + sd_hijack.model_hijack.hijack(sd_model) + frame_idx += 1 + + state.current_image = image + + args.seed = next_seed(args) + +def print_render_table(anim_args, keys, frame_idx): + from rich.table import Table + from rich import box + table = Table(padding=0, box=box.ROUNDED) + field_names = [] + if anim_args.animation_mode == '2D': + short_zoom = round(keys.zoom_series[frame_idx], 6) + field_names += ["Angle", "Zoom"] + field_names += ["Tr X", "Tr Y"] + if anim_args.animation_mode == '3D': + field_names += ["Tr Z", "Ro X", "Ro Y", "Ro Z"] + if anim_args.enable_perspective_flip: + field_names += ["Pf T", "Pf P", "Pf G", "Pf F"] + for field_name in field_names: + table.add_column(field_name, justify="center") + + rows = [] + if anim_args.animation_mode == '2D': + rows += [str(keys.angle_series[frame_idx]),str(short_zoom)] + rows += [str(keys.translation_x_series[frame_idx]),str(keys.translation_y_series[frame_idx])] + if anim_args.animation_mode == '3D': + rows += [str(keys.translation_z_series[frame_idx]),str(keys.rotation_3d_x_series[frame_idx]),str(keys.rotation_3d_y_series[frame_idx]),str(keys.rotation_3d_z_series[frame_idx])] + if anim_args.enable_perspective_flip: + rows +=[str(keys.perspective_flip_theta_series[frame_idx]), str(keys.perspective_flip_phi_series[frame_idx]), str(keys.perspective_flip_gamma_series[frame_idx]), str(keys.perspective_flip_fv_series[frame_idx])] + table.add_row(*rows) + + console.print(table) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/render_modes.py b/extensions/deforum/scripts/deforum_helpers/render_modes.py new file mode 100644 index 0000000000000000000000000000000000000000..3b53601ffdef213d1c54a182631347193cf21949 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/render_modes.py @@ -0,0 +1,154 @@ +import os +import pathlib +import json +from .render import render_animation +from .seed import next_seed +from .video_audio_utilities import vid2frames +from .prompt import interpolate_prompts +from .generate import generate +from .animation_key_frames import DeformAnimKeys +from .parseq_adapter import ParseqAnimKeys +from .save_images import save_image +from .settings import get_keys_to_exclude + +# Webui +from modules.shared import opts, cmd_opts, state + +def render_input_video(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root): + # create a folder for the video input frames to live in + video_in_frame_path = os.path.join(args.outdir, 'inputframes') + os.makedirs(video_in_frame_path, exist_ok=True) + + # save the video frames from input video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") + vid2frames(video_path = anim_args.video_init_path, video_in_frame_path=video_in_frame_path, n=anim_args.extract_nth_frame, overwrite=anim_args.overwrite_extracted_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame) + + # determine max frames from length of input frames + anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) + args.use_init = True + print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") + + if anim_args.use_mask_video: + # create a folder for the mask video input frames to live in + mask_in_frame_path = os.path.join(args.outdir, 'maskframes') + os.makedirs(mask_in_frame_path, exist_ok=True) + + # save the video frames from mask video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") + vid2frames(video_path=anim_args.video_mask_path,video_in_frame_path=mask_in_frame_path, n=anim_args.extract_nth_frame, overwrite=anim_args.overwrite_extracted_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame) + max_mask_frames = len([f for f in pathlib.Path(mask_in_frame_path).glob('*.jpg')]) + + # limit max frames if there are less frames in the video mask compared to input video + if max_mask_frames < anim_args.max_frames : + anim_args.max_mask_frames + print ("Video mask contains less frames than init video, max frames limited to number of mask frames.") + args.use_mask = True + args.overlay_mask = True + + + render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root) + +# Modified a copy of the above to allow using masking video with out a init video. +def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root): + # create a folder for the video input frames to live in + mask_in_frame_path = os.path.join(args.outdir, 'maskframes') + os.makedirs(mask_in_frame_path, exist_ok=True) + + # save the video frames from mask video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") + vid2frames(video_path=anim_args.video_mask_path, video_in_frame_path=mask_in_frame_path, n=anim_args.extract_nth_frame, overwrite=anim_args.overwrite_extracted_frames, extract_from_frame=anim_args.extract_from_frame, extract_to_frame=anim_args.extract_to_frame) + args.use_mask = True + #args.overlay_mask = True + + # determine max frames from length of input frames + anim_args.max_frames = len([f for f in pathlib.Path(mask_in_frame_path).glob('*.jpg')]) + #args.use_init = True + print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}") + + render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root) + + +def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, animation_prompts, root): + + # use parseq if manifest is provided + use_parseq = parseq_args.parseq_manifest != None and parseq_args.parseq_manifest.strip() + + # expand key frame strings to values + keys = DeformAnimKeys(anim_args) if not use_parseq else ParseqAnimKeys(parseq_args, anim_args) + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving interpolation animation frames to {args.outdir}") + + # save settings for the batch + exclude_keys = get_keys_to_exclude('general') + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {} + for d in [dict(args.__dict__), dict(anim_args.__dict__), dict(parseq_args.__dict__)]: + for key, value in d.items(): + if key not in exclude_keys: + s[key] = value + json.dump(s, f, ensure_ascii=False, indent=4) + + # Compute interpolated prompts + if use_parseq: + print("Parseq prompts are assumed to already be interpolated - not doing any additional prompt interpolation") + prompt_series = keys.prompts + else: + print("Generating interpolated prompts for all frames") + prompt_series = interpolate_prompts(animation_prompts, anim_args.max_frames) + + state.job_count = anim_args.max_frames + frame_idx = 0 + # INTERPOLATION MODE + while frame_idx < anim_args.max_frames: + # print data to cli + prompt_to_print = prompt_series[frame_idx].strip() + if prompt_to_print.endswith("--neg"): + prompt_to_print = prompt_to_print[:-5] + print(f"\033[36mInterpolation frame: \033[0m{frame_idx}/{anim_args.max_frames} ") + print(f"\033[32mSeed: \033[0m{args.seed}") + print(f"\033[35mPrompt: \033[0m{prompt_to_print}") + + state.job = f"frame {frame_idx + 1}/{anim_args.max_frames}" + state.job_no = frame_idx + 1 + + if state.interrupted: + break + + # grab inputs for current frame generation + args.n_samples = 1 + args.prompt = prompt_series[frame_idx] + args.scale = keys.cfg_scale_schedule_series[frame_idx] + args.pix2pix_img_cfg_scale = keys.pix2pix_img_cfg_scale_series[frame_idx] + + if anim_args.enable_checkpoint_scheduling: + args.checkpoint = keys.checkpoint_schedule_series[frame_idx] + print(f"Checkpoint changed to: {args.checkpoint}") + else: + args.checkpoint = None + + if anim_args.enable_subseed_scheduling: + args.subseed = keys.subseed_schedule_series[frame_idx] + args.subseed_strength = keys.subseed_strength_schedule_series[frame_idx] + + if use_parseq: + anim_args.enable_subseed_scheduling = True + args.subseed = int(keys.subseed_series[frame_idx]) + args.subseed_strength = keys.subseed_strength_series[frame_idx] + + if args.seed_behavior == 'schedule' or use_parseq: + args.seed = int(keys.seed_schedule_series[frame_idx]) + + image = generate(args, anim_args, loop_args, controlnet_args, root, frame_idx) + filename = f"{args.timestring}_{frame_idx:05}.png" + + save_image(image, 'PIL', filename, args, video_args, root) + + state.current_image = image + + if args.seed_behavior != 'schedule': + args.seed = next_seed(args) + + frame_idx += 1 \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/rich.py b/extensions/deforum/scripts/deforum_helpers/rich.py new file mode 100644 index 0000000000000000000000000000000000000000..745d8c8bc41116fe1ead73e18569c075a03450e1 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/rich.py @@ -0,0 +1,2 @@ +from rich.console import Console +console = Console() \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/save_images.py b/extensions/deforum/scripts/deforum_helpers/save_images.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6c60c5bfec5947b0a9bf7f9fb87512e97e5ad6 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/save_images.py @@ -0,0 +1,80 @@ +from typing import List, Tuple +from einops import rearrange +import numpy as np, os, torch +from PIL import Image +from torchvision.utils import make_grid +import time + + +def get_output_folder(output_path, batch_folder): + out_path = os.path.join(output_path,time.strftime('%Y-%m')) + if batch_folder != "": + out_path = os.path.join(out_path, batch_folder) + os.makedirs(out_path, exist_ok=True) + return out_path + + +def save_samples( + args, x_samples: torch.Tensor, seed: int, n_rows: int +) -> Tuple[Image.Image, List[Image.Image]]: + """Function to save samples to disk. + Args: + args: Stable deforum diffusion arguments. + x_samples: Samples to save. + seed: Seed for the experiment. + n_rows: Number of rows in the grid. + Returns: + A tuple of the grid image and a list of the generated images. + ( grid_image, generated_images ) + """ + + # save samples + images = [] + grid_image = None + if args.display_samples or args.save_samples: + for index, x_sample in enumerate(x_samples): + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + images.append(Image.fromarray(x_sample.astype(np.uint8))) + if args.save_samples: + images[-1].save( + os.path.join( + args.outdir, f"{args.timestring}_{index:02}_{seed}.png" + ) + ) + + # save grid + if args.display_grid or args.save_grid: + grid = torch.stack([x_samples], 0) + grid = rearrange(grid, "n b c h w -> (n b) c h w") + grid = make_grid(grid, nrow=n_rows, padding=0) + + # to image + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + grid_image = Image.fromarray(grid.astype(np.uint8)) + if args.save_grid: + grid_image.save( + os.path.join(args.outdir, f"{args.timestring}_{seed}_grid.png") + ) + + # return grid_image and individual sample images + return grid_image, images + +def save_image(image, image_type, filename, args, video_args, root): + if video_args.store_frames_in_ram: + root.frames_cache.append({'path':os.path.join(args.outdir, filename), 'image':image, 'image_type':image_type}) + else: + image.save(os.path.join(args.outdir, filename)) + +import cv2, gc + +def reset_frames_cache(root): + root.frames_cache = [] + gc.collect() + +def dump_frames_cache(root): + for image_cache in root.frames_cache: + if image_cache['image_type'] == 'cv2': + cv2.imwrite(image_cache['path'], image_cache['image']) + elif image_cache['image_type'] == 'PIL': + image_cache['image'].save(image_cache['path']) + # do not reset the cache since we're going to add frame erasing later function #TODO diff --git a/extensions/deforum/scripts/deforum_helpers/seed.py b/extensions/deforum/scripts/deforum_helpers/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..5c379d709bebdc16d654e872c4870e769a31e5ac --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/seed.py @@ -0,0 +1,26 @@ +import random + +def next_seed(args): + if args.seed_behavior == 'iter': + if args.seed_internal % args.seed_iter_N == 0: + args.seed += 1 + args.seed_internal += 1 + elif args.seed_behavior == 'ladder': + if args.seed_internal == 0: + args.seed += 2 + args.seed_internal = 1 + else: + args.seed -= 1 + args.seed_internal = 0 + elif args.seed_behavior == 'alternate': + if args.seed_internal == 0: + args.seed += 1 + args.seed_internal = 1 + else: + args.seed -= 1 + args.seed_internal = 0 + elif args.seed_behavior == 'fixed': + pass # always keep seed the same + else: + args.seed = random.randint(0, 2**32 - 1) + return args.seed \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/settings.py b/extensions/deforum/scripts/deforum_helpers/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..8aed504edb3246e60d4508bab6c92c3512f3ee80 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/settings.py @@ -0,0 +1,272 @@ +from math import ceil +import os +import json +import deforum_helpers.args as deforum_args +from .args import mask_fill_choices, DeforumArgs, DeforumAnimArgs +from .deprecation_utils import handle_deprecated_settings +import logging + +def get_keys_to_exclude(setting_type): + if setting_type == 'general': + return ["n_batch", "restore_faces", "seed_enable_extras", "save_samples", "display_samples", "show_sample_per_step", "filename_format", "from_img2img_instead_of_link", "scale", "subseed", "subseed_strength", "C", "f", "init_latent", "init_sample", "init_c", "noise_mask", "seed_internal"] + else: #video + return ["mp4_path", "image_path", "output_format","render_steps","path_name_modifier"] + +def load_args(args_dict,anim_args_dict, parseq_args_dict, loop_args_dict, controlnet_args_dict, custom_settings_file, root): + print(f"reading custom settings from {custom_settings_file}") + if not os.path.isfile(custom_settings_file): + print('The custom settings file does not exist. The in-notebook settings will be used instead') + else: + with open(custom_settings_file, "r") as f: + jdata = json.loads(f.read()) + handle_deprecated_settings(jdata) + root.animation_prompts = jdata["prompts"] + if "animation_prompts_positive" in jdata: + root.animation_prompts_positive = jdata["animation_prompts_positive"] + if "animation_prompts_negative" in jdata: + root.animation_prompts_negative = jdata["animation_prompts_negative"] + for i, k in enumerate(args_dict): + if k in jdata: + args_dict[k] = jdata[k] + else: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {args_dict[k]}") + for i, k in enumerate(anim_args_dict): + if k in jdata: + anim_args_dict[k] = jdata[k] + else: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {anim_args_dict[k]}") + for i, k in enumerate(parseq_args_dict): + if k in jdata: + parseq_args_dict[k] = jdata[k] + else: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {parseq_args_dict[k]}") + for i, k in enumerate(loop_args_dict): + if k in jdata: + loop_args_dict[k] = jdata[k] + else: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {loop_args_dict[k]}") + print(args_dict) + print(anim_args_dict) + print(parseq_args_dict) + print(loop_args_dict) + +# In gradio gui settings save/ load funs: +def save_settings(*args, **kwargs): + settings_path = args[0].strip() + data = {deforum_args.settings_component_names[i]: args[i+1] for i in range(0, len(deforum_args.settings_component_names))} + from deforum_helpers.args import pack_args, pack_anim_args, pack_parseq_args, pack_loop_args, pack_controlnet_args + args_dict = pack_args(data) + anim_args_dict = pack_anim_args(data) + parseq_dict = pack_parseq_args(data) + args_dict["prompts"] = json.loads(data['animation_prompts']) + args_dict["animation_prompts_positive"] = data['animation_prompts_positive'] + args_dict["animation_prompts_negative"] = data['animation_prompts_negative'] + loop_dict = pack_loop_args(data) + controlnet_dict = pack_controlnet_args(data) + + combined = {**args_dict, **anim_args_dict, **parseq_dict, **loop_dict, **controlnet_dict} + exclude_keys = get_keys_to_exclude('general') + ['controlnet_input_video_chosen_file', 'controlnet_input_video_mask_chosen_file'] + filtered_combined = {k: v for k, v in combined.items() if k not in exclude_keys} + + print(f"saving custom settings to {settings_path}") + with open(settings_path, "w") as f: + f.write(json.dumps(filtered_combined, ensure_ascii=False, indent=4)) + + return [""] + +def save_video_settings(*args, **kwargs): + video_settings_path = args[0].strip() + data = {deforum_args.video_args_names[i]: args[i+1] for i in range(0, len(deforum_args.video_args_names))} + from deforum_helpers.args import pack_video_args + video_args_dict = pack_video_args(data) + exclude_keys = get_keys_to_exclude('video') + filtered_data = video_args_dict if exclude_keys is None else {k: v for k, v in video_args_dict.items() if k not in exclude_keys} + print(f"saving video settings to {video_settings_path}") + with open(video_settings_path, "w") as f: + f.write(json.dumps(filtered_data, ensure_ascii=False, indent=4)) + return [""] + +def load_settings(*args, **kwargs): + settings_path = args[0].strip() + data = {deforum_args.settings_component_names[i]: args[i+1] for i in range(0, len(deforum_args.settings_component_names))} + print(f"reading custom settings from {settings_path}") + jdata = {} + if not os.path.isfile(settings_path): + print('The custom settings file does not exist. The values will be unchanged.') + return [data[name] for name in deforum_args.settings_component_names] + [""] + else: + with open(settings_path, "r") as f: + jdata = json.loads(f.read()) + handle_deprecated_settings(jdata) + ret = [] + if 'animation_prompts' in jdata: + jdata['prompts'] = jdata['animation_prompts']#compatibility with old versions + if 'animation_prompts_positive' in jdata: + data["animation_prompts_positive"] = jdata['animation_prompts_positive'] + if 'animation_prompts_negative' in jdata: + data["animation_prompts_negative"] = jdata['animation_prompts_negative'] + for key in data: + if key == 'sampler': + sampler_val = jdata[key] + if type(sampler_val) == int: + from modules.sd_samplers import samplers_for_img2img + ret.append(samplers_for_img2img[sampler_val].name) + else: + ret.append(sampler_val) + + elif key == 'fill': + if key in jdata: + fill_val = jdata[key] + if type(fill_val) == int: + ret.append(mask_fill_choices[fill_val]) + else: + ret.append(fill_val) + else: + fill_default = DeforumArgs()['fill'] + logging.debug(f"Fill not found in load file, using default value: {fill_default}") + ret.append(mask_fill_choices[fill_default]) + + elif key == 'reroll_blank_frames': + if key in jdata: + reroll_blank_frames_val = jdata[key] + ret.append(reroll_blank_frames_val) + else: + reroll_blank_frames_default = DeforumArgs()['reroll_blank_frames'] + logging.debug(f"Reroll blank frames not found in load file, using default value: {reroll_blank_frames_default}") + ret.append(reroll_blank_frames_default) + + elif key == 'noise_type': + if key in jdata: + noise_type_val = jdata[key] + ret.append(noise_type_val) + else: + noise_type_default = DeforumAnimArgs()['noise_type'] + logging.debug(f"Noise type not found in load file, using default value: {noise_type_default}") + ret.append(noise_type_default) + + elif key in jdata: + ret.append(jdata[key]) + else: + if key == 'animation_prompts': + ret.append(json.dumps(jdata['prompts'], ensure_ascii=False, indent=4)) + elif key == 'animation_prompts_positive' and 'animation_prompts_positive' in jdata: + ret.append(jdata['animation_prompts_positive']) + elif key == 'animation_prompts_negative' and 'animation_prompts_negative' in jdata: + ret.append(jdata['animation_prompts_negative']) + else: + ret.append(data[key]) + + #stuff + ret.append("") + + return ret + +def load_video_settings(*args, **kwargs): + video_settings_path = args[0].strip() + data = {deforum_args.video_args_names[i]: args[i+1] for i in range(0, len(deforum_args.video_args_names))} + print(f"reading custom video settings from {video_settings_path}") + jdata = {} + if not os.path.isfile(video_settings_path): + print('The custom video settings file does not exist. The values will be unchanged.') + return [data[name] for name in deforum_args.video_args_names] + [""] + else: + with open(video_settings_path, "r") as f: + jdata = json.loads(f.read()) + handle_deprecated_settings(jdata) + ret = [] + + for key in data: + if key == 'add_soundtrack': + add_soundtrack_val = jdata[key] + if type(add_soundtrack_val) == bool: + ret.append('File' if add_soundtrack_val else 'None') + else: + ret.append(add_soundtrack_val) + elif key in jdata: + ret.append(jdata[key]) + else: + ret.append(data[key]) + + #stuff + ret.append("") + + return ret + +import tqdm +from modules.shared import state, progress_print_out, opts, cmd_opts +class DeforumTQDM: + def __init__(self, args, anim_args, parseq_args): + self._tqdm = None + self._args = args + self._anim_args = anim_args + self._parseq_args = parseq_args + + def reset(self): + from .animation_key_frames import DeformAnimKeys + from .parseq_adapter import ParseqAnimKeys + deforum_total = 0 + # FIXME: get only amount of steps + use_parseq = self._parseq_args.parseq_manifest != None and self._parseq_args.parseq_manifest.strip() + keys = DeformAnimKeys(self._anim_args) if not use_parseq else ParseqAnimKeys(self._parseq_args, self._anim_args) + + start_frame = 0 + if self._anim_args.resume_from_timestring: + for tmp in os.listdir(self._args.outdir): + filename = tmp.split("_") + # don't use saved depth maps to count number of frames + if self._anim_args.resume_timestring in filename and "depth" not in filename: + start_frame += 1 + start_frame = start_frame - 1 + using_vid_init = self._anim_args.animation_mode == 'Video Input' + turbo_steps = 1 if using_vid_init else int(self._anim_args.diffusion_cadence) + if self._anim_args.resume_from_timestring: + last_frame = start_frame-1 + if turbo_steps > 1: + last_frame -= last_frame%turbo_steps + if turbo_steps > 1: + turbo_next_frame_idx = last_frame + turbo_prev_frame_idx = turbo_next_frame_idx + start_frame = last_frame+turbo_steps + frame_idx = start_frame + had_first = False + while frame_idx < self._anim_args.max_frames: + strength = keys.strength_schedule_series[frame_idx] + if not had_first and self._args.use_init and self._args.init_image != None and self._args.init_image != '': + deforum_total += int(ceil(self._args.steps * (1-strength))) + had_first = True + elif not had_first: + deforum_total += self._args.steps + had_first = True + else: + deforum_total += int(ceil(self._args.steps * (1-strength))) + + if turbo_steps > 1: + frame_idx += turbo_steps + else: + frame_idx += 1 + + self._tqdm = tqdm.tqdm( + desc="Deforum progress", + total=deforum_total, + position=1, + file=progress_print_out + ) + + def update(self): + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.update() + + def updateTotal(self, new_total): + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.total=new_total + + def clear(self): + if self._tqdm is not None: + self._tqdm.close() + self._tqdm = None diff --git a/extensions/deforum/scripts/deforum_helpers/src/adabins/__init__.py b/extensions/deforum/scripts/deforum_helpers/src/adabins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2a0eea190658f294d0a49363ea28543087bdf6 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/adabins/__init__.py @@ -0,0 +1 @@ +from .unet_adaptive_bins import UnetAdaptiveBins diff --git a/extensions/deforum/scripts/deforum_helpers/src/adabins/layers.py b/extensions/deforum/scripts/deforum_helpers/src/adabins/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..499cd8cc1ec5973da5718d184d36b187869f9c28 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/adabins/layers.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4): + super(PatchTransformerEncoder, self).__init__() + encoder_layers = nn.TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + self.positional_encodings = nn.Parameter(torch.rand(500, embedding_dim), requires_grad=True) + + def forward(self, x): + embeddings = self.embedding_convPxP(x).flatten(2) # .shape = n,c,s = n, embedding_dim, s + # embeddings = nn.functional.pad(embeddings, (1,0)) # extra special token at start ? + embeddings = embeddings + self.positional_encodings[:embeddings.shape[2], :].T.unsqueeze(0) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x + + +class PixelWiseDotProduct(nn.Module): + def __init__(self): + super(PixelWiseDotProduct, self).__init__() + + def forward(self, x, K): + n, c, h, w = x.size() + _, cout, ck = K.size() + assert c == ck, "Number of channels in x and Embedding dimension (at dim 2) of K matrix must match" + y = torch.matmul(x.view(n, c, h * w).permute(0, 2, 1), K.permute(0, 2, 1)) # .shape = n, hw, cout + return y.permute(0, 2, 1).view(n, cout, h, w) diff --git a/extensions/deforum/scripts/deforum_helpers/src/adabins/miniViT.py b/extensions/deforum/scripts/deforum_helpers/src/adabins/miniViT.py new file mode 100644 index 0000000000000000000000000000000000000000..8a619734aaa82e73fbe37800a6a1dd12e83020a2 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/adabins/miniViT.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn + +from .layers import PatchTransformerEncoder, PixelWiseDotProduct + + +class mViT(nn.Module): + def __init__(self, in_channels, n_query_channels=128, patch_size=16, dim_out=256, + embedding_dim=128, num_heads=4, norm='linear'): + super(mViT, self).__init__() + self.norm = norm + self.n_query_channels = n_query_channels + self.patch_transformer = PatchTransformerEncoder(in_channels, patch_size, embedding_dim, num_heads) + self.dot_product_layer = PixelWiseDotProduct() + + self.conv3x3 = nn.Conv2d(in_channels, embedding_dim, kernel_size=3, stride=1, padding=1) + self.regressor = nn.Sequential(nn.Linear(embedding_dim, 256), + nn.LeakyReLU(), + nn.Linear(256, 256), + nn.LeakyReLU(), + nn.Linear(256, dim_out)) + + def forward(self, x): + # n, c, h, w = x.size() + tgt = self.patch_transformer(x.clone()) # .shape = S, N, E + + x = self.conv3x3(x) + + regression_head, queries = tgt[0, ...], tgt[1:self.n_query_channels + 1, ...] + + # Change from S, N, E to N, S, E + queries = queries.permute(1, 0, 2) + range_attention_maps = self.dot_product_layer(x, queries) # .shape = n, n_query_channels, h, w + + y = self.regressor(regression_head) # .shape = N, dim_out + if self.norm == 'linear': + y = torch.relu(y) + eps = 0.1 + y = y + eps + elif self.norm == 'softmax': + return torch.softmax(y, dim=1), range_attention_maps + else: + y = torch.sigmoid(y) + y = y / y.sum(dim=1, keepdim=True) + return y, range_attention_maps diff --git a/extensions/deforum/scripts/deforum_helpers/src/adabins/unet_adaptive_bins.py b/extensions/deforum/scripts/deforum_helpers/src/adabins/unet_adaptive_bins.py new file mode 100644 index 0000000000000000000000000000000000000000..733927795146fe13563d07d20fbb461da596a181 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/adabins/unet_adaptive_bins.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +from pathlib import Path + +from .miniViT import mViT + + +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleBN, self).__init__() + + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +class DecoderBN(nn.Module): + def __init__(self, num_features=2048, num_classes=1, bottleneck_features=2048): + super(DecoderBN, self).__init__() + features = int(num_features) + + self.conv2 = nn.Conv2d(bottleneck_features, features, kernel_size=1, stride=1, padding=1) + + self.up1 = UpSampleBN(skip_input=features // 1 + 112 + 64, output_features=features // 2) + self.up2 = UpSampleBN(skip_input=features // 2 + 40 + 24, output_features=features // 4) + self.up3 = UpSampleBN(skip_input=features // 4 + 24 + 16, output_features=features // 8) + self.up4 = UpSampleBN(skip_input=features // 8 + 16 + 8, output_features=features // 16) + + # self.up5 = UpSample(skip_input=features // 16 + 3, output_features=features//16) + self.conv3 = nn.Conv2d(features // 16, num_classes, kernel_size=3, stride=1, padding=1) + # self.act_out = nn.Softmax(dim=1) if output_activation == 'softmax' else nn.Identity() + + def forward(self, features): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[ + 11] + + x_d0 = self.conv2(x_block4) + + x_d1 = self.up1(x_d0, x_block3) + x_d2 = self.up2(x_d1, x_block2) + x_d3 = self.up3(x_d2, x_block1) + x_d4 = self.up4(x_d3, x_block0) + # x_d5 = self.up5(x_d4, features[0]) + out = self.conv3(x_d4) + # out = self.act_out(out) + # if with_features: + # return out, features[-1] + # elif with_intermediate: + # return out, [x_block0, x_block1, x_block2, x_block3, x_block4, x_d1, x_d2, x_d3, x_d4] + return out + + +class Encoder(nn.Module): + def __init__(self, backend): + super(Encoder, self).__init__() + self.original_model = backend + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +class UnetAdaptiveBins(nn.Module): + def __init__(self, backend, n_bins=100, min_val=0.1, max_val=10, norm='linear'): + super(UnetAdaptiveBins, self).__init__() + self.num_classes = n_bins + self.min_val = min_val + self.max_val = max_val + self.encoder = Encoder(backend) + self.adaptive_bins_layer = mViT(128, n_query_channels=128, patch_size=16, + dim_out=n_bins, + embedding_dim=128, norm=norm) + + self.decoder = DecoderBN(num_classes=128) + self.conv_out = nn.Sequential(nn.Conv2d(128, n_bins, kernel_size=1, stride=1, padding=0), + nn.Softmax(dim=1)) + + def forward(self, x, **kwargs): + unet_out = self.decoder(self.encoder(x), **kwargs) + bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(unet_out) + out = self.conv_out(range_attention_maps) + + # Post process + # n, c, h, w = out.shape + # hist = torch.sum(out.view(n, c, h * w), dim=2) / (h * w) # not used for training + + bin_widths = (self.max_val - self.min_val) * bin_widths_normed # .shape = N, dim_out + bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_val) + bin_edges = torch.cumsum(bin_widths, dim=1) + + centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:]) + n, dout = centers.size() + centers = centers.view(n, dout, 1, 1) + + pred = torch.sum(out * centers, dim=1, keepdim=True) + + return bin_edges, pred + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + modules = [self.decoder, self.adaptive_bins_layer, self.conv_out] + for m in modules: + yield from m.parameters() + + @classmethod + def build(cls, n_bins, **kwargs): + basemodel_name = 'tf_efficientnet_b5_ap' + + print('Loading base model ()...'.format(basemodel_name), end='') + predicted_torch_model_cache_path = str(Path.home()) + '\\.cache\\torch\\hub\\rwightman_gen-efficientnet-pytorch_master' + predicted_gep_cache_testilfe = Path(predicted_torch_model_cache_path + '\\hubconf.py') + #print(f"predicted_gep_cache_testilfe: {predicted_gep_cache_testilfe}") + # try to fetch the models from cache, and only if it can't be find, download from the internet (to enable offline usage) + if os.path.isfile(predicted_gep_cache_testilfe): + basemodel = torch.hub.load(predicted_torch_model_cache_path, basemodel_name, pretrained=True, source = 'local') + else: + basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) + print('Done.') + + # Remove last layer + print('Removing last two layers (global_pool & classifier).') + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + # Building Encoder-Decoder model + print('Building Encoder-Decoder model..', end='') + m = cls(basemodel, n_bins=n_bins, **kwargs) + print('Done.') + return m + + +if __name__ == '__main__': + model = UnetAdaptiveBins.build(100) + x = torch.rand(2, 3, 480, 640) + bins, pred = model(x) + print(bins.shape, pred.shape) diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/LICENSE b/extensions/deforum/scripts/deforum_helpers/src/clipseg/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..033c8f5ac0e93624d1890d03581d940b9fa850ac --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/LICENSE @@ -0,0 +1,21 @@ +MIT License + +This license does not apply to the model weights. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/Quickstart.ipynb b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Quickstart.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a06690b4606fa533c1cb2429610c983c989570b2 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Quickstart.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import requests\n", + "\n", + "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n", + "! unzip -d weights -j weights.zip\n", + "from models.clipseg import CLIPDensePredT\n", + "from PIL import Image\n", + "from torchvision import transforms\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# load model\n", + "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n", + "model.eval();\n", + "\n", + "# non-strict, because we only stored decoder weights (not CLIP weights)\n", + "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load and normalize `example_image.jpg`. You can also load through an URL." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load and normalize image\n", + "input_image = Image.open('example_image.jpg')\n", + "\n", + "# or load from URL...\n", + "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n", + "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", + " transforms.Resize((352, 352)),\n", + "])\n", + "img = transform(input_image).unsqueeze(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Predict and visualize (this might take a few seconds if running without GPU support)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n", + "\n", + "# predict\n", + "with torch.no_grad():\n", + " preds = model(img.repeat(4,1,1,1), prompts)[0]\n", + "\n", + "# visualize prediction\n", + "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n", + "[a.axis('off') for a in ax.flatten()]\n", + "ax[0].imshow(input_image)\n", + "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n", + "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/Readme.md b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Readme.md new file mode 100644 index 0000000000000000000000000000000000000000..070220ae51af60a601657dfc6c8093ad3192ee28 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Readme.md @@ -0,0 +1,84 @@ +# Image Segmentation Using Text and Image Prompts +This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003). + +**The Paper has been accepted to CVPR 2022!** + +drawing + +The systems allows to create segmentation models without training based on: +- An arbitrary text query +- Or an image with a mask highlighting stuff or an object. + +### Quick Start + +In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension. +It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb) +(please note that the VM does not use a GPU, thus inference takes a few seconds). + + +### Dependencies +This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`). +Additional dependencies are hidden for double blind review. + + +### Datasets + +* `PhraseCut` and `PhraseCutPlus`: Referring expression dataset +* `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation +* `PascalZeroShot`: Wrapper class for PascalZeroShot +* `COCOWrapper`: Wrapper class for COCO. + +### Models + +* `CLIPDensePredT`: CLIPSeg model with transformer-based decoder. +* `ViTDensePredT`: CLIPSeg model with transformer-based decoder. + +### Third Party Dependencies +For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder. +```bash +git clone https://github.com/cvlab-yonsei/JoEm +git clone https://github.com/Jia-Research-Lab/PFENet.git +git clone https://github.com/ChenyunWu/PhraseCutDataset.git +git clone https://github.com/juhongm999/hsnet.git +``` + +### Weights + +The MIT license does not apply to these weights. + +We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB). +``` +wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip +unzip -d weights -j weights.zip +``` + + +### Training and Evaluation + +To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`. + +For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`. + + +### Usage of PFENet Wrappers + +In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder. +`git clone https://github.com/Jia-Research-Lab/PFENet.git ` + + +### License + +The source code files in this repository (excluding model weights) are released under MIT license. + +### Citation +``` +@InProceedings{lueddecke22_cvpr, + author = {L\"uddecke, Timo and Ecker, Alexander}, + title = {Image Segmentation Using Text and Image Prompts}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {7086-7096} +} + +``` diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/Tables.ipynb b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Tables.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..354f6b9e9a8dabd5f202f9cce4d5d4ca7759e8f9 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Tables.ipynb @@ -0,0 +1,349 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import clip\n", + "from evaluation_utils import norm, denorm\n", + "from general_utils import *\n", + "from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PhraseCut" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n", + "tab1 = pc[['name'] + cols]\n", + "for k in cols:\n", + " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", + "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", + "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n", + "print(tab1.to_latex(header=False, index=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For 0.1 threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n", + "tab1 = pc[['name'] + cols]\n", + "for k in cols:\n", + " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", + "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n", + "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n", + "print(tab1.to_latex(header=False, index=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# One-shot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pascal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", + "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n", + "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", + "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", + "\n", + "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", + "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n", + "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", + "\n", + "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", + "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", + "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Pascal Zero-shot (in one-shot setting)\n", + "\n", + "Using the same setting as one-shot (hence different from the other zero-shot benchmark)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", + "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", + "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", + "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", + "\n", + "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", + "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n", + "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", + "\n", + "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n", + "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n", + "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# without fixed thresholds...\n", + "\n", + "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n", + "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", + "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n", + "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", + "\n", + "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n", + "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n", + "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### COCO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n", + "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n", + "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n", + "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n", + "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n", + "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n", + "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Zero-shot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n", + "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n", + "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n", + "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ablation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n", + "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n", + " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n", + "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generalization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generalization = experiment('experiments/generalize.yaml').dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n", + " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n", + " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n", + " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n", + ")" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" + }, + "kernelspec": { + "display_name": "env2", + "language": "python", + "name": "env2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/Visual_Feature_Engineering.ipynb b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Visual_Feature_Engineering.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e6acdacff75fd48011f2f8940f99c5124bee88cd --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/Visual_Feature_Engineering.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Systematic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import clip\n", + "from evaluation_utils import norm, denorm\n", + "from general_utils import *\n", + "from datasets.lvis_oneshot3 import LVIS_OneShot3\n", + "\n", + "clip_device = 'cuda'\n", + "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n", + "clip_model.eval();\n", + "\n", + "from models.clipseg import CLIPDensePredTMasked\n", + "\n", + "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n", + "clip_mask_model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n", + " text_class_labels=True, image_size=352, min_area=0.1,\n", + " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_data(lvis)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import json\n", + "\n", + "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n", + "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n", + "\n", + "objects_per_image = defaultdict(lambda : set())\n", + "for ann in lvis_raw['annotations']:\n", + " objects_per_image[ann['image_id']].add(ann['category_id'])\n", + " \n", + "for ann in lvis_val_raw['annotations']:\n", + " objects_per_image[ann['image_id']].add(ann['category_id']) \n", + " \n", + "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n", + "\n", + "del lvis_raw, lvis_val_raw" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#bs = 32\n", + "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from general_utils import get_batch\n", + "from functools import partial\n", + "from evaluation_utils import img_preprocess\n", + "import torch\n", + "\n", + "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n", + "\n", + " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n", + "\n", + " all_prompts = []\n", + " \n", + " with torch.no_grad():\n", + " valid_sims = []\n", + " torch.manual_seed(571)\n", + " \n", + " if type(batches_or_dataset) == list:\n", + " loader = batches_or_dataset # already loaded\n", + " max_iter = float('inf')\n", + " else:\n", + " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n", + " max_iter = 50\n", + " \n", + " global batch\n", + " for i_batch, (batch, batch_y) in enumerate(loader):\n", + " \n", + " if i_batch >= max_iter: break\n", + " \n", + " processed_batch = process(batch)\n", + " if type(processed_batch) == dict:\n", + " \n", + " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n", + " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n", + " else:\n", + " processed_batch = process(batch).to(clip_device)\n", + " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n", + " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n", + " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n", + " \n", + " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", + " bs = len(batch[0])\n", + " for j in range(bs):\n", + " \n", + " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n", + " support_image = basename(lvis.samples[c][sid])\n", + " \n", + " img_objs = [o for o in objects_per_image[int(support_image)]]\n", + " img_objs = [o.replace('_', ' ') for o in img_objs]\n", + " \n", + " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n", + " if o != batch_y[2][j]]\n", + " \n", + " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n", + " all_prompts += [prompts]\n", + " \n", + " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n", + " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n", + "\n", + " global logits\n", + " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n", + "\n", + " global sim\n", + " sim = torch.softmax(logits, dim=-1)\n", + " \n", + " valid_sims += [sim]\n", + " \n", + " #valid_sims = torch.stack(valid_sims)\n", + " return valid_sims, all_prompts\n", + " \n", + "\n", + "def new_img_preprocess(x):\n", + " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n", + " \n", + "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n", + "get_similarities(lvis, lambda x: x[1]);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preprocessing_functions = [\n", + "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n", + "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n", + "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n", + "# ['colorize object red', partial(img_preprocess, colorize=True)],\n", + "# ['add red outline', partial(img_preprocess, outline=True)],\n", + " \n", + "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n", + "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n", + "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n", + "# ['BG blur', partial(img_preprocess, blur=3)],\n", + "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", + " \n", + "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n", + "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n", + " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n", + " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n", + "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n", + "]\n", + "\n", + "preprocessing_functions = preprocessing_functions\n", + "\n", + "base, base_p = get_similarities(lvis, lambda x: x[1])\n", + "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for j in range(1):\n", + " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pandas import DataFrame\n", + "tab = dict()\n", + "for j, (name, _) in enumerate(preprocessing_functions):\n", + " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n", + " \n", + " \n", + "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visual" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from evaluation_utils import denorm, norm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_sample(filename, filename2):\n", + " from os.path import join\n", + " bp = expanduser('~/cloud/resources/sample_images')\n", + " tf = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", + " transforms.Resize(224),\n", + " transforms.CenterCrop(224)\n", + " ])\n", + " tf2 = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Resize(224),\n", + " transforms.CenterCrop(224)\n", + " ])\n", + " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n", + " inp1[1] = inp1[1].unsqueeze(0)\n", + " inp1[2] = inp1[2][:1] \n", + " return inp1\n", + "\n", + "def all_preprocessing(inp1):\n", + " return [\n", + " img_preprocess(inp1),\n", + " img_preprocess(inp1, colorize=True),\n", + " img_preprocess(inp1, outline=True), \n", + " img_preprocess(inp1, blur=3),\n", + " img_preprocess(inp1, bg_fac=0.1),\n", + " #img_preprocess(inp1, bg_fac=0.5),\n", + " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n", + " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n", + " ]\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torchvision import transforms\n", + "from PIL import Image\n", + "from matplotlib import pyplot as plt\n", + "from evaluation_utils import img_preprocess\n", + "import clip\n", + "\n", + "images_queries = [\n", + " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n", + " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n", + "]\n", + "\n", + "\n", + "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n", + "\n", + "for j, (images, objects) in enumerate(images_queries):\n", + " \n", + " joint_image = all_preprocessing(images)\n", + " \n", + " joint_image = torch.stack(joint_image)[:,0]\n", + " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n", + " image_features = clip_model.encode_image(joint_image)\n", + " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", + " \n", + " prompts = [f'a photo of a {obj}'for obj in objects]\n", + " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n", + " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n", + " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n", + " sim = torch.softmax(logits, dim=-1).detach().cpu()\n", + "\n", + " for i, img in enumerate(joint_image):\n", + " ax[2*j, i].axis('off')\n", + " \n", + " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n", + " ax[2*j+ 1, i].grid(True)\n", + " \n", + " ax[2*j + 1, i].set_ylim(0,1)\n", + " ax[2*j + 1, i].set_yticklabels([])\n", + " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n", + "# ax[1, i].set_xticklabels(objects, rotation=90)\n", + " for k in range(len(sim[i])):\n", + " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n", + " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env2", + "language": "python", + "name": "env2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/coco_wrapper.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/coco_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6844d63ef849ffa19e37c7a49f8cc2a590dc92a5 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/coco_wrapper.py @@ -0,0 +1,99 @@ +import pickle +from types import new_class +import torch +import numpy as np +import os +import json + +from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename +from random import shuffle, seed as set_seed +from PIL import Image + +from itertools import combinations +from torchvision import transforms +from torchvision.transforms.transforms import Resize + +from datasets.utils import blend_image_segmentation +from general_utils import get_from_repository + +COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} + +class COCOWrapper(object): + + def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0, + with_class_label=False): + super().__init__() + + self.mask = mask + self.with_class_label = with_class_label + self.negative_prob = negative_prob + + from third_party.hsnet.data.coco import DatasetCOCO + + get_from_repository('COCO-20i', ['COCO-20i.tar']) + + foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl') + + def build_img_metadata_classwise(self): + with open(foldpath % (self.split, self.fold), 'rb') as f: + img_metadata_classwise = pickle.load(f) + return img_metadata_classwise + + + DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise + # DatasetCOCO.read_mask = read_mask + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]) + + self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False) + + self.all_classes = [self.coco.class_ids] + self.coco.base_path = join(expanduser('~/datasets/COCO-20i')) + + def __len__(self): + return len(self.coco) + + def __getitem__(self, i): + sample = self.coco[i] + + label_name = COCO_CLASSES[int(sample['class_id'])] + + img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0] + + if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob: + new_class_id = sample['class_id'] + while new_class_id == sample['class_id']: + sample2 = self.coco[torch.randint(0, len(self), (1,)).item()] + new_class_id = sample2['class_id'] + img_s = sample2['support_imgs'][0] + seg_s = torch.zeros_like(seg_s) + + mask = self.mask + if mask == 'separate': + supp = (img_s, seg_s) + elif mask == 'text_label': + # DEPRECATED + supp = [int(sample['class_id'])] + elif mask == 'text': + supp = [label_name] + else: + if mask.startswith('text_and_'): + mask = mask[9:] + label_add = [label_name] + else: + label_add = [] + + supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask) + + if self.with_class_label: + label = (torch.zeros(0), sample['class_id'],) + else: + label = (torch.zeros(0), ) + + return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pascal_classes.json b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pascal_classes.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8ad2b9ff453a88af7d50c412fd291ec6567644 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pascal_classes.json @@ -0,0 +1 @@ +[{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}] \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pascal_zeroshot.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pascal_zeroshot.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cfd20f6b63a815776278cc0a4780a2e9b263c2 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pascal_zeroshot.py @@ -0,0 +1,60 @@ +from os.path import expanduser +import torch +import json +import torchvision +from general_utils import get_from_repository +from general_utils import log +from torchvision import transforms + +PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], + ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], + ['chair.n.01', 'pot_plant.n.01']] + + +class PascalZeroShot(object): + + def __init__(self, split, n_unseen, image_size=224) -> None: + super().__init__() + + import sys + sys.path.append('third_party/JoEm') + from third_party.JoEm.data_loader.dataset import VOCSegmentation + from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC + + self.pascal_classes = VOC + self.image_size = image_size + + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + ]) + + if split == 'train': + self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), + split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), + ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) + elif split == 'val': + self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), + split=split, transform=False, + ignore_bg=False, ignore_unseen=False) + + self.unseen_idx = get_unseen_idx(n_unseen) + + def __len__(self): + return len(self.voc) + + def __getitem__(self, i): + + sample = self.voc[i] + label = sample['label'].long() + all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] + class_indices = [l for l in all_labels] + class_names = [self.pascal_classes[l] for l in all_labels] + + image = self.transform(sample['image']) + + label = transforms.Resize((self.image_size, self.image_size), + interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] + + return (image,), (label, ) + + diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pfe_dataset.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pfe_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab6833fea3cdbbe997228e540ea16486c899cbc --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/pfe_dataset.py @@ -0,0 +1,129 @@ +from os.path import expanduser +import torch +import json +from general_utils import get_from_repository +from datasets.lvis_oneshot3 import blend_image_segmentation +from general_utils import log + +PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))} + + +class PFEPascalWrapper(object): + + def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None): + import sys + # sys.path.append(expanduser('~/projects/new_one_shot')) + from third_party.PFENet.util.dataset import SemData + + get_from_repository('PascalVOC2012', ['Pascal5i.tar']) + + self.p_negative = p_negative + self.size = size + self.mode = mode + self.image_size = image_size + + if label_support in {True, False}: + log.warning('label_support argument is deprecated. Use mask instead.') + #raise ValueError() + + self.mask = mask + + value_scale = 255 + mean = [0.485, 0.456, 0.406] + mean = [item * value_scale for item in mean] + std = [0.229, 0.224, 0.225] + std = [item * value_scale for item in std] + + import third_party.PFENet.util.transform as transform + + if mode == 'val': + data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt') + + data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else [] + data_transform += [ + transform.ToTensor(), + transform.Normalize(mean=mean, std=std) + ] + + + elif mode == 'train': + data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt') + + assert image_size != 'original' + + data_transform = [ + transform.RandScale([0.9, 1.1]), + transform.RandRotate([-10, 10], padding=mean, ignore_label=255), + transform.RandomGaussianBlur(), + transform.RandomHorizontalFlip(), + transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255), + transform.ToTensor(), + transform.Normalize(mean=mean, std=std) + ] + + data_transform = transform.Compose(data_transform) + + self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'), + data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False) + + self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list + + # verify that subcls_list always has length 1 + # assert len(set([len(d[4]) for d in self.dataset])) == 1 + + print('actual length', len(self.dataset.data_list)) + + def __len__(self): + if self.mode == 'val': + return len(self.dataset.data_list) + else: + return len(self.dataset.data_list) + + def __getitem__(self, index): + if self.dataset.mode == 'train': + image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)] + elif self.dataset.mode == 'val': + image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)] + ori_label = torch.from_numpy(ori_label).unsqueeze(0) + + if self.image_size != 'original': + longerside = max(ori_label.size(1), ori_label.size(2)) + backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 + backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label + label = backmask.clone().long() + else: + label = label.unsqueeze(0) + + # assert label.shape == (473, 473) + + if self.p_negative > 0: + if torch.rand(1).item() < self.p_negative: + while True: + idx = torch.randint(0, len(self.dataset.data_list), (1,)).item() + _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx] + if subcls_list[0] != subcls_list_tmp[0]: + break + + s_x = s_x[0] + s_y = (s_y == 1)[0] + label_fg = (label == 1).float() + val_mask = (label != 255).float() + + class_id = self.class_list[subcls_list[0]] + + label_name = PASCAL_CLASSES[class_id][0] + label_add = () + mask = self.mask + + if mask == 'text': + support = ('a photo of a ' + label_name + '.',) + elif mask == 'separate': + support = (s_x, s_y) + else: + if mask.startswith('text_and_'): + label_add = (label_name,) + mask = mask[9:] + + support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],) + + return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0]) diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/phrasecut.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/phrasecut.py new file mode 100644 index 0000000000000000000000000000000000000000..a88cf324b8df132667110be2b03eed1433c2671e --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/phrasecut.py @@ -0,0 +1,335 @@ + +import torch +import numpy as np +import os + +from os.path import join, isdir, isfile, expanduser +from PIL import Image + +from torchvision import transforms +from torchvision.transforms.transforms import Resize + +from torch.nn import functional as nnf +from general_utils import get_from_repository + +from skimage.draw import polygon2mask + + + +def random_crop_slices(origin_size, target_size): + """Gets slices of a random crop. """ + assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}' + + offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high + offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item() + + return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1]) + + +def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None): + + + best_crops = [] + best_crop_not_ok = float('-inf'), None, None + min_sum = 0 + + seg = seg.astype('bool') + + if min_frac is not None: + #min_sum = seg.sum() * min_frac + min_sum = seg.shape[0] * seg.shape[1] * min_frac + + for iteration in range(iterations): + sl_y, sl_x = random_crop_slices(seg.shape, image_size) + seg_ = seg[sl_y, sl_x] + sum_seg_ = seg_.sum() + + if sum_seg_ > min_sum: + + if best_of is None: + return sl_y, sl_x, False + else: + best_crops += [(sum_seg_, sl_y, sl_x)] + if len(best_crops) >= best_of: + best_crops.sort(key=lambda x:x[0], reverse=True) + sl_y, sl_x = best_crops[0][1:] + + return sl_y, sl_x, False + + else: + if sum_seg_ > best_crop_not_ok[0]: + best_crop_not_ok = sum_seg_, sl_y, sl_x + + else: + # return best segmentation found + return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,) + + +class PhraseCut(object): + + def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True, + min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None): + super().__init__() + + self.negative_prob = negative_prob + self.image_size = image_size + self.with_visual = with_visual + self.only_visual = only_visual + self.phrase_form = '{}' + self.mask = mask + self.aug_crop = aug_crop + + if aug_color: + self.aug_color = transforms.Compose([ + transforms.ColorJitter(0.5, 0.5, 0.2, 0.05), + ]) + else: + self.aug_color = None + + get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([ + isdir(join(local_dir, 'VGPhraseCut_v0')), + isdir(join(local_dir, 'VGPhraseCut_v0', 'images')), + isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')), + len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249} + ])) + + from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader + self.refvg_loader = RefVGLoader(split=split) + + # img_ids where the size in the annotations does not match actual size + invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530, + 150333, 286065, 285814, 498187, 285761, 498042]) + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + self.normalize = transforms.Normalize(mean, std) + + self.sample_ids = [(i, j) + for i in self.refvg_loader.img_ids + for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases'])) + if i not in invalid_img_ids] + + + # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']])) + + from nltk.stem import WordNetLemmatizer + wnl = WordNetLemmatizer() + + # Filter by class (if remove_classes is set) + if remove_classes is None: + pass + else: + from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo + from nltk.corpus import wordnet + + print('remove pascal classes...') + + get_data = self.refvg_loader.get_img_ref_data # shortcut + keep_sids = None + + if remove_classes[0] == 'pas5i': + subset_id = remove_classes[1] + from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS + avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]] + + + elif remove_classes[0] == 'zs': + stop = remove_classes[1] + + from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS + + avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set] + print(avoid) + + elif remove_classes[0] == 'aff': + # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02'] + # all_lemmas = set(['drink', 'sit', 'ride']) + avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting', + 'ride', 'rides', 'riding', + 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven', + 'swim', 'swims', 'swimming', + 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears'] + keep_sids = [(i, j) for i, j in self.sample_ids if + all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))] + + print('avoid classes:', avoid) + + + if keep_sids is None: + all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)] + all_lemmas = list(set(all_lemmas)) + all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas] + all_lemmas = set(all_lemmas) + + # divide into multi word and single word + all_lemmas_s = set(l for l in all_lemmas if ' ' not in l) + all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s) + + # new3 + phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids] + remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases) + if any(l in phrase for l in all_lemmas_m) or + len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0 + ) + keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids] + + print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}') + removed_ids = set(self.sample_ids) - set(keep_sids) + + print('Examples of removed', len(removed_ids)) + for i, j in list(removed_ids)[:20]: + print(i, get_data(i)['phrases'][j]) + + self.sample_ids = keep_sids + + from itertools import groupby + samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j)) + for i, j in self.sample_ids] + samples_by_phrase = sorted(samples_by_phrase) + samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0]) + + self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase} + + self.all_phrases = list(set(self.samples_by_phrase.keys())) + + + if self.only_visual: + assert self.with_visual + self.sample_ids = [(i, j) for i, j in self.sample_ids + if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1] + + # Filter by size (if min_size is set) + sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids] + image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids] + #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes] + self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)] + + if min_size: + print('filter by size') + + self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size] + + self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/')) + + def __len__(self): + return len(self.sample_ids) + + + def load_sample(self, sample_i, j): + + img_ref_data = self.refvg_loader.get_img_ref_data(sample_i) + + polys_phrase0 = img_ref_data['gt_Polygons'][j] + phrase = img_ref_data['phrases'][j] + phrase = self.phrase_form.format(phrase) + + masks = [] + for polys in polys_phrase0: + for poly in polys: + poly = [p[::-1] for p in poly] # swap x,y + masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)] + + seg = np.stack(masks).max(0) + img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg'))) + + min_shape = min(img.shape[:2]) + + if self.aug_crop: + sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05) + else: + sly, slx = slice(0, None), slice(0, None) + + seg = seg[sly, slx] + img = img[sly, slx] + + seg = seg.astype('uint8') + seg = torch.from_numpy(seg).view(1, 1, *seg.shape) + + if img.ndim == 2: + img = np.dstack([img] * 3) + + img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() + + seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0] + img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0] + + # img = img.permute([2,0, 1]) + img = img / 255.0 + + if self.aug_color is not None: + img = self.aug_color(img) + + img = self.normalize(img) + + + + return img, seg, phrase + + def __getitem__(self, i): + + sample_i, j = self.sample_ids[i] + + img, seg, phrase = self.load_sample(sample_i, j) + + if self.negative_prob > 0: + if torch.rand((1,)).item() < self.negative_prob: + + new_phrase = None + while new_phrase is None or new_phrase == phrase: + idx = torch.randint(0, len(self.all_phrases), (1,)).item() + new_phrase = self.all_phrases[idx] + phrase = new_phrase + seg = torch.zeros_like(seg) + + if self.with_visual: + # find a corresponding visual image + if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1: + idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item() + other_sample = self.samples_by_phrase[phrase][idx] + #print(other_sample) + img_s, seg_s, _ = self.load_sample(*other_sample) + + from datasets.utils import blend_image_segmentation + + if self.mask in {'separate', 'text_and_separate'}: + # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:] + add_phrase = [phrase] if self.mask == 'text_and_separate' else [] + vis_s = add_phrase + [img_s, seg_s, True] + else: + if self.mask.startswith('text_and_'): + mask_mode = self.mask[9:] + label_add = [phrase] + else: + mask_mode = self.mask + label_add = [] + + masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0]) + vis_s = label_add + [masked_img_s, True] + + else: + # phrase is unique + vis_s = torch.zeros_like(img) + + if self.mask in {'separate', 'text_and_separate'}: + add_phrase = [phrase] if self.mask == 'text_and_separate' else [] + vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False] + elif self.mask.startswith('text_and_'): + vis_s = [phrase, vis_s, False] + else: + vis_s = [vis_s, False] + else: + assert self.mask == 'text' + vis_s = [phrase] + + seg = seg.unsqueeze(0).float() + + data_x = (img,) + tuple(vis_s) + + return data_x, (seg, torch.zeros(0), i) + + +class PhraseCutPlus(PhraseCut): + + def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None): + super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size, + remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/utils.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07b17449991c6ebe58776b64982c75ed30ca6c82 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/datasets/utils.py @@ -0,0 +1,68 @@ + +import numpy as np +import torch + + +def blend_image_segmentation(img, seg, mode, image_size=224): + + + if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}: + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + + if isinstance(seg, np.ndarray): + seg = torch.from_numpy(seg) + + if mode == 'overlay': + out = img * seg + out = [out.astype('float32')] + elif mode == 'highlight': + out = img * seg[None, :, :] * 0.85 + 0.15 * img + out = [out.astype('float32')] + elif mode == 'highlight2': + img = img / 2 + out = (img+0.1) * seg[None, :, :] + 0.3 * img + out = [out.astype('float32')] + elif mode == 'blur_highlight': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01] + elif mode == 'blur3_highlight': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01] + elif mode == 'blur3_highlight01': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01] + elif mode == 'blur_highlight_random': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01] + elif mode == 'crop': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()] + elif mode == 'crop_blur_highlight': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()] + elif mode == 'crop_blur_highlight352': + from evaluation_utils import img_preprocess + out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()] + elif mode == 'shape': + out = [np.stack([seg[:, :]]*3).astype('float32')] + elif mode == 'concat': + out = [np.concatenate([img, seg[None, :, :]]).astype('float32')] + elif mode == 'image_only': + out = [img.astype('float32')] + elif mode == 'image_black': + out = [img.astype('float32')*0] + elif mode is None: + out = [img.astype('float32')] + elif mode == 'separate': + out = [img.astype('float32'), seg.astype('int64')] + elif mode == 'separate_img_black': + out = [img.astype('float32')*0, seg.astype('int64')] + elif mode == 'separate_seg_ones': + out = [img.astype('float32'), np.ones_like(seg).astype('int64')] + elif mode == 'separate_both_black': + out = [img.astype('float32')*0, seg.astype('int64')*0] + else: + raise ValueError(f'invalid mode: {mode}') + + return out \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/environment.yml b/extensions/deforum/scripts/deforum_helpers/src/clipseg/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..40bb99f79ee06ae85c5f7934d74997de295e95a3 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/environment.yml @@ -0,0 +1,15 @@ +name: clipseg-environment +channels: + - conda-forge + - pytorch +dependencies: + - numpy + - scipy + - matplotlib-base + - pip + - pip: + - --find-links https://download.pytorch.org/whl/torch_stable.html + - torch==1.10.0+cpu + - torchvision==0.11.1+cpu + - opencv-python + - git+https://github.com/openai/CLIP.git diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/evaluation_utils.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/evaluation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13e482f081106fce43305376a35b8c4b11676cd7 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/evaluation_utils.py @@ -0,0 +1,292 @@ +from torch.functional import Tensor +from general_utils import load_model +from torch.utils.data import DataLoader +import torch +import numpy as np + +def denorm(img): + + np_input = False + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + np_input = True + + mean = torch.Tensor([0.485, 0.456, 0.406]) + std = torch.Tensor([0.229, 0.224, 0.225]) + + img_denorm = (img*std[:,None,None]) + mean[:,None,None] + + if np_input: + img_denorm = np.clip(img_denorm.numpy(), 0, 1) + else: + img_denorm = torch.clamp(img_denorm, 0, 1) + + return img_denorm + + +def norm(img): + mean = torch.Tensor([0.485, 0.456, 0.406]) + std = torch.Tensor([0.229, 0.224, 0.225]) + return (img - mean[:,None,None]) / std[:,None,None] + + +def fast_iou_curve(p, g): + + g = g[p.sort().indices] + p = torch.sigmoid(p.sort().values) + + scores = [] + vals = np.linspace(0, 1, 50) + + for q in vals: + + n = int(len(g) * q) + + valid = torch.where(p > q)[0] + if len(valid) > 0: + n = int(valid[0]) + else: + n = len(g) + + fn = g[:n].sum() + tn = n - fn + tp = g[n:].sum() + fp = len(g) - n - tp + + iou = tp / (tp + fn + fp) + + precision = tp / (tp + fp) + recall = tp / (tp + fn) + + scores += [iou] + + return vals, scores + + +def fast_rp_curve(p, g): + + g = g[p.sort().indices] + p = torch.sigmoid(p.sort().values) + + precisions, recalls = [], [] + vals = np.linspace(p.min(), p.max(), 250) + + for q in p[::100000]: + + n = int(len(g) * q) + + valid = torch.where(p > q)[0] + if len(valid) > 0: + n = int(valid[0]) + else: + n = len(g) + + fn = g[:n].sum() + tn = n - fn + tp = g[n:].sum() + fp = len(g) - n - tp + + iou = tp / (tp + fn + fp) + + precision = tp / (tp + fp) + recall = tp / (tp + fn) + + precisions += [precision] + recalls += [recall] + + return recalls, precisions + + +# Image processing + +def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2, + brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224): + import cv2 + + rw = rect_width + + out = [] + for img, mask in zip(batch[1], batch[2]): + + img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img) + mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask) + + img *= brightness + img_bl = img + if blur > 0: # best 5 + img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1) + + if grayscale: + img_bl = img_bl[1][None] + + #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl + # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask) + img_inp = img*mask + (bg_fac) * img_bl * (1-mask) + + if rect: + _, bbox = crop_mask(img, mask, context=0.1) + img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None] + img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None] + img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] + img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None] + + + if center_context is not None: + img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size) + + if colorize: + img_gray = denorm(img) + img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY) + img_gray = torch.stack([torch.from_numpy(img_gray)]*3) + img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask) + img_inp = norm(img_inp) + + if outline: + cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + outline_img = np.zeros(mask.shape, dtype=np.uint8) + cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255)) + outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255. + img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img) + img_inp = norm(img_inp) + + out += [img_inp] + + return torch.stack(out) + + +def object_crop(img, mask, context=0.0, square=False, image_size=224): + img_crop, bbox = crop_mask(img, mask, context=context, square=square) + img_crop = pad_to_square(img_crop, channel_dim=0) + img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0) + return img_crop + + +def crop_mask(img, mask, context=0.0, square=False): + + assert img.shape[1:] == mask.shape + + bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()] + bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()] + bbox = [int(x) for x in bbox] + + width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) + + # square mask + if square: + bbox[0] = int(max(0, bbox[0] - context * height)) + bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) + bbox[2] = int(max(0, bbox[2] - context * width)) + bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) + + width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) + if height > width: + bbox[2] = int(max(0, (bbox[2] - 0.5*height))) + bbox[3] = bbox[2] + height + else: + bbox[0] = int(max(0, (bbox[0] - 0.5*width))) + bbox[1] = bbox[0] + width + else: + bbox[0] = int(max(0, bbox[0] - context * height)) + bbox[1] = int(min(mask.size(0), bbox[1] + context * height)) + bbox[2] = int(max(0, bbox[2] - context * width)) + bbox[3] = int(min(mask.size(1), bbox[3] + context * width)) + + width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0]) + img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]] + return img_crop, bbox + + +def pad_to_square(img, channel_dim=2, fill=0): + """ + + + add padding such that a squared image is returned """ + + from torchvision.transforms.functional import pad + + if channel_dim == 2: + img = img.permute(2, 0, 1) + elif channel_dim == 0: + pass + else: + raise ValueError('invalid channel_dim') + + h, w = img.shape[1:] + pady1 = pady2 = padx1 = padx2 = 0 + + if h > w: + padx1 = (h - w) // 2 + padx2 = h - w - padx1 + elif w > h: + pady1 = (w - h) // 2 + pady2 = w - h - pady1 + + img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant') + + if channel_dim == 2: + img_padded = img_padded.permute(1, 2, 0) + + return img_padded + + +# qualitative + +def split_sentence(inp, limit=9): + t_new, current_len = [], 0 + for k, t in enumerate(inp.split(' ')): + current_len += len(t) + 1 + t_new += [t+' '] + # not last + if current_len > limit and k != len(inp.split(' ')) - 1: + current_len = 0 + t_new += ['\n'] + + t_new = ''.join(t_new) + return t_new + + +from matplotlib import pyplot as plt + + +def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None): + + row_off = 0 if labels is None else 1 + _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2))) + [a.axis('off') for a in ax.flatten()] + + if labels is not None: + for j in range(len(labels)): + t_new = split_sentence(labels[j], limit=6) + ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale) + + + for i in range(len(imgs)): + ax[i + row_off,0].imshow(imgs[i]) + for j in range(len(preds)): + img = preds[j][i][0].detach().cpu().numpy() + + if gt_labels is not None and labels[j] == gt_labels[i]: + print(j, labels[j], gt_labels[i]) + edgecolor = 'red' + if aps is not None: + ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8) + else: + edgecolor = 'k' + + rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none", + edgecolor=edgecolor, linewidth=3) + ax[i + row_off,1 + j].add_patch(rect) + + if vmax is None: + this_vmax = 1 + elif vmax == 'per_prompt': + this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))]) + elif vmax == 'per_image': + this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))]) + + ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap) + + + # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max()) + plt.tight_layout() + plt.subplots_adjust(wspace=0.05, hspace=0.05) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/example_image.jpg b/extensions/deforum/scripts/deforum_helpers/src/clipseg/example_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7e47c8223599929bc45f0d64f22ca581b3442dae --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/example_image.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bffef4c76aa9ad56fa072c3b61f1733ea60a204994c5eae2262621e8c8edd686 +size 91493 diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/ablation.yaml b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/ablation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..133c7f096095da6a085618b418e1b692ca9a1a2c --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/ablation.yaml @@ -0,0 +1,84 @@ +configuration: + batch_size: 64 + optimizer: torch.optim.AdamW + + lr: 0.001 + + trainer: experiment_setup.train_loop + scorer: experiment_setup.score + model: models.clipseg.CLIPDensePredT + + lr_scheduler: cosine + T_max: 20000 + eta_min: 0.0001 + + max_iterations: 20000 # <-########################################## + val_interval: null + + # dataset + dataset: datasets.phrasecut.PhraseCut # <----------------- + split_mode: pascal_test + split: train + mask: text_and_crop_blur_highlight352 + image_size: 352 + negative_prob: 0.2 + mix_text_max: 0.5 + + # general + mix: True # <----------------- + prompt: shuffle+ + norm_cond: True + mix_text_min: 0.0 + with_visual: True + + # model + version: 'ViT-B/16' + extract_layers: [3, 7, 9] + reduce_dim: 64 + depth: 3 + fix_shift: False # <-########################################## + + loss: torch.nn.functional.binary_cross_entropy_with_logits + amp: True + +test_configuration_common: + normalize: True + image_size: 352 + batch_size: 32 + sigmoid: True + split: test + label_support: True + +test_configuration: + + - + name: pc + metric: metrics.FixedIntervalMetrics + test_dataset: phrasecut + mask: text + + - + name: pc-vis + metric: metrics.FixedIntervalMetrics + test_dataset: phrasecut + mask: crop_blur_highlight352 + with_visual: True + visual_only: True + + +columns: [name, +pc_fgiou_best, pc_miou_best, pc_fgiou_0.5, +pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5, +duration] + + +individual_configurations: + +- {name: rd64-uni} +- {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003} +- {name: rd64-no-negatives, negative_prob: 0.0} +- {name: rd64-neg0.5, negative_prob: 0.5} +- {name: rd64-no-visual, with_visual: False, mix: False} +- {name: rd16-uni, reduce_dim: 16} +- {name: rd64-layer3, extract_layers: [3], depth: 1} +- {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}} \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/coco.yaml b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e240eebda8a2b96a20335732b1091b40e818e9d --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/coco.yaml @@ -0,0 +1,101 @@ +configuration: + batch_size: 64 + optimizer: torch.optim.AdamW + + lr: 0.001 + + trainer: experiment_setup.train_loop + scorer: experiment_setup.score + model: models.clipseg.CLIPDensePredT + + lr_scheduler: cosine + T_max: 20000 + eta_min: 0.0001 + + max_iterations: 20000 + val_interval: null + + # dataset + dataset: datasets.coco_wrapper.COCOWrapper + # split_mode: pascal_test + split: train + mask: text_and_blur3_highlight01 + image_size: 352 + normalize: True + pre_crop_image_size: [sample, 1, 1.5] + aug: 1new + + # general + mix: True + prompt: shuffle+ + norm_cond: True + mix_text_min: 0.0 + + # model + out: 1 + extract_layers: [3, 7, 9] + reduce_dim: 64 + depth: 3 + fix_shift: False + + loss: torch.nn.functional.binary_cross_entropy_with_logits + amp: True + +test_configuration_common: + normalize: True + image_size: 352 + # max_iterations: 10 + batch_size: 8 + sigmoid: True + test_dataset: coco + metric: metrics.FixedIntervalMetrics + +test_configuration: + + - + name: coco_t + mask: text + + - + name: coco_h + mask: blur3_highlight01 + + - + name: coco_h2 + mask: crop_blur_highlight352 + + +columns: [i, name, +coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5, +coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5, +coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t, +train_loss, duration, date +] + +individual_configurations: + + +- {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} + + +- {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} + + +# ViT +- {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} +- {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} +- {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} +- {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001} + + +# BASELINE +- {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} +- {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000} \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/pascal_1shot.yaml b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/pascal_1shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..287915c5554d92f095349815b7d356e831831b58 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/pascal_1shot.yaml @@ -0,0 +1,101 @@ +configuration: + batch_size: 64 + optimizer: torch.optim.AdamW + + lr: 0.001 + + trainer: experiment_setup.train_loop + scorer: experiment_setup.score + model: models.clipseg.CLIPDensePredT + + lr_scheduler: cosine + T_max: 20000 + eta_min: 0.0001 + + max_iterations: 20000 # <-########################################## + val_interval: null + + # dataset + dataset: datasets.phrasecut.PhraseCut + split_mode: pascal_test + mode: train + mask: text_and_crop_blur_highlight352 + image_size: 352 + normalize: True + pre_crop_image_size: [sample, 1, 1.5] + aug: 1new + with_visual: True + split: train + + # general + mix: True + prompt: shuffle+ + norm_cond: True + mix_text_min: 0.0 + + # model + out: 1 + version: 'ViT-B/16' + extract_layers: [3, 7, 9] + reduce_dim: 64 + depth: 3 + + loss: torch.nn.functional.binary_cross_entropy_with_logits + amp: True + +test_configuration_common: + normalize: True + image_size: 352 + metric: metrics.FixedIntervalMetrics + batch_size: 1 + test_dataset: pascal + sigmoid: True + # max_iterations: 250 + +test_configuration: + + - + name: pas_t + mask: text + + - + name: pas_h + mask: blur3_highlight01 + + - + name: pas_h2 + mask: crop_blur_highlight352 + + +columns: [name, +pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct, +pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct, +pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t, +train_loss, duration, date +] + +individual_configurations: + +- {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}} +- {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}} +- {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}} +- {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}} + + +- {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}} +- {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}} +- {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}} +- {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}} + + +# baseline +- {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}} +- {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}} +- {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}} +- {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}} + +# ViT +- {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}} +- {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}} +- {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}} +- {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}} diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/phrasecut.yaml b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/phrasecut.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad1bc5e111651836874ba997043442bd50303669 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/experiments/phrasecut.yaml @@ -0,0 +1,80 @@ +configuration: + batch_size: 64 + optimizer: torch.optim.AdamW + + lr: 0.001 + + trainer: experiment_setup.train_loop + scorer: experiment_setup.score + model: models.clipseg.CLIPDensePredT + + lr_scheduler: cosine + T_max: 20000 + eta_min: 0.0001 + + max_iterations: 20000 + val_interval: null + + # dataset + dataset: datasets.phrasecut.PhraseCut # <----------------- + split_mode: pascal_test + split: train + mask: text_and_crop_blur_highlight352 + image_size: 352 + normalize: True + pre_crop_image_size: [sample, 1, 1.5] + aug: 1new + + # general + mix: False # <----------------- + prompt: shuffle+ + norm_cond: True + mix_text_min: 0.0 + + # model + out: 1 + extract_layers: [3, 7, 9] + reduce_dim: 64 + depth: 3 + fix_shift: False + + loss: torch.nn.functional.binary_cross_entropy_with_logits + amp: True + +test_configuration_common: + normalize: True + image_size: 352 + batch_size: 32 + # max_iterations: 5 + # max_iterations: 150 + +test_configuration: + + - + name: pc # old: phrasecut + metric: metrics.FixedIntervalMetrics + test_dataset: phrasecut + split: test + mask: text + label_support: True + sigmoid: True + + +columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date] + + +individual_configurations: + +# important ones + + +- {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5} + +# this was accedentally trained using old mask +- {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01} +- {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False} +# this was accedentally trained using old mask +- {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01} + +- {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003} +- {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001} \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/general_utils.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ba5552ef2b8340ae9daa29924e2c95255bf7de --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/general_utils.py @@ -0,0 +1,272 @@ +import json +import inspect +import torch +import os +import sys +import yaml +from shutil import copy, copytree +from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename + + +class Logger(object): + + def __getattr__(self, k): + return print + +log = Logger() + +def training_config_from_cli_args(): + experiment_name = sys.argv[1] + experiment_id = int(sys.argv[2]) + + yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) + + config = yaml_config['configuration'] + config = {**config, **yaml_config['individual_configurations'][experiment_id]} + config = AttributeDict(config) + return config + + +def score_config_from_cli_args(): + experiment_name = sys.argv[1] + experiment_id = int(sys.argv[2]) + + + yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader) + + config = yaml_config['test_configuration_common'] + + if type(yaml_config['test_configuration']) == list: + test_id = int(sys.argv[3]) + config = {**config, **yaml_config['test_configuration'][test_id]} + else: + config = {**config, **yaml_config['test_configuration']} + + if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]: + config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']} + + train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name'] + + config = AttributeDict(config) + return config, train_checkpoint_id + + +def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository', + local_dir='~/datasets'): + """ copies files from repository to local folder. + + repo_files: list of filenames or list of tuples [filename, target path] + + e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar']) + will create a folder 'MyDataset' in local_dir, and extract the content of + '/data/dataset1.tar' to /MyDataset/other/path. + """ + + local_dir = realpath(join(expanduser(local_dir), local_name)) + + dataset_exists = True + + # check if folder is available + if not isdir(local_dir): + dataset_exists = False + + if integrity_check is not None: + try: + integrity_ok = integrity_check(local_dir) + except BaseException: + integrity_ok = False + + if integrity_ok: + log.hint('Passed custom integrity check') + else: + log.hint('Custom integrity check failed') + + dataset_exists = dataset_exists and integrity_ok + + if not dataset_exists: + + repo_dir = realpath(expanduser(repo_dir)) + + for i, filename in enumerate(repo_files): + + if type(filename) == str: + origin, target = filename, filename + archive_target = join(local_dir, basename(origin)) + extract_target = join(local_dir) + else: + origin, target = filename + archive_target = join(local_dir, dirname(target), basename(origin)) + extract_target = join(local_dir, dirname(target)) + + archive_origin = join(repo_dir, origin) + + log.hint(f'copy: {archive_origin} to {archive_target}') + + # make sure the path exists + os.makedirs(dirname(archive_target), exist_ok=True) + + if os.path.isfile(archive_target): + # only copy if size differs + if os.path.getsize(archive_target) != os.path.getsize(archive_origin): + log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}') + copy(archive_origin, archive_target) + else: + copy(archive_origin, archive_target) + + extract_archive(archive_target, extract_target, noarchive_ok=True) + + # concurrent processes might have deleted the file + if os.path.isfile(archive_target): + os.remove(archive_target) + + +def extract_archive(filename, target_folder=None, noarchive_ok=False): + from subprocess import run, PIPE + + if filename.endswith('.tgz') or filename.endswith('.tar'): + command = f'tar -xf {filename}' + command += f' -C {target_folder}' if target_folder is not None else '' + elif filename.endswith('.tar.gz'): + command = f'tar -xzf {filename}' + command += f' -C {target_folder}' if target_folder is not None else '' + elif filename.endswith('zip'): + command = f'unzip {filename}' + command += f' -d {target_folder}' if target_folder is not None else '' + else: + if noarchive_ok: + return + else: + raise ValueError(f'unsuppored file ending of {filename}') + + log.hint(command) + result = run(command.split(), stdout=PIPE, stderr=PIPE) + if result.returncode != 0: + print(result.stdout, result.stderr) + + +class AttributeDict(dict): + """ + An extended dictionary that allows access to elements as atttributes and counts + these accesses. This way, we know if some attributes were never used. + """ + + def __init__(self, *args, **kwargs): + from collections import Counter + super().__init__(*args, **kwargs) + self.__dict__['counter'] = Counter() + + def __getitem__(self, k): + self.__dict__['counter'][k] += 1 + return super().__getitem__(k) + + def __getattr__(self, k): + self.__dict__['counter'][k] += 1 + return super().get(k) + + def __setattr__(self, k, v): + return super().__setitem__(k, v) + + def __delattr__(self, k, v): + return super().__delitem__(k, v) + + def unused_keys(self, exceptions=()): + return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions] + + def assume_no_unused_keys(self, exceptions=()): + if len(self.unused_keys(exceptions=exceptions)) > 0: + log.warning('Unused keys:', self.unused_keys(exceptions=exceptions)) + + +def get_attribute(name): + import importlib + + if name is None: + raise ValueError('The provided attribute is None') + + name_split = name.split('.') + mod = importlib.import_module('.'.join(name_split[:-1])) + return getattr(mod, name_split[-1]) + + + +def filter_args(input_args, default_args): + + updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()} + used_args = {k: v for k, v in input_args.items() if k in default_args} + unused_args = {k: v for k, v in input_args.items() if k not in default_args} + + return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args) + + +def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False): + + config = json.load(open(join('logs', checkpoint_id, 'config.json'))) + + if model_args != 'from_config' and type(model_args) != dict: + raise ValueError('model_args must either be "from_config" or a dictionary of values') + + model_cls = get_attribute(config['model']) + + # load model + if model_args == 'from_config': + _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) + + model = model_cls(**model_args) + + if weights_file is None: + weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) + else: + weights_file = realpath(join('logs', checkpoint_id, weights_file)) + + if isfile(weights_file): + weights = torch.load(weights_file) + for _, w in weights.items(): + assert not torch.any(torch.isnan(w)), 'weights contain NaNs' + model.load_state_dict(weights, strict=strict) + else: + raise FileNotFoundError(f'model checkpoint {weights_file} was not found') + + if with_config: + return model, config + + return model + + +class TrainingLogger(object): + + def __init__(self, model, log_dir, config=None, *args): + super().__init__() + self.model = model + self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None + + os.makedirs('logs/', exist_ok=True) + os.makedirs(self.base_path, exist_ok=True) + + if config is not None: + json.dump(config, open(join(self.base_path, 'config.json'), 'w')) + + def iter(self, i, **kwargs): + if i % 100 == 0 and 'loss' in kwargs: + loss = kwargs['loss'] + print(f'iteration {i}: loss {loss:.4f}') + + def save_weights(self, only_trainable=False, weight_file='weights.pth'): + if self.model is None: + raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.') + + weights_path = join(self.base_path, weight_file) + + weight_dict = self.model.state_dict() + + if only_trainable: + weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad} + + torch.save(weight_dict, weights_path) + log.info(f'Saved weights to {weights_path}') + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """ automatically stop processes if used in a context manager """ + pass \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/metrics.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..85473387dd355dd1d9b6b158e1b0431b13a1c1f3 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/metrics.py @@ -0,0 +1,271 @@ +from torch.functional import Tensor +from general_utils import log +from collections import defaultdict +import numpy as np + +import torch +from torch.nn import functional as nnf + + +class BaseMetric(object): + + def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True, + eval_validation=True): + self._names = tuple(metric_names) + self._eval_intermediate = eval_intermediate + self._eval_validation = eval_validation + + self._pred_range = pred_range + self._pred_index = pred_index + self._gt_index = gt_index + + self.predictions = [] + self.ground_truths = [] + + def eval_intermediate(self): + return self._eval_intermediate + + def eval_validation(self): + return self._eval_validation + + def names(self): + return self._names + + def add(self, predictions, ground_truth): + raise NotImplementedError + + def value(self): + raise NotImplementedError + + def scores(self): + # similar to value but returns dict + value = self.value() + if type(value) == dict: + return value + else: + assert type(value) in {list, tuple} + return list(zip(self.names(), self.value())) + + def _get_pred_gt(self, predictions, ground_truth): + pred = predictions[self._pred_index] + gt = ground_truth[self._gt_index] + + if self._pred_range is not None: + pred = pred[:, self._pred_range[0]: self._pred_range[1]] + + return pred, gt + + +class FixedIntervalMetrics(BaseMetric): + + def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None, + resize_pred=None, n_values=51, custom_threshold=None): + + + super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh')) + self.intersections = [] + self.unions = [] + # self.threshold = threshold + self.sigmoid = sigmoid + self.resize_to = resize_to + self.resize_pred = resize_pred # resize prediction to match ground truth + self.class_count = defaultdict(lambda: 0) + self.per_class = defaultdict(lambda : [0,0]) + self.ignore_mask = ignore_mask + self.custom_threshold = custom_threshold + + self.scores_ap = [] + self.scores_iou = [] + self.gts, self.preds = [], [] + self.classes = [] + + # [1:-1] ignores 0 and 1 + self.threshold_values = np.linspace(0, 1, n_values)[1:-1] + + self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) + + def add(self, pred, gt): + + pred_batch = pred[0].cpu() + + if self.sigmoid: + pred_batch = torch.sigmoid(pred_batch) + + gt_batch = gt[0].cpu() + mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) + cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) + + if self.resize_to is not None: + gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest') + pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False) + + if isinstance(cls_batch, torch.Tensor): + cls_batch = cls_batch.cpu().numpy().tolist() + + assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}' + + for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): + + if self.resize_pred: + predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True) + + p = predictions.flatten() + g = ground_truth.flatten() + + assert len(p) == len(g) + + if mask is not None: + m = mask.flatten().bool() + p = p[m] + g = g[m] + + p_sorted = p.sort() + p = p_sorted.values + g = g[p_sorted.indices] + + tps, fps, fns, tns = [], [], [], [] + for thresh in self.threshold_values: + + valid = torch.where(p > thresh)[0] + if len(valid) > 0: + n = int(valid[0]) + else: + n = len(g) + + fn = int(g[:n].sum()) + tp = int(g[n:].sum()) + fns += [fn] + tns += [n - fn] + tps += [tp] + fps += [len(g) - n - tp] + + self.metrics['tp'] += [tps] + self.metrics['fp'] += [fps] + self.metrics['fn'] += [fns] + self.metrics['tn'] += [tns] + + self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] + + def value(self): + + import time + t_start = time.time() + + if set(self.classes) == set([None]): + all_classes = None + log.warning('classes were not provided, cannot compute mIoU') + else: + all_classes = set(int(c) for c in self.classes) + # log.info(f'compute metrics for {len(all_classes)} classes') + + summed = {k: [sum([self.metrics[k][i][j] + for i in range(len(self.metrics[k]))]) + for j in range(len(self.threshold_values))] + for k in self.metrics.keys()} + + if all_classes is not None: + + assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn']) + # group by class + metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} + for i in range(len(self.metrics['tp'])): + for k in self.metrics.keys(): + metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] + + # sum over all instances within the classes + summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()} + + + # Compute average precision + + assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made' + + # only consider values where a prediction is made + precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values)) + if summed['tp'][j] + summed['fp'][j] > 0] + recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values)) + if summed['tp'][j] + summed['fp'][j] > 0] + + # remove duplicate recall-precision-pairs (and sort by recall value) + recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) + + from scipy.integrate import simps + ap = simps(precisions, recalls) + + # Compute best IoU + fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))] + + biniou_scores = [ + 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) + + 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j])) + for j in range(len(self.threshold_values)) + ] + + index_0p5 = self.threshold_values.tolist().index(0.5) + index_0p1 = self.threshold_values.tolist().index(0.1) + index_0p2 = self.threshold_values.tolist().index(0.2) + index_0p3 = self.threshold_values.tolist().index(0.3) + + if self.custom_threshold is not None: + index_ct = self.threshold_values.tolist().index(self.custom_threshold) + + if all_classes is not None: + # mean IoU + mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j]) + for c in all_classes]) + for j in range(len(self.threshold_values))] + + mean_iou_dict = { + 'miou_best': max(mean_ious) if all_classes is not None else None, + 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None, + 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None, + 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None, + 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None, + 'miou_best_t': self.threshold_values[np.argmax(mean_ious)], + 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None, + 'mean_iou_scores': mean_ious, + } + + print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s') + + return { + 'ap': ap, + + # fgiou + 'fgiou_best': max(fgiou_scores), + 'fgiou_0.5': fgiou_scores[index_0p5], + 'fgiou_0.1': fgiou_scores[index_0p1], + 'fgiou_0.2': fgiou_scores[index_0p2], + 'fgiou_0.3': fgiou_scores[index_0p3], + 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)], + + # mean iou + + + # biniou + 'biniou_best': max(biniou_scores), + 'biniou_0.5': biniou_scores[index_0p5], + 'biniou_0.1': biniou_scores[index_0p1], + 'biniou_0.2': biniou_scores[index_0p2], + 'biniou_0.3': biniou_scores[index_0p3], + 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)], + + # custom threshold + 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None, + 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None, + 'ct': self.custom_threshold, + + # statistics + 'fgiou_scores': fgiou_scores, + 'biniou_scores': biniou_scores, + 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))), + 'summed_statistics': summed, + 'summed_by_cls_statistics': summed_by_cls, + + **mean_iou_dict + } + + # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh' + + # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls} + diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/models/clipseg.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/models/clipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..430953620b0737a2db5cf6ccd9a8f98d6f518feb --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/models/clipseg.py @@ -0,0 +1,552 @@ +import math +from os.path import basename, dirname, join, isfile +import torch +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules.activation import ReLU + + +def precompute_clip_vectors(): + + from trails.initialization import init_dataset + lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True, + reduce_factor=None, add_bar=False, negative_prob=0.5) + + all_names = list(lvis.category_names.values()) + + import clip + from models.clip_prompts import imagenet_templates + clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0] + prompt_vectors = {} + for name in all_names[:100]: + with torch.no_grad(): + conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates] + text_tokens = clip.tokenize(conditionals).cuda() + cond = clip_model.encode_text(text_tokens).cpu() + + for cond, vec in zip(conditionals, cond): + prompt_vectors[cond] = vec.cpu() + + import pickle + + pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb')) + + +def get_prompt_list(prompt): + if prompt == 'plain': + return ['{}'] + elif prompt == 'fixed': + return ['a photo of a {}.'] + elif prompt == 'shuffle': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + elif prompt == 'shuffle_clip': + from models.clip_prompts import imagenet_templates + return imagenet_templates + else: + raise ValueError('Invalid value for prompt') + + +def forward_multihead_attention(x, b, with_aff=False, attn_mask=None): + """ + Simplified version of multihead attention (taken from torch source code but without tons of if clauses). + The mlp and layer norm come from CLIP. + x: input. + b: multihead attention module. + """ + + x_ = b.ln_1(x) + q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1) + tgt_len, bsz, embed_dim = q.size() + + head_dim = embed_dim // b.attn.num_heads + scaling = float(head_dim) ** -0.5 + + q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + + q = q * scaling + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2 + if attn_mask is not None: + + + attn_mask_type, attn_mask = attn_mask + n_heads = attn_output_weights.size(0) // attn_mask.size(0) + attn_mask = attn_mask.repeat(n_heads, 1) + + if attn_mask_type == 'cls_token': + # the mask only affects similarities compared to the readout-token. + attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...] + # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0] + + if attn_mask_type == 'all': + # print(attn_output_weights.shape, attn_mask[:, None].shape) + attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None] + + + attn_output_weights = torch.softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = b.attn.out_proj(attn_output) + + x = x + attn_output + x = x + b.mlp(b.ln_2(x)) + + if with_aff: + return x, attn_output_weights + else: + return x + + +class CLIPDenseBase(nn.Module): + + def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens): + super().__init__() + + import clip + + # prec = torch.FloatTensor + self.clip_model, _ = clip.load(version, device='cpu', jit=False) + self.model = self.clip_model.visual + + # if not None, scale conv weights such that we obtain n_tokens. + self.n_tokens = n_tokens + + for p in self.clip_model.parameters(): + p.requires_grad_(False) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + self.reduce = nn.Linear(768, reduce_dim) + + self.prompt_list = get_prompt_list(prompt) + + # precomputed prompts + import pickle + if isfile('precomputed_prompt_vectors.pickle'): + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + else: + self.precomputed_prompts = dict() + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + + with torch.no_grad(): + + inp_size = x_inp.shape[2:] + + if self.n_tokens is not None: + stride2 = x_inp.shape[2] // self.n_tokens + conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) + x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) + else: + x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] + + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + + standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 + + if x.shape[1] != standard_n_tokens: + new_shape = int(math.sqrt(x.shape[1]-1)) + x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] + else: + x = x + self.model.positional_embedding.to(x.dtype) + + x = self.model.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + activations, affinities = [], [] + for i, res_block in enumerate(self.model.transformer.resblocks): + + if mask is not None: + mask_layer, mask_type, mask_tensor = mask + if mask_layer == i or mask_layer == 'all': + # import ipdb; ipdb.set_trace() + size = int(math.sqrt(x.shape[0] - 1)) + + attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) + + else: + attn_mask = None + else: + attn_mask = None + + x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) + + if i in extract_layers: + affinities += [aff_per_head] + + #if self.n_tokens is not None: + # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] + #else: + activations += [x] + + if len(extract_layers) > 0 and i == max(extract_layers) and skip: + print('early skip') + break + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_post(x[:, 0, :]) + + if self.model.proj is not None: + x = x @ self.model.proj + + return x, activations, affinities + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + if self.shift_vector is not None: + return cond + self.shift_vector + else: + return cond + + +def clip_load_untrained(version): + assert version == 'ViT-B/16' + from clip.model import CLIP + from clip.clip import _MODELS, _download + model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval() + state_dict = model.state_dict() + + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) + + +class CLIPDensePredT(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False, + add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + self.rev_activations = rev_activations + + depth = len(extract_layers) + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + self.version = version + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + if fix_shift: + # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False) + self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False) + # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False) + else: + self.shift_vector = None + + if trans_conv is None: + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + else: + # explicitly define transposed conv kernel size + trans_conv_ks = (trans_conv, trans_conv) + + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + self.prompt_list = get_prompt_list(prompt) + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + _activations = activations[::-1] if not self.rev_activations else activations + + a = None + for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + a = self.trans_conv(a) + + if self.n_tokens is not None: + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, + + + +class CLIPDensePredTMasked(CLIPDensePredT): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, + prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False, + refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None): + + super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim, + n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond, + fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only, + limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration, + n_tokens=n_tokens) + + def visual_forward_masked(self, img_s, seg_s): + return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) + + def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): + + if seg_s is None: + cond = cond_or_img_s + else: + img_s = cond_or_img_s + + with torch.no_grad(): + cond, _, _ = self.visual_forward_masked(img_s, seg_s) + + return super().forward(img_q, cond, return_features=return_features) + + + +class CLIPDenseBaseline(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', cond_layer=0, + extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed', + reduce_cond=None, limit_to_clip_only=False, n_tokens=None): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + device = 'cpu' + + # self.cond_layer = cond_layer + self.extract_layer = extract_layer + self.limit_to_clip_only = limit_to_clip_only + self.shift_vector = None + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + assert reduce2_dim is not None + + self.reduce2 = nn.Sequential( + nn.Linear(reduce_dim, reduce2_dim), + nn.ReLU(), + nn.Linear(reduce2_dim, reduce_dim) + ) + + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + + def forward(self, inp_image, conditional=None, return_features=False): + + inp_image = inp_image.to(self.model.positional_embedding.device) + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer]) + + a = activations[0] + a = self.reduce(a) + a = self.film_mul(cond) * a + self.film_add(cond) + + if self.reduce2 is not None: + a = self.reduce2(a) + + # the original model would execute a transformer block here + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + a = self.trans_conv(a) + + if return_features: + return a, visual_q, cond, activations + else: + return a, + + +class CLIPSegMultiLabel(nn.Module): + + def __init__(self, model) -> None: + super().__init__() + + from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC + + self.pascal_classes = VOC + + from models.clipseg import CLIPDensePredT + from general_utils import load_model + # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False) + self.clipseg = load_model(model, strict=False) + + self.clipseg.eval() + + def forward(self, x): + + bs = x.shape[0] + out = torch.ones(21, bs, 352, 352).to(x.device) * -10 + + for class_id, class_name in enumerate(self.pascal_classes): + + fac = 3 if class_name == 'background' else 1 + + with torch.no_grad(): + pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac + + out[class_id] += pred + + + out = out.permute(1, 0, 2, 3) + + return out + + # construct output tensor + \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/models/vitseg.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/models/vitseg.py new file mode 100644 index 0000000000000000000000000000000000000000..d3231e5b6ecb81863eec0c021865ea52e9b70a15 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/models/vitseg.py @@ -0,0 +1,286 @@ +import math +from posixpath import basename, dirname, join +# import clip +from clip.model import convert_weights +import torch +import json +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules import activation +from torch.nn.modules.activation import ReLU +from torchvision import transforms + +normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + +from torchvision.models import ResNet + + +def process_prompts(conditional, prompt_list, conditional_map): + # DEPRECATED + + # randomly sample a synonym + words = [conditional_map[int(i)] for i in conditional] + words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] + words = [w.replace('_', ' ') for w in words] + + if prompt_list is not None: + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + else: + prompts = ['a photo of {}'] * (len(words)) + + return [promt.format(w) for promt, w in zip(prompts, words)] + + +class VITDenseBase(nn.Module): + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + with torch.no_grad(): + + x_inp = nnf.interpolate(x_inp, (384, 384)) + + x = self.model.patch_embed(x_inp) + cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.model.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.model.pos_drop(x + self.model.pos_embed) + + activations = [] + for i, block in enumerate(self.model.blocks): + x = block(x) + + if i in extract_layers: + # permute to be compatible with CLIP + activations += [x.permute(1,0,2)] + + x = self.model.norm(x) + x = self.model.head(self.model.pre_logits(x[:, 0])) + + # again for CLIP compatibility + # x = x.permute(1, 0, 2) + + return x, activations, None + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + return cond + + +class VITDensePredT(VITDenseBase): + + def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, + add_calibration=False, process_cond=None, not_pretrained=False): + super().__init__() + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + import timm + self.model = timm.create_model('vit_base_patch16_384', pretrained=True) + self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) + + for p in self.model.parameters(): + p.requires_grad_(False) + + import clip + self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) + # del self.clip_model.visual + + + self.token_shape = (14, 14) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + # self.film = AVAILABLE_BLOCKS['film'](512, 128) + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + # DEPRECATED + # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + trans_conv_ks = (16, 16) + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + if prompt == 'fixed': + self.prompt_list = ['a photo of a {}.'] + elif prompt == 'shuffle': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + elif prompt == 'shuffle_clip': + from models.clip_prompts import imagenet_templates + self.prompt_list = imagenet_templates + + if process_cond is not None: + if process_cond == 'clamp' or process_cond[0] == 'clamp': + + val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 + + def clamp_vec(x): + return torch.clamp(x, -val, val) + + self.process_cond = clamp_vec + + elif process_cond.endswith('.pth'): + + shift = torch.load(process_cond) + def add_shift(x): + return x + shift.to(x.device) + + self.process_cond = add_shift + + import pickle + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + # inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + inp_image_size = inp_image.shape[2:] + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + a = None + for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + if self.trans_conv is not None: + a = self.trans_conv(a) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + a = nnf.interpolate(a, inp_image_size) + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/overview.png b/extensions/deforum/scripts/deforum_helpers/src/clipseg/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..1f77bc7743746108fa34bfbaa6e9e8c4db213c8c --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c46050cecb3068b9b4b263458b5d0da154a95a7a47b8ce312ba402f0dda3cbfe +size 53964 diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/score.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/score.py new file mode 100644 index 0000000000000000000000000000000000000000..944f17d3978919f63cca7f8c16274bb3b1dd8c79 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/score.py @@ -0,0 +1,453 @@ +from torch.functional import Tensor + +import torch +import inspect +import json +import yaml +import time +import sys + +from general_utils import log + +import numpy as np +from os.path import expanduser, join, isfile, realpath + +from torch.utils.data import DataLoader + +from metrics import FixedIntervalMetrics + +from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args + + +DATASET_CACHE = dict() + +def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False): + + config = json.load(open(join('logs', checkpoint_id, 'config.json'))) + + if model_args != 'from_config' and type(model_args) != dict: + raise ValueError('model_args must either be "from_config" or a dictionary of values') + + model_cls = get_attribute(config['model']) + + # load model + if model_args == 'from_config': + _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) + + model = model_cls(**model_args) + + if weights_file is None: + weights_file = realpath(join('logs', checkpoint_id, 'weights.pth')) + else: + weights_file = realpath(join('logs', checkpoint_id, weights_file)) + + if isfile(weights_file) and not ignore_weights: + weights = torch.load(weights_file) + for _, w in weights.items(): + assert not torch.any(torch.isnan(w)), 'weights contain NaNs' + model.load_state_dict(weights, strict=strict) + else: + if not ignore_weights: + raise FileNotFoundError(f'model checkpoint {weights_file} was not found') + + if with_config: + return model, config + + return model + + +def compute_shift2(model, datasets, seed=123, repetitions=1): + """ computes shift """ + + model.eval() + model.cuda() + + import random + random.seed(seed) + + preds, gts = [], [] + for i_dataset, dataset in enumerate(datasets): + + loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) + + max_iterations = int(repetitions * len(dataset.dataset.data_list)) + + with torch.no_grad(): + + i, losses = 0, [] + for i_all, (data_x, data_y) in enumerate(loader): + + data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x] + data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y] + + pred, = model(data_x[0], data_x[1], data_x[2]) + preds += [pred.detach()] + gts += [data_y] + + i += 1 + if max_iterations and i >= max_iterations: + break + + from metrics import FixedIntervalMetrics + n_values = 51 + thresholds = np.linspace(0, 1, n_values)[1:-1] + metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values) + + for p, y in zip(preds, gts): + metric.add(p.unsqueeze(1), y) + + best_idx = np.argmax(metric.value()['fgiou_scores']) + best_thresh = thresholds[best_idx] + + return best_thresh + + +def get_cached_pascal_pfe(split, config): + from datasets.pfe_dataset import PFEPascalWrapper + try: + dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] + except KeyError: + dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support) + DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset + return dataset + + + + +def main(): + config, train_checkpoint_id = score_config_from_cli_args() + + metrics = score(config, train_checkpoint_id, None) + + for dataset in metrics.keys(): + for k in metrics[dataset]: + if type(metrics[dataset][k]) in {float, int}: + print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}') + + +def score(config, train_checkpoint_id, train_config): + + config = AttributeDict(config) + + print(config) + + # use training dataset and loss + train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json'))) + + cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else '' + + + model_cls = get_attribute(train_config['model']) + + _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters) + + model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}} + + strict_models = {'ConditionBase4', 'PFENetWrapper'} + model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args, + weights_file=f'weights{cp_str}.pth', ) + + + model.eval() + model.cuda() + + metric_args = dict() + + if 'threshold' in config: + if config.metric.split('.')[-1] == 'SkLearnMetrics': + metric_args['threshold'] = config.threshold + + if 'resize_to' in config: + metric_args['resize_to'] = config.resize_to + + if 'sigmoid' in config: + metric_args['sigmoid'] = config.sigmoid + + if 'custom_threshold' in config: + metric_args['custom_threshold'] = config.custom_threshold + + if config.test_dataset == 'pascal': + + loss_fn = get_attribute(train_config.loss) + # assume that if no split is specified in train_config, test on all splits, + + if 'splits' in config: + splits = config.splits + else: + if 'split' in train_config and type(train_config.split) == int: + # unless train_config has a split set, in that case assume train mode in training + splits = [train_config.split] + assert train_config.mode == 'train' + else: + splits = [0,1,2,3] + + log.info('Test on these splits', splits) + + scores = dict() + for split in splits: + + shift = config.shift if 'shift' in config else 0 + + # automatic shift + if shift == 'auto': + shift_compute_t = time.time() + shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac) + log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s') + + dataset = get_cached_pascal_pfe(split, config) + + eval_start_t = time.time() + + loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) + + assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1' + + metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args) + + with torch.no_grad(): + + i, losses = 0, [] + for i_all, (data_x, data_y) in enumerate(loader): + + data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] + data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] + + if config.mask == 'separate': # for old CondBase model + pred, = model(data_x[0], data_x[1], data_x[2]) + else: + # assert config.mask in {'text', 'highlight'} + pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) + + # loss = loss_fn(pred, data_y[0]) + metric.add(pred.unsqueeze(1) + shift, data_y) + + # losses += [float(loss)] + + i += 1 + if config.max_iterations and i >= config.max_iterations: + break + + #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())} + + log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.') + + print(metric.value()['mean_iou_scores']) + + scores[split] = metric.scores() + + log.info(f'Completed split {split}') + + key_prefix = config['name'] if 'name' in config else 'pas' + + all_keys = set.intersection(*[set(v.keys()) for v in scores.values()]) + + valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())] + + return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}} + + + if config.test_dataset == 'coco': + from datasets.coco_wrapper import COCOWrapper + + coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask, + with_class_label=True) + + log.info('Dataset length', len(coco_dataset)) + loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False) + + metric = get_attribute(config.metric)(resize_pred=True, **metric_args) + + shift = config.shift if 'shift' in config else 0 + + with torch.no_grad(): + + i, losses = 0, [] + for i_all, (data_x, data_y) in enumerate(loader): + data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] + data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] + + if config.mask == 'separate': # for old CondBase model + pred, = model(data_x[0], data_x[1], data_x[2]) + else: + # assert config.mask in {'text', 'highlight'} + pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) + + metric.add([pred + shift], data_y) + + i += 1 + if config.max_iterations and i >= config.max_iterations: + break + + key_prefix = config['name'] if 'name' in config else 'coco' + return {key_prefix: metric.scores()} + #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}} + + + if config.test_dataset == 'phrasecut': + from datasets.phrasecut import PhraseCut + + only_visual = config.only_visual is not None and config.only_visual + with_visual = config.with_visual is not None and config.with_visual + + dataset = PhraseCut('test', + image_size=train_config.image_size, + mask=config.mask, + with_visual=with_visual, only_visual=only_visual, aug_crop=False, + aug_color=False) + + loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False) + metric = get_attribute(config.metric)(resize_pred=True, **metric_args) + + shift = config.shift if 'shift' in config else 0 + + + with torch.no_grad(): + + i, losses = 0, [] + for i_all, (data_x, data_y) in enumerate(loader): + data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x] + data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y] + + pred, _, _, _ = model(data_x[0], data_x[1], return_features=True) + metric.add([pred + shift], data_y) + + i += 1 + if config.max_iterations and i >= config.max_iterations: + break + + key_prefix = config['name'] if 'name' in config else 'phrasecut' + return {key_prefix: metric.scores()} + #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}} + + if config.test_dataset == 'pascal_zs': + from third_party.JoEm.model.metric import Evaluator + from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC + from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS + + from models.clipseg import CLIPSegMultiLabel + + n_unseen = train_config.remove_classes[1] + + pz = PascalZeroShot('val', n_unseen, image_size=352) + m = CLIPSegMultiLabel(model=train_config.name).cuda() + m.eval(); + + print(len(pz), n_unseen) + print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set]) + + print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)]) + print('seen', [VOC[i] for i in get_seen_idx(n_unseen)]) + + loader = DataLoader(pz, batch_size=8) + evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen)) + + for i, (data_x, data_y) in enumerate(loader): + pred = m(data_x[0].cuda()) + evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy()) + + if config.max_iter is not None and i > config.max_iter: + break + + scores = evaluator.Mean_Intersection_over_Union() + key_prefix = config['name'] if 'name' in config else 'pas_zs' + + return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}} + + elif config.test_dataset in {'same_as_training', 'affordance'}: + loss_fn = get_attribute(train_config.loss) + + metric_cls = get_attribute(config.metric) + metric = metric_cls(**metric_args) + + if config.test_dataset == 'same_as_training': + dataset_cls = get_attribute(train_config.dataset) + elif config.test_dataset == 'affordance': + dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance') + dataset_name = 'aff' + else: + dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot') + dataset_name = 'lvis' + + _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) + + dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation + + if model.__class__.__name__ == 'PFENetWrapper': + dataset_args['image_size'] = config.image_size + + log.info('init dataset', str(dataset_cls)) + dataset = dataset_cls(**dataset_args) + + log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}') + + data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle) + + # explicitly set prompts + if config.prompt == 'plain': + model.prompt_list = ['{}'] + elif config.prompt == 'fixed': + model.prompt_list = ['a photo of a {}.'] + elif config.prompt == 'shuffle': + model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif config.prompt == 'shuffle_clip': + from models.clip_prompts import imagenet_templates + model.prompt_list = imagenet_templates + + config.assume_no_unused_keys(exceptions=['max_iterations']) + + t_start = time.time() + + with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9) + i, losses = 0, [] + for data_x, data_y in data_loader: + + data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] + data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] + + if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}: + pred, = model(data_x[0], data_x[1], data_x[2]) + visual_q = None + else: + pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True) + + loss = loss_fn(pred, data_y[0]) + + metric.add([pred], data_y) + + losses += [float(loss)] + + i += 1 + if config.max_iterations and i >= config.max_iterations: + break + + # scores = {m: s for m, s in zip(metric.names(), metric.value())} + scores = metric.scores() + + keys = set(scores.keys()) + if dataset.negative_prob > 0 and 'mIoU' in keys: + keys.remove('mIoU') + + name_mask = dataset.mask.replace('text_label', 'txt')[:3] + name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob) + + score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}' + + scores = {score_name: {k: v for k,v in scores.items() if k in keys}} + scores[score_name].update({'test_loss': np.mean(losses)}) + + log.info(f'Evaluation took {time.time() - t_start:.1f}s') + + return scores + else: + raise ValueError('invalid test dataset') + + + + + + + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/setup.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfa1567b1da2beec4eda74513cfdf23f8464242 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup + +with open("README.md", "r", encoding="utf-8") as readme_file: + readme = readme_file.read() + +requirements = [ + "numpy", + "scipy", + "matplotlib", + "torch", + "torchvision", + "opencv-python", + "CLIP @ git+https://github.com/openai/CLIP.git" +] + +setup( + name='clipseg', + packages=['clipseg'], + package_dir={'clipseg': 'models'}, + package_data={'clipseg': [ + "../weights/*.pth", + ]}, + version='0.0.1', + url='https://github.com/timojl/clipseg', + python_requires='>=3.9', + install_requires=requirements, + description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".', + long_description=readme, + long_description_content_type="text/markdown", +) diff --git a/extensions/deforum/scripts/deforum_helpers/src/clipseg/training.py b/extensions/deforum/scripts/deforum_helpers/src/clipseg/training.py new file mode 100644 index 0000000000000000000000000000000000000000..4e07418261e4e14e81f99cb1e9422d6d11309854 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/clipseg/training.py @@ -0,0 +1,266 @@ +import torch +import inspect +import json +import yaml +import math +import os +import sys + +from general_utils import log + +import numpy as np +from functools import partial +from os.path import expanduser, join, isfile, basename + +from torch.cuda.amp import autocast, GradScaler +from torch.optim.lr_scheduler import LambdaLR +from contextlib import nullcontext +from torch.utils.data import DataLoader + +from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args + + +def cosine_warmup_lr(i, warmup=10, max_iter=90): + """ Cosine LR with Warmup """ + if i < warmup: + return (i+1)/(warmup+1) + else: + return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup)))) + + +def validate(model, dataset, config): + data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False) + + metric_class, use_metric = config.val_metric_class, config.use_val_metric + loss_fn = get_attribute(config.loss) + + model.eval() + model.cuda() + + if metric_class is not None: + metric = get_attribute(metric_class)() + + with torch.no_grad(): + + i, losses = 0, [] + for data_x, data_y in data_loader: + + data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x] + data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y] + + prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',)) + pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True) + + if metric_class is not None: + metric.add([pred], data_y) + + # pred = model(data_x[0], prompts) + # loss = loss_fn(pred[0], data_y[0]) + loss = loss_fn(pred, data_y[0]) + losses += [float(loss)] + + i += 1 + + if config.val_max_iterations is not None and i > config.val_max_iterations: + break + + if use_metric is None: + return np.mean(losses), {}, False + else: + metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {} + return np.mean(losses), metric_scores, True + + +def main(): + + config = training_config_from_cli_args() + + val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf') + + model_cls = get_attribute(config.model) + _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) + model = model_cls(**model_args).cuda() + + dataset_cls = get_attribute(config.dataset) + _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters) + + dataset = dataset_cls(**dataset_args) + + log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})') + + if val_interval is not None: + dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'} + _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters) + print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) + + dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args}) + + # optimizer + opt_cls = get_attribute(config.optimizer) + if config.optimize == 'torch.optim.SGD': + opt_args = {'momentum': config.momentum if 'momentum' in config else 0} + else: + opt_args = {} + opt = opt_cls(model.parameters(), lr=config.lr, **opt_args) + + if config.lr_scheduler == 'cosine': + assert config.T_max is not None and config.eta_min is not None + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min) + elif config.lr_scheduler == 'warmup_cosine': + lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup)) + else: + lr_scheduler = None + + batch_size, max_iterations = config.batch_size, config.max_iterations + + loss_fn = get_attribute(config.loss) + + if config.amp: + log.info('Using AMP') + autocast_fn = autocast + scaler = GradScaler() + else: + autocast_fn, scaler = nullcontext, None + + + save_only_trainable = True + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4) + + # disable config when hyperparam. opt. to avoid writing logs. + tracker_config = config if not config.hyperparameter_optimization else None + + with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger: + + i = 0 + while True: + for data_x, data_y in data_loader: + + # between caption and output feature. + # 1. Sample random captions + # 2. Check alignment with CLIP + + # randomly mix text and visual support conditionals + if config.mix: + + assert config.mask.startswith('text_and') + + with autocast_fn(): + # data_x[1] = text label + prompts = model.sample_prompts(data_x[1]) + + # model.clip_model() + + text_cond = model.compute_conditional(prompts) + if model.__class__.__name__ == 'CLIPDensePredTMasked': + # when mask=='separate' + visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda()) + else: + # data_x[2] = visual prompt + visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda()) + + max_txt = config.mix_text_max if config.mix_text_max is not None else 1 + batch_size = text_cond.shape[0] + + # sample weights for each element in batch + text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None] + text_weights = text_weights.cuda() + + if dataset.__class__.__name__ == 'PhraseCut': + # give full weight to text where support_image is invalid + visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3] + text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1) + + cond = text_cond * text_weights + visual_s_cond * (1 - text_weights) + + else: + # no mix + + if model.__class__.__name__ == 'CLIPDensePredTMasked': + # compute conditional vector using CLIP masking + with autocast_fn(): + assert config.mask == 'separate' + cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda()) + else: + cond = data_x[1] + if isinstance(cond, torch.Tensor): + cond = cond.cuda() + + with autocast_fn(): + visual_q = None + + pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True) + + loss = loss_fn(pred, data_y[0].cuda()) + + if torch.isnan(loss) or torch.isinf(loss): + # skip if loss is nan + log.warning('Training stopped due to inf/nan loss.') + sys.exit(-1) + + extra_loss = 0 + loss += extra_loss + + opt.zero_grad() + + if scaler is None: + loss.backward() + opt.step() + else: + scaler.scale(loss).backward() + scaler.step(opt) + scaler.update() + + if lr_scheduler is not None: + lr_scheduler.step() + if i % 2000 == 0: + current_lr = [g['lr'] for g in opt.param_groups][0] + log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)') + + logger.iter(i=i, loss=loss) + i += 1 + + if i >= max_iterations: + + if not isfile(join(logger.base_path, 'weights.pth')): + # only write if no weights were already written + logger.save_weights(only_trainable=save_only_trainable) + + sys.exit(0) + + + if config.checkpoint_iterations is not None and i in config.checkpoint_iterations: + logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth') + + + if val_interval is not None and i % val_interval == val_interval - 1: + + val_loss, val_scores, maximize = validate(model, dataset_val, config) + + if len(val_scores) > 0: + + score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items()) + + if maximize and val_scores[config.use_val_metric] > best_val_score: + logger.save_weights(only_trainable=save_only_trainable) + best_val_score = val_scores[config.use_val_metric] + + elif not maximize and val_scores[config.use_val_metric] < best_val_score: + logger.save_weights(only_trainable=save_only_trainable) + best_val_score = val_scores[config.use_val_metric] + + else: + score_str = '' + # if no score is used, fall back to loss + if val_loss < best_val_loss: + logger.save_weights(only_trainable=save_only_trainable) + best_val_loss = val_loss + + log.info(f'Validation loss: {val_loss}' + score_str) + logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores) + model.train() + + print('epoch complete') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/film_interpolation/film_inference.py b/extensions/deforum/scripts/deforum_helpers/src/film_interpolation/film_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..12a07696d4abaf2d3d145e75e044a8879f69be0e --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/film_interpolation/film_inference.py @@ -0,0 +1,118 @@ +import os +from glob import glob +import bisect +from tqdm import tqdm +import torch +import numpy as np +import cv2 +from .film_util import load_image +import time +from types import SimpleNamespace +import warnings +warnings.filterwarnings("ignore") + +def run_film_interp_infer( + model_path = None, + input_folder = None, + save_folder = None, + inter_frames = None): + + args = SimpleNamespace() + args.model_path = model_path + args.input_folder = input_folder + args.save_folder = save_folder + args.inter_frames = inter_frames + + # Check if the folder exists + if not os.path.exists(args.input_folder): + print(f"Error: Folder '{args.input_folder}' does not exist.") + return + # Check if the folder contains any PNG or JPEG images + if not any([f.endswith(".png") or f.endswith(".jpg") for f in os.listdir(args.input_folder)]): + print(f"Error: Folder '{args.input_folder}' does not contain any PNG or JPEG images.") + return + + start_time = time.time() # Timer START + + # Sort Jpg/Png images by name + image_paths = sorted(glob(os.path.join(args.input_folder, "*.[jJ][pP][gG]")) + glob(os.path.join(args.input_folder, "*.[pP][nN][gG]"))) + print(f"Total frames to FILM-interpolate: {len(image_paths)}. Total frame-pairs: {len(image_paths)-1}.") + + model = torch.jit.load(args.model_path, map_location='cpu') + model.eval() + + for i in tqdm(range(len(image_paths) - 1), desc='FILM progress'): + img1 = image_paths[i] + img2 = image_paths[i+1] + img_batch_1, crop_region_1 = load_image(img1) + img_batch_2, crop_region_2 = load_image(img2) + img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2) + img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2) + + model = model.half() + model = model.cuda() + + save_path = os.path.join(args.save_folder, f"{i}_to_{i+1}.jpg") + + results = [ + img_batch_1, + img_batch_2 + ] + + idxes = [0, inter_frames + 1] + remains = list(range(1, inter_frames + 1)) + + splits = torch.linspace(0, 1, inter_frames + 2) + + inner_loop_progress = tqdm(range(len(remains)), leave=False, disable=True) + for _ in inner_loop_progress: + starts = splits[idxes[:-1]] + ends = splits[idxes[1:]] + distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() + matrix = torch.argmin(distances).item() + start_i, step = np.unravel_index(matrix, distances.shape) + end_i = start_i + 1 + + x0 = results[start_i] + x1 = results[end_i] + + x0 = x0.half() + x1 = x1.half() + x0 = x0.cuda() + x1 = x1.cuda() + + dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) + + with torch.no_grad(): + prediction = model(x0, x1, dt) + insert_position = bisect.bisect_left(idxes, remains[step]) + idxes.insert(insert_position, remains[step]) + results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) + inner_loop_progress.update(1) + del remains[step] + inner_loop_progress.close() + # create output folder for interoplated imgs to live in + os.makedirs(args.save_folder, exist_ok=True) + + y1, x1, y2, x2 = crop_region_1 + frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy()[y1:y2, x1:x2].copy() for tensor in results] + + existing_files = os.listdir(args.save_folder) + if len(existing_files) > 0: + existing_numbers = [int(file.split("_")[1].split(".")[0]) for file in existing_files] + next_number = max(existing_numbers) + 1 + else: + next_number = 0 + + outer_loop_count = i + for i, frame in enumerate(frames): + frame_path = os.path.join(args.save_folder, f"frame_{next_number:05d}.png") + # last pair, save all frames including the last one + if len(image_paths) - 2 == outer_loop_count: + cv2.imwrite(frame_path, frame) + else: # not last pair, don't save the last frame + if not i == len(frames) - 1: + cv2.imwrite(frame_path, frame) + next_number += 1 + + print(f"Interpolation \033[0;32mdone\033[0m in {time.time()-start_time:.2f} seconds!") \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/film_interpolation/film_util.py b/extensions/deforum/scripts/deforum_helpers/src/film_interpolation/film_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e510758e53ced0af433fc14f63bf9b504e256544 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/film_interpolation/film_util.py @@ -0,0 +1,161 @@ +"""Various utilities used in the film_net frame interpolator model.""" +from typing import List, Optional + +import cv2 +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def pad_batch(batch, align): + height, width = batch.shape[1:3] + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)] + batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant') + return batch, crop_region + + +def load_image(path, align=64): + image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) + image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align) + return image_batch, crop_region + + +def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]: + """Builds an image pyramid from a given image. + + The original image is included in the pyramid and the rest are generated by + successively halving the resolution. + + Args: + image: the input image. + options: film_net options object + + Returns: + A list of images starting from the finest with options.pyramid_levels items + """ + + pyramid = [] + for i in range(pyramid_levels): + pyramid.append(image) + if i < pyramid_levels - 1: + image = F.avg_pool2d(image, 2, 2) + return pyramid + + +def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: + """Backward warps the image using the given flow. + + Specifically, the output pixel in batch b, at position x, y will be computed + as follows: + (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) + output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) + + Note that the flow vectors are expected as [x, y], e.g. x in position 0 and + y in position 1. + + Args: + image: An image with shape BxHxWxC. + flow: A flow with shape BxHxWx2, with the two channels denoting the relative + offset in order: (dx, dy). + Returns: + A warped image. + """ + flow = -flow.flip(1) + + dtype = flow.dtype + device = flow.device + + # warped = tfa_image.dense_image_warp(image, flow) + # Same as above but with pytorch + ls1 = 1 - 1 / flow.shape[3] + ls2 = 1 - 1 / flow.shape[2] + + normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor( + [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None] + normalized_flow2 = torch.stack([ + torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1], + torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0], + ], dim=3) + + warped = F.grid_sample(image, normalized_flow2, + mode='bilinear', padding_mode='border', align_corners=False) + return warped.reshape(image.shape) + + +def multiply_pyramid(pyramid: List[torch.Tensor], + scalar: torch.Tensor) -> List[torch.Tensor]: + """Multiplies all image batches in the pyramid by a batch of scalars. + + Args: + pyramid: Pyramid of image batches. + scalar: Batch of scalars. + + Returns: + An image pyramid with all images multiplied by the scalar. + """ + # To multiply each image with its corresponding scalar, we first transpose + # the batch of images from BxHxWxC-format to CxHxWxB. This can then be + # multiplied with a batch of scalars, then we transpose back to the standard + # BxHxWxC form. + return [image * scalar for image in pyramid] + + +def flow_pyramid_synthesis( + residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: + """Converts a residual flow pyramid into a flow pyramid.""" + flow = residual_pyramid[-1] + flow_pyramid: List[torch.Tensor] = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + level_size = residual_flow.shape[2:4] + flow = F.interpolate(2 * flow, size=level_size, mode='bilinear') + flow = residual_flow + flow + flow_pyramid.insert(0, flow) + return flow_pyramid + + +def pyramid_warp(feature_pyramid: List[torch.Tensor], + flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: + """Warps the feature pyramid using the flow pyramid. + + Args: + feature_pyramid: feature pyramid starting from the finest level. + flow_pyramid: flow fields, starting from the finest level. + + Returns: + Reverse warped feature pyramid. + """ + warped_feature_pyramid = [] + for features, flow in zip(feature_pyramid, flow_pyramid): + warped_feature_pyramid.append(warp(features, flow)) + return warped_feature_pyramid + + +def concatenate_pyramids(pyramid1: List[torch.Tensor], + pyramid2: List[torch.Tensor]) -> List[torch.Tensor]: + """Concatenates each pyramid level together in the channel dimension.""" + result = [] + for features1, features2 in zip(pyramid1, pyramid2): + result.append(torch.cat([features1, features2], dim=1)) + return result + + +def conv(in_channels, out_channels, size, activation: Optional[str] = 'relu'): + # Since PyTorch doesn't have an in-built activation in Conv2d, we use a + # Sequential layer to combine Conv2d and Leaky ReLU in one module. + _conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=size, + padding='same') + if activation is None: + return _conv + assert activation == 'relu' + return nn.Sequential( + _conv, + nn.LeakyReLU(.2) + ) diff --git a/extensions/deforum/scripts/deforum_helpers/src/infer.py b/extensions/deforum/scripts/deforum_helpers/src/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..4474e972a7c2ad25e1078b6549805dc26164fdbb --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/infer.py @@ -0,0 +1,165 @@ +import glob +import os + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torchvision import transforms +from tqdm import tqdm + +import model_io +import utils +from adabins import UnetAdaptiveBins + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +class ToTensor(object): + def __init__(self): + self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def __call__(self, image, target_size=(640, 480)): + # image = image.resize(target_size) + image = self.to_tensor(image) + image = self.normalize(image) + return image + + def to_tensor(self, pic): + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError( + 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class InferenceHelper: + def __init__(self, models_path, dataset='nyu', device='cuda:0'): + self.toTensor = ToTensor() + self.device = device + if dataset == 'nyu': + self.min_depth = 1e-3 + self.max_depth = 10 + self.saving_factor = 1000 # used to save in 16 bit + model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) + pretrained_path = os.path.join(models_path, "AdaBins_nyu.pt") + elif dataset == 'kitti': + self.min_depth = 1e-3 + self.max_depth = 80 + self.saving_factor = 256 + model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) + pretrained_path = os.path.join(models_path, "AdaBins_kitti.pt") + else: + raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset)) + + model, _, _ = model_io.load_checkpoint(pretrained_path, model) + model.eval() + self.model = model.to(self.device) + + @torch.no_grad() + def predict_pil(self, pil_image, visualized=False): + # pil_image = pil_image.resize((640, 480)) + img = np.asarray(pil_image) / 255. + + img = self.toTensor(img).unsqueeze(0).float().to(self.device) + bin_centers, pred = self.predict(img) + + if visualized: + viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma') + # pred = np.asarray(pred*1000, dtype='uint16') + viz = Image.fromarray(viz) + return bin_centers, pred, viz + return bin_centers, pred + + @torch.no_grad() + def predict(self, image): + bins, pred = self.model(image) + pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth) + + # Flip + image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device) + pred_lr = self.model(image)[-1] + pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth) + + # Take average of original and mirror + final = 0.5 * (pred + pred_lr) + final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], + mode='bilinear', align_corners=True).cpu().numpy() + + final[final < self.min_depth] = self.min_depth + final[final > self.max_depth] = self.max_depth + final[np.isinf(final)] = self.max_depth + final[np.isnan(final)] = self.min_depth + + centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) + centers = centers.cpu().squeeze().numpy() + centers = centers[centers > self.min_depth] + centers = centers[centers < self.max_depth] + + return centers, final + + @torch.no_grad() + def predict_dir(self, test_dir, out_dir): + os.makedirs(out_dir, exist_ok=True) + transform = ToTensor() + all_files = glob.glob(os.path.join(test_dir, "*")) + self.model.eval() + for f in tqdm(all_files): + image = np.asarray(Image.open(f), dtype='float32') / 255. + image = transform(image).unsqueeze(0).to(self.device) + + centers, final = self.predict(image) + # final = final.squeeze().cpu().numpy() + + final = (final * self.saving_factor).astype('uint16') + basename = os.path.basename(f).split('.')[0] + save_path = os.path.join(out_dir, basename + ".png") + + Image.fromarray(final.squeeze()).save(save_path) + + def to(self, device): + self.device = device + self.model.to(device) + + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from time import time + + img = Image.open("test_imgs/classroom__rgb_00283.jpg") + start = time() + inferHelper = InferenceHelper() + centers, pred = inferHelper.predict_pil(img) + print(f"took :{time() - start}s") + plt.imshow(pred.squeeze(), cmap='magma_r') + plt.show() diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/base_model.py b/extensions/deforum/scripts/deforum_helpers/src/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/blocks.py b/extensions/deforum/scripts/deforum_helpers/src/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/dpt_depth.py b/extensions/deforum/scripts/deforum_helpers/src/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/midas_net.py b/extensions/deforum/scripts/deforum_helpers/src/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/midas_net_custom.py b/extensions/deforum/scripts/deforum_helpers/src/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/transforms.py b/extensions/deforum/scripts/deforum_helpers/src/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/extensions/deforum/scripts/deforum_helpers/src/midas/vit.py b/extensions/deforum/scripts/deforum_helpers/src/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/extensions/deforum/scripts/deforum_helpers/src/model_io.py b/extensions/deforum/scripts/deforum_helpers/src/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..3427be8176f178c4c3ef09664a3f28d9fbaab4c3 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/model_io.py @@ -0,0 +1,74 @@ +import os + +import torch + + +def save_weights(model, filename, path="./saved_models"): + if not os.path.isdir(path): + os.makedirs(path) + + fpath = os.path.join(path, filename) + torch.save(model.state_dict(), fpath) + return + + +def save_checkpoint(model, optimizer, epoch, filename, root="./checkpoints"): + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch + } + , fpath) + + +def load_weights(model, filename, path="./saved_models"): + fpath = os.path.join(path, filename) + state_dict = torch.load(fpath) + model.load_state_dict(state_dict) + return model + + +def load_checkpoint(fpath, model, optimizer=None): + ckpt = torch.load(fpath, map_location='cpu') + if ckpt is None: + raise Exception(f"\nERROR Loading AdaBins_nyu.pt. Read this for a fix:\nhttps://github.com/deforum-art/deforum-for-automatic1111-webui/wiki/FAQ-&-Troubleshooting#3d-animation-mode-is-not-working-only-2d-works") + if optimizer is None: + optimizer = ckpt.get('optimizer', None) + else: + optimizer.load_state_dict(ckpt['optimizer']) + epoch = ckpt['epoch'] + + if 'model' in ckpt: + ckpt = ckpt['model'] + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + modified = {} # backward compatibility to older naming of architecture blocks + for k, v in load_dict.items(): + if k.startswith('adaptive_bins_layer.embedding_conv.'): + k_ = k.replace('adaptive_bins_layer.embedding_conv.', + 'adaptive_bins_layer.conv3x3.') + modified[k_] = v + # del load_dict[k] + + elif k.startswith('adaptive_bins_layer.patch_transformer.embedding_encoder'): + + k_ = k.replace('adaptive_bins_layer.patch_transformer.embedding_encoder', + 'adaptive_bins_layer.patch_transformer.embedding_convPxP') + modified[k_] = v + # del load_dict[k] + else: + modified[k] = v # else keep the original + + model.load_state_dict(modified) + return model, optimizer, epoch diff --git a/extensions/deforum/scripts/deforum_helpers/src/py3d_tools.py b/extensions/deforum/scripts/deforum_helpers/src/py3d_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb958607c4fd405a06bb67e33963e744fd2306f --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/py3d_tools.py @@ -0,0 +1,1801 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import math +import warnings +from typing import List, Optional, Sequence, Tuple, Union, Any + +import numpy as np +import torch +import torch.nn.functional as F + +import copy +import inspect +import torch.nn as nn + +Device = Union[str, torch.device] + +# Default values for rotation and translation matrices. +_R = torch.eye(3)[None] # (1, 3, 3) +_T = torch.zeros(1, 3) # (1, 3) + + +# Provide get_origin and get_args even in Python 3.7. + +if sys.version_info >= (3, 8, 0): + from typing import get_args, get_origin +elif sys.version_info >= (3, 7, 0): + + def get_origin(cls): # pragma: no cover + return getattr(cls, "__origin__", None) + + def get_args(cls): # pragma: no cover + return getattr(cls, "__args__", None) + + +else: + raise ImportError("This module requires Python 3.7+") + +################################################################ +## ██████╗██╗ █████╗ ███████╗███████╗███████╗███████╗ ## +## ██╔════╝██║ ██╔══██╗██╔════╝██╔════╝██╔════╝██╔════╝ ## +## ██║ ██║ ███████║███████╗███████╗█████╗ ███████╗ ## +## ██║ ██║ ██╔══██║╚════██║╚════██║██╔══╝ ╚════██║ ## +## ╚██████╗███████╗██║ ██║███████║███████║███████╗███████║ ## +## ╚═════╝╚══════╝╚═╝ ╚═╝╚══════╝╚══════╝╚══════╝╚══════╝ ## +################################################################ + +class Transform3d: + """ + A Transform3d object encapsulates a batch of N 3D transformations, and knows + how to transform points and normal vectors. Suppose that t is a Transform3d; + then we can do the following: + + .. code-block:: python + + N = len(t) + points = torch.randn(N, P, 3) + normals = torch.randn(N, P, 3) + points_transformed = t.transform_points(points) # => (N, P, 3) + normals_transformed = t.transform_normals(normals) # => (N, P, 3) + + + BROADCASTING + Transform3d objects supports broadcasting. Suppose that t1 and tN are + Transform3d objects with len(t1) == 1 and len(tN) == N respectively. Then we + can broadcast transforms like this: + + .. code-block:: python + + t1.transform_points(torch.randn(P, 3)) # => (P, 3) + t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3) + t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3) + tN.transform_points(torch.randn(P, 3)) # => (N, P, 3) + tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3) + + + COMBINING TRANSFORMS + Transform3d objects can be combined in two ways: composing and stacking. + Composing is function composition. Given Transform3d objects t1, t2, t3, + the following all compute the same thing: + + .. code-block:: python + + y1 = t3.transform_points(t2.transform_points(t1.transform_points(x))) + y2 = t1.compose(t2).compose(t3).transform_points(x) + y3 = t1.compose(t2, t3).transform_points(x) + + + Composing transforms should broadcast. + + .. code-block:: python + + if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N. + + We can also stack a sequence of Transform3d objects, which represents + composition along the batch dimension; then the following should compute the + same thing. + + .. code-block:: python + + N, M = len(tN), len(tM) + xN = torch.randn(N, P, 3) + xM = torch.randn(M, P, 3) + y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0) + y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0)) + + BUILDING TRANSFORMS + We provide convenience methods for easily building Transform3d objects + as compositions of basic transforms. + + .. code-block:: python + + # Scale by 0.5, then translate by (1, 2, 3) + t1 = Transform3d().scale(0.5).translate(1, 2, 3) + + # Scale each axis by a different amount, then translate, then scale + t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0) + + t3 = t1.compose(t2) + tN = t1.stack(t3, t3) + + + BACKPROP THROUGH TRANSFORMS + When building transforms, we can also parameterize them by Torch tensors; + in this case we can backprop through the construction and application of + Transform objects, so they could be learned via gradient descent or + predicted by a neural network. + + .. code-block:: python + + s1_params = torch.randn(N, requires_grad=True) + t_params = torch.randn(N, 3, requires_grad=True) + s2_params = torch.randn(N, 3, requires_grad=True) + + t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params) + x = torch.randn(N, 3) + y = t.transform_points(x) + loss = compute_loss(y) + loss.backward() + + with torch.no_grad(): + s1_params -= lr * s1_params.grad + t_params -= lr * t_params.grad + s2_params -= lr * s2_params.grad + + CONVENTIONS + We adopt a right-hand coordinate system, meaning that rotation about an axis + with a positive angle results in a counter clockwise rotation. + + This class assumes that transformations are applied on inputs which + are row vectors. The internal representation of the Nx4x4 transformation + matrix is of the form: + + .. code-block:: python + + M = [ + [Rxx, Ryx, Rzx, 0], + [Rxy, Ryy, Rzy, 0], + [Rxz, Ryz, Rzz, 0], + [Tx, Ty, Tz, 1], + ] + + To apply the transformation to points which are row vectors, the M matrix + can be pre multiplied by the points: + + .. code-block:: python + + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * M + + """ + + def __init__( + self, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", + matrix: Optional[torch.Tensor] = None, + ) -> None: + """ + Args: + dtype: The data type of the transformation matrix. + to be used if `matrix = None`. + device: The device for storing the implemented transformation. + If `matrix != None`, uses the device of input `matrix`. + matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4) + representing the 4x4 3D transformation matrix. + If `None`, initializes with identity using + the specified `device` and `dtype`. + """ + + if matrix is None: + self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) + else: + if matrix.ndim not in (2, 3): + raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.') + if matrix.shape[-2] != 4 or matrix.shape[-1] != 4: + raise ValueError( + '"matrix" has to be a tensor of shape (minibatch, 4, 4)' + ) + # set dtype and device from matrix + dtype = matrix.dtype + device = matrix.device + self._matrix = matrix.view(-1, 4, 4) + + self._transforms = [] # store transforms to compose + self._lu = None + self.device = make_device(device) + self.dtype = dtype + + def __len__(self) -> int: + return self.get_matrix().shape[0] + + def __getitem__( + self, index: Union[int, List[int], slice, torch.Tensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(matrix=self.get_matrix()[index]) + + def compose(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new Transform3d representing the composition of self with the + given other transforms, which will be stored as an internal list. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d with the stored transforms + """ + out = Transform3d(dtype=self.dtype, device=self.device) + out._matrix = self._matrix.clone() + for other in others: + if not isinstance(other, Transform3d): + msg = "Only possible to compose Transform3d objects; got %s" + raise ValueError(msg % type(other)) + out._transforms = self._transforms + list(others) + return out + + def get_matrix(self) -> torch.Tensor: + """ + Return a matrix which is the result of composing this transform + with others stored in self.transforms. Where necessary transforms + are broadcast against each other. + For example, if self.transforms contains transforms t1, t2, and t3, and + given a set of points x, the following should be true: + + .. code-block:: python + + y1 = t1.compose(t2, t3).transform(x) + y2 = t3.transform(t2.transform(t1.transform(x))) + y1.get_matrix() == y2.get_matrix() + + Returns: + A transformation matrix representing the composed inputs. + """ + composed_matrix = self._matrix.clone() + if len(self._transforms) > 0: + for other in self._transforms: + other_matrix = other.get_matrix() + composed_matrix = _broadcast_bmm(composed_matrix, other_matrix) + return composed_matrix + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + return torch.inverse(self._matrix) + + def inverse(self, invert_composed: bool = False) -> "Transform3d": + """ + Returns a new Transform3d object that represents an inverse of the + current transformation. + + Args: + invert_composed: + - True: First compose the list of stored transformations + and then apply inverse to the result. This is + potentially slower for classes of transformations + with inverses that can be computed efficiently + (e.g. rotations and translations). + - False: Invert the individual stored transformations + independently without composing them. + + Returns: + A new Transform3d object containing the inverse of the original + transformation. + """ + + tinv = Transform3d(dtype=self.dtype, device=self.device) + + if invert_composed: + # first compose then invert + tinv._matrix = torch.inverse(self.get_matrix()) + else: + # self._get_matrix_inverse() implements efficient inverse + # of self._matrix + i_matrix = self._get_matrix_inverse() + + # 2 cases: + if len(self._transforms) > 0: + # a) Either we have a non-empty list of transforms: + # Here we take self._matrix and append its inverse at the + # end of the reverted _transforms list. After composing + # the transformations with get_matrix(), this correctly + # right-multiplies by the inverse of self._matrix + # at the end of the composition. + tinv._transforms = [t.inverse() for t in reversed(self._transforms)] + last = Transform3d(dtype=self.dtype, device=self.device) + last._matrix = i_matrix + tinv._transforms.append(last) + else: + # b) Or there are no stored transformations + # we just set inverted matrix + tinv._matrix = i_matrix + + return tinv + + def stack(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new batched Transform3d representing the batch elements from + self and all the given other transforms all batched together. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d. + """ + transforms = [self] + list(others) + matrix = torch.cat([t.get_matrix() for t in transforms], dim=0) + out = Transform3d(dtype=self.dtype, device=self.device) + out._matrix = matrix + return out + + def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor: + """ + Use this transform to transform a set of 3D points. Assumes row major + ordering of the input points. + + Args: + points: Tensor of shape (P, 3) or (N, P, 3) + eps: If eps!=None, the argument is used to clamp the + last coordinate before performing the final division. + The clamping corresponds to: + last_coord := (last_coord.sign() + (last_coord==0)) * + torch.clamp(last_coord.abs(), eps), + i.e. the last coordinates that are exactly 0 will + be clamped to +eps. + + Returns: + points_out: points of shape (N, P, 3) or (P, 3) depending + on the dimensions of the transform + """ + points_batch = points.clone() + if points_batch.dim() == 2: + points_batch = points_batch[None] # (P, 3) -> (1, P, 3) + if points_batch.dim() != 3: + msg = "Expected points to have dim = 2 or dim = 3: got shape %r" + raise ValueError(msg % repr(points.shape)) + + N, P, _3 = points_batch.shape + ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device) + points_batch = torch.cat([points_batch, ones], dim=2) + + composed_matrix = self.get_matrix() + points_out = _broadcast_bmm(points_batch, composed_matrix) + denom = points_out[..., 3:] # denominator + if eps is not None: + denom_sign = denom.sign() + (denom == 0.0).type_as(denom) + denom = denom_sign * torch.clamp(denom.abs(), eps) + points_out = points_out[..., :3] / denom + + # When transform is (1, 4, 4) and points is (P, 3) return + # points_out of shape (P, 3) + if points_out.shape[0] == 1 and points.dim() == 2: + points_out = points_out.reshape(points.shape) + + return points_out + + def transform_normals(self, normals) -> torch.Tensor: + """ + Use this transform to transform a set of normal vectors. + + Args: + normals: Tensor of shape (P, 3) or (N, P, 3) + + Returns: + normals_out: Tensor of shape (P, 3) or (N, P, 3) depending + on the dimensions of the transform + """ + if normals.dim() not in [2, 3]: + msg = "Expected normals to have dim = 2 or dim = 3: got shape %r" + raise ValueError(msg % (normals.shape,)) + composed_matrix = self.get_matrix() + + # TODO: inverse is bad! Solve a linear system instead + mat = composed_matrix[:, :3, :3] + normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse()) + + # This doesn't pass unit tests. TODO investigate further + # if self._lu is None: + # self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu() + # normals_out = normals.lu_solve(*self._lu) + + # When transform is (1, 4, 4) and normals is (P, 3) return + # normals_out of shape (P, 3) + if normals_out.shape[0] == 1 and normals.dim() == 2: + normals_out = normals_out.reshape(normals.shape) + + return normals_out + + def translate(self, *args, **kwargs) -> "Transform3d": + return self.compose( + Translate(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def scale(self, *args, **kwargs) -> "Transform3d": + return self.compose( + Scale(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def rotate(self, *args, **kwargs) -> "Transform3d": + return self.compose( + Rotate(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": + return self.compose( + RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def clone(self) -> "Transform3d": + """ + Deep copy of Transforms object. All internal tensors are cloned + individually. + + Returns: + new Transforms object. + """ + other = Transform3d(dtype=self.dtype, device=self.device) + if self._lu is not None: + other._lu = [elem.clone() for elem in self._lu] + other._matrix = self._matrix.clone() + other._transforms = [t.clone() for t in self._transforms] + return other + + def to( + self, + device: Device, + copy: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> "Transform3d": + """ + Match functionality of torch.Tensor.to() + If copy = True or the self Tensor is on a different device, the + returned tensor is a copy of self with the desired torch.device. + If copy = False and the self Tensor already has the correct torch.device, + then self is returned. + + Args: + device: Device (as str or torch.device) for the new tensor. + copy: Boolean indicator whether or not to clone self. Default False. + dtype: If not None, casts the internal tensor variables + to a given torch.dtype. + + Returns: + Transform3d object. + """ + device_ = make_device(device) + dtype_ = self.dtype if dtype is None else dtype + skip_to = self.device == device_ and self.dtype == dtype_ + + if not copy and skip_to: + return self + + other = self.clone() + + if skip_to: + return other + + other.device = device_ + other.dtype = dtype_ + other._matrix = other._matrix.to(device=device_, dtype=dtype_) + other._transforms = [ + t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms + ] + return other + + def cpu(self) -> "Transform3d": + return self.to("cpu") + + def cuda(self) -> "Transform3d": + return self.to("cuda") + +class Translate(Transform3d): + def __init__( + self, + x, + y=None, + z=None, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + ) -> None: + """ + Create a new Transform3d representing 3D translations. + + Option I: Translate(xyz, dtype=torch.float32, device='cpu') + xyz should be a tensor of shape (N, 3) + + Option II: Translate(x, y, z, dtype=torch.float32, device='cpu') + Here x, y, and z will be broadcast against each other and + concatenated to form the translation. Each can be: + - A python scalar + - A torch scalar + - A 1D torch tensor + """ + xyz = _handle_input(x, y, z, dtype, device, "Translate") + super().__init__(device=xyz.device, dtype=dtype) + N = xyz.shape[0] + + mat = torch.eye(4, dtype=dtype, device=self.device) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, 3, :3] = xyz + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + inv_mask = self._matrix.new_ones([1, 4, 4]) + inv_mask[0, 3, :3] = -1.0 + i_matrix = self._matrix * inv_mask + return i_matrix + +class Rotate(Transform3d): + def __init__( + self, + R: torch.Tensor, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + orthogonal_tol: float = 1e-5, + ) -> None: + """ + Create a new Transform3d representing 3D rotation using a rotation + matrix as the input. + + Args: + R: a tensor of shape (3, 3) or (N, 3, 3) + orthogonal_tol: tolerance for the test of the orthogonality of R + + """ + device_ = get_device(R, device) + super().__init__(device=device_, dtype=dtype) + if R.dim() == 2: + R = R[None] + if R.shape[-2:] != (3, 3): + msg = "R must have shape (3, 3) or (N, 3, 3); got %s" + raise ValueError(msg % repr(R.shape)) + R = R.to(device=device_, dtype=dtype) + _check_valid_rotation_matrix(R, tol=orthogonal_tol) + N = R.shape[0] + mat = torch.eye(4, dtype=dtype, device=device_) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, :3, :3] = R + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + return self._matrix.permute(0, 2, 1).contiguous() + +class TensorAccessor(nn.Module): + """ + A helper class to be used with the __getitem__ method. This can be used for + getting/setting the values for an attribute of a class at one particular + index. This is useful when the attributes of a class are batched tensors + and one element in the batch needs to be modified. + """ + + def __init__(self, class_object, index: Union[int, slice]) -> None: + """ + Args: + class_object: this should be an instance of a class which has + attributes which are tensors representing a batch of + values. + index: int/slice, an index indicating the position in the batch. + In __setattr__ and __getattr__ only the value of class + attributes at this index will be accessed. + """ + self.__dict__["class_object"] = class_object + self.__dict__["index"] = index + + def __setattr__(self, name: str, value: Any): + """ + Update the attribute given by `name` to the value given by `value` + at the index specified by `self.index`. + Args: + name: str, name of the attribute. + value: value to set the attribute to. + """ + v = getattr(self.class_object, name) + if not torch.is_tensor(v): + msg = "Can only set values on attributes which are tensors; got %r" + raise AttributeError(msg % type(v)) + + # Convert the attribute to a tensor if it is not a tensor. + if not torch.is_tensor(value): + value = torch.tensor( + value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad + ) + + # Check the shapes match the existing shape and the shape of the index. + if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]: + msg = "Expected value to have shape %r; got %r" + raise ValueError(msg % (v.shape, value.shape)) + if ( + v.dim() == 0 + and isinstance(self.index, slice) + and len(value) != len(self.index) + ): + msg = "Expected value to have len %r; got %r" + raise ValueError(msg % (len(self.index), len(value))) + self.class_object.__dict__[name][self.index] = value + + def __getattr__(self, name: str): + """ + Return the value of the attribute given by "name" on self.class_object + at the index specified in self.index. + Args: + name: string of the attribute name + """ + if hasattr(self.class_object, name): + return self.class_object.__dict__[name][self.index] + else: + msg = "Attribute %s not found on %r" + return AttributeError(msg % (name, self.class_object.__name__)) + +BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray) + +class TensorProperties(nn.Module): + """ + A mix-in class for storing tensors as properties with helper methods. + """ + + def __init__( + self, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", + **kwargs, + ) -> None: + """ + Args: + dtype: data type to set for the inputs + device: Device (as str or torch.device) + kwargs: any number of keyword arguments. Any arguments which are + of type (float/int/list/tuple/tensor/array) are broadcasted and + other keyword arguments are set as attributes. + """ + super().__init__() + self.device = make_device(device) + self._N = 0 + if kwargs is not None: + + # broadcast all inputs which are float/int/list/tuple/tensor/array + # set as attributes anything else e.g. strings, bools + args_to_broadcast = {} + for k, v in kwargs.items(): + if v is None or isinstance(v, (str, bool)): + setattr(self, k, v) + elif isinstance(v, BROADCAST_TYPES): + args_to_broadcast[k] = v + else: + msg = "Arg %s with type %r is not broadcastable" + warnings.warn(msg % (k, type(v))) + + names = args_to_broadcast.keys() + # convert from type dict.values to tuple + values = tuple(v for v in args_to_broadcast.values()) + + if len(values) > 0: + broadcasted_values = convert_to_tensors_and_broadcast( + *values, device=device + ) + + # Set broadcasted values as attributes on self. + for i, n in enumerate(names): + setattr(self, n, broadcasted_values[i]) + if self._N == 0: + self._N = broadcasted_values[i].shape[0] + + def __len__(self) -> int: + return self._N + + def isempty(self) -> bool: + return self._N == 0 + + def __getitem__(self, index: Union[int, slice]) -> TensorAccessor: + """ + Args: + index: an int or slice used to index all the fields. + Returns: + if `index` is an index int/slice return a TensorAccessor class + with getattribute/setattribute methods which return/update the value + at the index in the original class. + """ + if isinstance(index, (int, slice)): + return TensorAccessor(class_object=self, index=index) + + msg = "Expected index of type int or slice; got %r" + raise ValueError(msg % type(index)) + + # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. + def to(self, device: Device = "cpu") -> "TensorProperties": + """ + In place operation to move class properties which are tensors to a + specified device. If self has a property "device", update this as well. + """ + device_ = make_device(device) + for k in dir(self): + v = getattr(self, k) + if k == "device": + setattr(self, k, device_) + if torch.is_tensor(v) and v.device != device_: + setattr(self, k, v.to(device_)) + return self + + def cpu(self) -> "TensorProperties": + return self.to("cpu") + + # pyre-fixme[14]: `cuda` overrides method defined in `Module` inconsistently. + def cuda(self, device: Optional[int] = None) -> "TensorProperties": + return self.to(f"cuda:{device}" if device is not None else "cuda") + + def clone(self, other) -> "TensorProperties": + """ + Update the tensor properties of other with the cloned properties of self. + """ + for k in dir(self): + v = getattr(self, k) + if inspect.ismethod(v) or k.startswith("__"): + continue + if torch.is_tensor(v): + v_clone = v.clone() + else: + v_clone = copy.deepcopy(v) + setattr(other, k, v_clone) + return other + + def gather_props(self, batch_idx) -> "TensorProperties": + """ + This is an in place operation to reformat all tensor class attributes + based on a set of given indices using torch.gather. This is useful when + attributes which are batched tensors e.g. shape (N, 3) need to be + multiplied with another tensor which has a different first dimension + e.g. packed vertices of shape (V, 3). + Example + .. code-block:: python + self.specular_color = (N, 3) tensor of specular colors for each mesh + A lighting calculation may use + .. code-block:: python + verts_packed = meshes.verts_packed() # (V, 3) + To multiply these two tensors the batch dimension needs to be the same. + To achieve this we can do + .. code-block:: python + batch_idx = meshes.verts_packed_to_mesh_idx() # (V) + This gives index of the mesh for each vertex in verts_packed. + .. code-block:: python + self.gather_props(batch_idx) + self.specular_color = (V, 3) tensor with the specular color for + each packed vertex. + torch.gather requires the index tensor to have the same shape as the + input tensor so this method takes care of the reshaping of the index + tensor to use with class attributes with arbitrary dimensions. + Args: + batch_idx: shape (B, ...) where `...` represents an arbitrary + number of dimensions + Returns: + self with all properties reshaped. e.g. a property with shape (N, 3) + is transformed to shape (B, 3). + """ + # Iterate through the attributes of the class which are tensors. + for k in dir(self): + v = getattr(self, k) + if torch.is_tensor(v): + if v.shape[0] > 1: + # There are different values for each batch element + # so gather these using the batch_idx. + # First clone the input batch_idx tensor before + # modifying it. + _batch_idx = batch_idx.clone() + idx_dims = _batch_idx.shape + tensor_dims = v.shape + if len(idx_dims) > len(tensor_dims): + msg = "batch_idx cannot have more dimensions than %s. " + msg += "got shape %r and %s has shape %r" + raise ValueError(msg % (k, idx_dims, k, tensor_dims)) + if idx_dims != tensor_dims: + # To use torch.gather the index tensor (_batch_idx) has + # to have the same shape as the input tensor. + new_dims = len(tensor_dims) - len(idx_dims) + new_shape = idx_dims + (1,) * new_dims + expand_dims = (-1,) + tensor_dims[1:] + _batch_idx = _batch_idx.view(*new_shape) + _batch_idx = _batch_idx.expand(*expand_dims) + + v = v.gather(0, _batch_idx) + setattr(self, k, v) + return self + +class CamerasBase(TensorProperties): + """ + `CamerasBase` implements a base class for all cameras. + For cameras, there are four different coordinate systems (or spaces) + - World coordinate system: This is the system the object lives - the world. + - Camera view coordinate system: This is the system that has its origin on the camera + and the and the Z-axis perpendicular to the image plane. + In PyTorch3D, we assume that +X points left, and +Y points up and + +Z points out from the image plane. + The transformation from world --> view happens after applying a rotation (R) + and translation (T) + - NDC coordinate system: This is the normalized coordinate system that confines + in a volume the rendered part of the object or scene. Also known as view volume. + For square images, given the PyTorch3D convention, (+1, +1, znear) + is the top left near corner, and (-1, -1, zfar) is the bottom right far + corner of the volume. + The transformation from view --> NDC happens after applying the camera + projection matrix (P) if defined in NDC space. + For non square images, we scale the points such that smallest side + has range [-1, 1] and the largest side has range [-u, u], with u > 1. + - Screen coordinate system: This is another representation of the view volume with + the XY coordinates defined in image space instead of a normalized space. + A better illustration of the coordinate systems can be found in + pytorch3d/docs/notes/cameras.md. + It defines methods that are common to all camera models: + - `get_camera_center` that returns the optical center of the camera in + world coordinates + - `get_world_to_view_transform` which returns a 3D transform from + world coordinates to the camera view coordinates (R, T) + - `get_full_projection_transform` which composes the projection + transform (P) with the world-to-view transform (R, T) + - `transform_points` which takes a set of input points in world coordinates and + projects to the space the camera is defined in (NDC or screen) + - `get_ndc_camera_transform` which defines the transform from screen/NDC to + PyTorch3D's NDC space + - `transform_points_ndc` which takes a set of points in world coordinates and + projects them to PyTorch3D's NDC space + - `transform_points_screen` which takes a set of points in world coordinates and + projects them to screen space + For each new camera, one should implement the `get_projection_transform` + routine that returns the mapping from camera view coordinates to camera + coordinates (NDC or screen). + Another useful function that is specific to each camera model is + `unproject_points` which sends points from camera coordinates (NDC or screen) + back to camera view or world coordinates depending on the `world_coordinates` + boolean argument of the function. + """ + + # Used in __getitem__ to index the relevant fields + # When creating a new camera, this should be set in the __init__ + _FIELDS: Tuple[str, ...] = () + + # Names of fields which are a constant property of the whole batch, rather + # than themselves a batch of data. + # When joining objects into a batch, they will have to agree. + _SHARED_FIELDS: Tuple[str, ...] = () + + def get_projection_transform(self): + """ + Calculate the projective transformation matrix. + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + Return: + a `Transform3d` object which represents a batch of projection + matrices of shape (N, 3, 3) + """ + raise NotImplementedError() + + def unproject_points(self, xy_depth: torch.Tensor, **kwargs): + """ + Transform input points from camera coodinates (NDC or screen) + to the world / camera coordinates. + Each of the input points `xy_depth` of shape (..., 3) is + a concatenation of the x, y location and its depth. + For instance, for an input 2D tensor of shape `(num_points, 3)` + `xy_depth` takes the following form: + `xy_depth[i] = [x[i], y[i], depth[i]]`, + for a each point at an index `i`. + The following example demonstrates the relationship between + `transform_points` and `unproject_points`: + .. code-block:: python + cameras = # camera object derived from CamerasBase + xyz = # 3D points of shape (batch_size, num_points, 3) + # transform xyz to the camera view coordinates + xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz) + # extract the depth of each point as the 3rd coord of xyz_cam + depth = xyz_cam[:, :, 2:] + # project the points xyz to the camera + xy = cameras.transform_points(xyz)[:, :, :2] + # append depth to xy + xy_depth = torch.cat((xy, depth), dim=2) + # unproject to the world coordinates + xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True) + print(torch.allclose(xyz, xyz_unproj_world)) # True + # unproject to the camera coordinates + xyz_unproj = cameras.unproject_points(xy_depth, world_coordinates=False) + print(torch.allclose(xyz_cam, xyz_unproj)) # True + Args: + xy_depth: torch tensor of shape (..., 3). + world_coordinates: If `True`, unprojects the points back to world + coordinates using the camera extrinsics `R` and `T`. + `False` ignores `R` and `T` and unprojects to + the camera view coordinates. + from_ndc: If `False` (default), assumes xy part of input is in + NDC space if self.in_ndc(), otherwise in screen space. If + `True`, assumes xy is in NDC space even if the camera + is defined in screen space. + Returns + new_points: unprojected points with the same shape as `xy_depth`. + """ + raise NotImplementedError() + + def get_camera_center(self, **kwargs) -> torch.Tensor: + """ + Return the 3D location of the camera optical center + in the world coordinates. + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + Setting T here will update the values set in init as this + value may be needed later on in the rendering pipeline e.g. for + lighting calculations. + Returns: + C: a batch of 3D locations of shape (N, 3) denoting + the locations of the center of each camera in the batch. + """ + w2v_trans = self.get_world_to_view_transform(**kwargs) + P = w2v_trans.inverse().get_matrix() + # the camera center is the translation component (the first 3 elements + # of the last row) of the inverted world-to-view + # transform (4x4 RT matrix) + C = P[:, 3, :3] + return C + + def get_world_to_view_transform(self, **kwargs) -> Transform3d: + """ + Return the world-to-view transform. + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + Returns: + A Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + R: torch.Tensor = kwargs.get("R", self.R) + T: torch.Tensor = kwargs.get("T", self.T) + self.R = R # pyre-ignore[16] + self.T = T # pyre-ignore[16] + world_to_view_transform = get_world_to_view_transform(R=R, T=T) + return world_to_view_transform + + def get_full_projection_transform(self, **kwargs) -> Transform3d: + """ + Return the full world-to-camera transform composing the + world-to-view and view-to-camera transforms. + If camera is defined in NDC space, the projected points are in NDC space. + If camera is defined in screen space, the projected points are in screen space. + Args: + **kwargs: parameters for the projection transforms can be passed in + as keyword arguments to override the default values + set in __init__. + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + Returns: + a Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16] + self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16] + world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) + view_to_proj_transform = self.get_projection_transform(**kwargs) + return world_to_view_transform.compose(view_to_proj_transform) + + def transform_points( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transform input points from world to camera space with the + projection matrix defined by the camera. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the camera plane. + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_proj_transform = self.get_full_projection_transform(**kwargs) + return world_to_proj_transform.transform_points(points, eps=eps) + + def get_ndc_camera_transform(self, **kwargs) -> Transform3d: + """ + Returns the transform from camera projection space (screen or NDC) to NDC space. + For cameras that can be specified in screen space, this transform + allows points to be converted from screen to NDC space. + The default transform scales the points from [0, W]x[0, H] + to [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. + This function should be modified per camera definitions if need be, + e.g. for Perspective/Orthographic cameras we provide a custom implementation. + This transform assumes PyTorch3D coordinate system conventions for + both the NDC space and the input points. + This transform interfaces with the PyTorch3D renderer which assumes + input points to the renderer to be in NDC space. + """ + if self.in_ndc(): + return Transform3d(device=self.device, dtype=torch.float32) + else: + # For custom cameras which can be defined in screen space, + # users might might have to implement the screen to NDC transform based + # on the definition of the camera parameters. + # See PerspectiveCameras/OrthographicCameras for an example. + # We don't flip xy because we assume that world points are in + # PyTorch3D coordinates, and thus conversion from screen to ndc + # is a mere scaling from image to [-1, 1] scale. + image_size = kwargs.get("image_size", self.get_image_size()) + return get_screen_to_ndc_transform( + self, with_xyflip=False, image_size=image_size + ) + + def transform_points_ndc( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transforms points from PyTorch3D world/camera space to NDC space. + Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up. + Output points are in NDC space: +X left, +Y up, origin at image center. + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_ndc_transform = self.get_full_projection_transform(**kwargs) + if not self.in_ndc(): + to_ndc_transform = self.get_ndc_camera_transform(**kwargs) + world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform) + + return world_to_ndc_transform.transform_points(points, eps=eps) + + def transform_points_screen( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transforms points from PyTorch3D world/camera space to screen space. + Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up. + Output points are in screen space: +X right, +Y down, origin at top left corner. + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + Returns + new_points: transformed points with the same shape as the input. + """ + points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs) + image_size = kwargs.get("image_size", self.get_image_size()) + return get_ndc_to_screen_transform( + self, with_xyflip=True, image_size=image_size + ).transform_points(points_ndc, eps=eps) + + def clone(self): + """ + Returns a copy of `self`. + """ + cam_type = type(self) + other = cam_type(device=self.device) + return super().clone(other) + + def is_perspective(self): + raise NotImplementedError() + + def in_ndc(self): + """ + Specifies whether the camera is defined in NDC space + or in screen (image) space + """ + raise NotImplementedError() + + def get_znear(self): + return self.znear if hasattr(self, "znear") else None + + def get_image_size(self): + """ + Returns the image size, if provided, expected in the form of (height, width) + The image size is used for conversion of projected points to screen coordinates. + """ + return self.image_size if hasattr(self, "image_size") else None + + def __getitem__( + self, index: Union[int, List[int], torch.LongTensor] + ) -> "CamerasBase": + """ + Override for the __getitem__ method in TensorProperties which needs to be + refactored. + Args: + index: an int/list/long tensor used to index all the fields in the cameras given by + self._FIELDS. + Returns: + if `index` is an index int/list/long tensor return an instance of the current + cameras class with only the values at the selected index. + """ + + kwargs = {} + + if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)): + msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r" + raise ValueError(msg % type(index)) + + if isinstance(index, int): + index = [index] + + if max(index) >= len(self): + raise ValueError(f"Index {max(index)} is out of bounds for select cameras") + + for field in self._FIELDS: + val = getattr(self, field, None) + if val is None: + continue + + # e.g. "in_ndc" is set as attribute "_in_ndc" on the class + # but provided as "in_ndc" on initialization + if field.startswith("_"): + field = field[1:] + + if isinstance(val, (str, bool)): + kwargs[field] = val + elif isinstance(val, torch.Tensor): + # In the init, all inputs will be converted to + # tensors before setting as attributes + kwargs[field] = val[index] + else: + raise ValueError(f"Field {field} type is not supported for indexing") + + kwargs["device"] = self.device + return self.__class__(**kwargs) + +class FoVPerspectiveCameras(CamerasBase): + """ + A class which stores a batch of parameters to generate a batch of + projection matrices by specifying the field of view. + The definition of the parameters follow the OpenGL perspective camera. + + The extrinsics of the camera (R and T matrices) can also be set in the + initializer or passed in to `get_full_projection_transform` to get + the full transformation from world -> ndc. + + The `transform_points` method calculates the full world -> ndc transform + and then applies it to the input points. + + The transforms can also be returned separately as Transform3d objects. + + * Setting the Aspect Ratio for Non Square Images * + + If the desired output image size is non square (i.e. a tuple of (H, W) where H != W) + the aspect ratio needs special consideration: There are two aspect ratios + to be aware of: + - the aspect ratio of each pixel + - the aspect ratio of the output image + The `aspect_ratio` setting in the FoVPerspectiveCameras sets the + pixel aspect ratio. When using this camera with the differentiable rasterizer + be aware that in the rasterizer we assume square pixels, but allow + variable image aspect ratio (i.e rectangle images). + + In most cases you will want to set the camera `aspect_ratio=1.0` + (i.e. square pixels) and only vary the output image dimensions in pixels + for rasterization. + """ + + # For __getitem__ + _FIELDS = ( + "K", + "znear", + "zfar", + "aspect_ratio", + "fov", + "R", + "T", + "degrees", + ) + + _SHARED_FIELDS = ("degrees",) + + def __init__( + self, + znear=1.0, + zfar=100.0, + aspect_ratio=1.0, + fov=60.0, + degrees: bool = True, + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + ) -> None: + """ + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + aspect_ratio: aspect ratio of the image pixels. + 1.0 indicates square pixels. + fov: field of view angle of the camera. + degrees: bool, set to True if fov is specified in degrees. + R: Rotation matrix of shape (N, 3, 3) + T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need znear, zfar, fov, aspect_ratio, degrees + device: Device (as str or torch.device) + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__( + device=device, + znear=znear, + zfar=zfar, + aspect_ratio=aspect_ratio, + fov=fov, + R=R, + T=T, + K=K, + ) + + # No need to convert to tensor or broadcast. + self.degrees = degrees + + def compute_projection_matrix( + self, znear, zfar, fov, aspect_ratio, degrees: bool + ) -> torch.Tensor: + """ + Compute the calibration matrix K of shape (N, 4, 4) + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + fov: field of view angle of the camera. + aspect_ratio: aspect ratio of the image pixels. + 1.0 indicates square pixels. + degrees: bool, set to True if fov is specified in degrees. + + Returns: + torch.FloatTensor of the calibration matrix with shape (N, 4, 4) + """ + K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) + ones = torch.ones((self._N), dtype=torch.float32, device=self.device) + if degrees: + fov = (np.pi / 180) * fov + + if not torch.is_tensor(fov): + fov = torch.tensor(fov, device=self.device) + tanHalfFov = torch.tan((fov / 2)) + max_y = tanHalfFov * znear + min_y = -max_y + max_x = max_y * aspect_ratio + min_x = -max_x + + # NOTE: In OpenGL the projection matrix changes the handedness of the + # coordinate frame. i.e the NDC space positive z direction is the + # camera space negative z direction. This is because the sign of the z + # in the projection matrix is set to -1.0. + # In pytorch3d we maintain a right handed coordinate system throughout + # so the so the z sign is 1.0. + z_sign = 1.0 + + K[:, 0, 0] = 2.0 * znear / (max_x - min_x) + K[:, 1, 1] = 2.0 * znear / (max_y - min_y) + K[:, 0, 2] = (max_x + min_x) / (max_x - min_x) + K[:, 1, 2] = (max_y + min_y) / (max_y - min_y) + K[:, 3, 2] = z_sign * ones + + # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point + # is at the near clipping plane and z = 1 when the point is at the far + # clipping plane. + K[:, 2, 2] = z_sign * zfar / (zfar - znear) + K[:, 2, 3] = -(zfar * znear) / (zfar - znear) + + return K + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the perspective projection matrix with a symmetric + viewing frustrum. Use column major order. + The viewing frustrum will be projected into ndc, s.t. + (max_x, max_y) -> (+1, +1) + (min_x, min_y) -> (-1, -1) + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + + Return: + a Transform3d object which represents a batch of projection + matrices of shape (N, 4, 4) + + .. code-block:: python + + h1 = (max_y + min_y)/(max_y - min_y) + w1 = (max_x + min_x)/(max_x - min_x) + tanhalffov = tan((fov/2)) + s1 = 1/tanhalffov + s2 = 1/(tanhalffov * (aspect_ratio)) + + # To map z to the range [0, 1] use: + f1 = far / (far - near) + f2 = -(far * near) / (far - near) + + # Projection matrix + K = [ + [s1, 0, w1, 0], + [0, s2, h1, 0], + [0, 0, f1, f2], + [0, 0, 1, 0], + ] + """ + K = kwargs.get("K", self.K) + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = self.compute_projection_matrix( + kwargs.get("znear", self.znear), + kwargs.get("zfar", self.zfar), + kwargs.get("fov", self.fov), + kwargs.get("aspect_ratio", self.aspect_ratio), + kwargs.get("degrees", self.degrees), + ) + + # Transpose the projection matrix as PyTorch3D transforms use row vectors. + transform = Transform3d( + matrix=K.transpose(1, 2).contiguous(), device=self.device + ) + return transform + + def unproject_points( + self, + xy_depth: torch.Tensor, + world_coordinates: bool = True, + scaled_depth_input: bool = False, + **kwargs, + ) -> torch.Tensor: + """>! + FoV cameras further allow for passing depth in world units + (`scaled_depth_input=False`) or in the [0, 1]-normalized units + (`scaled_depth_input=True`) + + Args: + scaled_depth_input: If `True`, assumes the input depth is in + the [0, 1]-normalized units. If `False` the input depth is in + the world units. + """ + + # obtain the relevant transformation to ndc + if world_coordinates: + to_ndc_transform = self.get_full_projection_transform() + else: + to_ndc_transform = self.get_projection_transform() + + if scaled_depth_input: + # the input is scaled depth, so we don't have to do anything + xy_sdepth = xy_depth + else: + # parse out important values from the projection matrix + K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix() + # parse out f1, f2 from K_matrix + unsqueeze_shape = [1] * xy_depth.dim() + unsqueeze_shape[0] = K_matrix.shape[0] + f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape) + f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape) + # get the scaled depth + sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3] + # concatenate xy + scaled depth + xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1) + + # unproject with inverse of the projection + unprojection_transform = to_ndc_transform.inverse() + return unprojection_transform.transform_points(xy_sdepth) + + def is_perspective(self): + return True + + def in_ndc(self): + return True + +####################################################################################### +## ██████╗ ███████╗███████╗██╗███╗ ██╗██╗████████╗██╗ ██████╗ ███╗ ██╗███████╗ ## +## ██╔══██╗██╔════╝██╔════╝██║████╗ ██║██║╚══██╔══╝██║██╔═══██╗████╗ ██║██╔════╝ ## +## ██║ ██║█████╗ █████╗ ██║██╔██╗ ██║██║ ██║ ██║██║ ██║██╔██╗ ██║███████╗ ## +## ██║ ██║██╔══╝ ██╔══╝ ██║██║╚██╗██║██║ ██║ ██║██║ ██║██║╚██╗██║╚════██║ ## +## ██████╔╝███████╗██║ ██║██║ ╚████║██║ ██║ ██║╚██████╔╝██║ ╚████║███████║ ## +## ╚═════╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚══════╝ ## +####################################################################################### + +def make_device(device: Device) -> torch.device: + """ + Makes an actual torch.device object from the device specified as + either a string or torch.device object. If the device is `cuda` without + a specific index, the index of the current device is assigned. + Args: + device: Device (as str or torch.device) + Returns: + A matching torch.device object + """ + device = torch.device(device) if isinstance(device, str) else device + if device.type == "cuda" and device.index is None: # pyre-ignore[16] + # If cuda but with no index, then the current cuda device is indicated. + # In that case, we fix to that device + device = torch.device(f"cuda:{torch.cuda.current_device()}") + return device + +def get_device(x, device: Optional[Device] = None) -> torch.device: + """ + Gets the device of the specified variable x if it is a tensor, or + falls back to a default CPU device otherwise. Allows overriding by + providing an explicit device. + Args: + x: a torch.Tensor to get the device from or another type + device: Device (as str or torch.device) to fall back to + Returns: + A matching torch.device object + """ + + # User overrides device + if device is not None: + return make_device(device) + + # Set device based on input tensor + if torch.is_tensor(x): + return x.device + + # Default device is cpu + return torch.device("cpu") + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + +def _broadcast_bmm(a, b) -> torch.Tensor: + """ + Batch multiply two matrices and broadcast if necessary. + + Args: + a: torch tensor of shape (P, K) or (M, P, K) + b: torch tensor of shape (N, K, K) + + Returns: + a and b broadcast multiplied. The output batch dimension is max(N, M). + + To broadcast transforms across a batch dimension if M != N then + expect that either M = 1 or N = 1. The tensor with batch dimension 1 is + expanded to have shape N or M. + """ + if a.dim() == 2: + a = a[None] + if len(a) != len(b): + if not ((len(a) == 1) or (len(b) == 1)): + msg = "Expected batch dim for bmm to be equal or 1; got %r, %r" + raise ValueError(msg % (a.shape, b.shape)) + if len(a) == 1: + a = a.expand(len(b), -1, -1) + if len(b) == 1: + b = b.expand(len(a), -1, -1) + return a.bmm(b) + +def _safe_det_3x3(t: torch.Tensor): + """ + Fast determinant calculation for a batch of 3x3 matrices. + Note, result of this function might not be the same as `torch.det()`. + The differences might be in the last significant digit. + Args: + t: Tensor of shape (N, 3, 3). + Returns: + Tensor of shape (N) with determinants. + """ + + det = ( + t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1]) + - t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2]) + + t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1]) + ) + + return det + +def get_world_to_view_transform( + R: torch.Tensor = _R, T: torch.Tensor = _T +) -> Transform3d: + """ + This function returns a Transform3d representing the transformation + matrix to go from world space to view space by applying a rotation and + a translation. + PyTorch3D uses the same convention as Hartley & Zisserman. + I.e., for camera extrinsic parameters R (rotation) and T (translation), + we map a 3D point `X_world` in world coordinates to + a point `X_cam` in camera coordinates with: + `X_cam = X_world R + T` + Args: + R: (N, 3, 3) matrix representing the rotation. + T: (N, 3) matrix representing the translation. + Returns: + a Transform3d object which represents the composed RT transformation. + """ + # TODO: also support the case where RT is specified as one matrix + # of shape (N, 4, 4). + + if T.shape[0] != R.shape[0]: + msg = "Expected R, T to have the same batch dimension; got %r, %r" + raise ValueError(msg % (R.shape[0], T.shape[0])) + if T.dim() != 2 or T.shape[1:] != (3,): + msg = "Expected T to have shape (N, 3); got %r" + raise ValueError(msg % repr(T.shape)) + if R.dim() != 3 or R.shape[1:] != (3, 3): + msg = "Expected R to have shape (N, 3, 3); got %r" + raise ValueError(msg % repr(R.shape)) + + # Create a Transform3d object + T_ = Translate(T, device=T.device) + R_ = Rotate(R, device=R.device) + return R_.compose(T_) + +def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None: + """ + Determine if R is a valid rotation matrix by checking it satisfies the + following conditions: + + ``RR^T = I and det(R) = 1`` + + Args: + R: an (N, 3, 3) matrix + + Returns: + None + + Emits a warning if R is an invalid rotation matrix. + """ + N = R.shape[0] + eye = torch.eye(3, dtype=R.dtype, device=R.device) + eye = eye.view(1, 3, 3).expand(N, -1, -1) + orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol) + det_R = _safe_det_3x3(R) + no_distortion = torch.allclose(det_R, torch.ones_like(det_R)) + if not (orthogonal and no_distortion): + msg = "R is not a valid rotation matrix" + warnings.warn(msg) + return + +def format_tensor( + input, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", +) -> torch.Tensor: + """ + Helper function for converting a scalar value to a tensor. + Args: + input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor + dtype: data type for the input + device: Device (as str or torch.device) on which the tensor should be placed. + Returns: + input_vec: torch tensor with optional added batch dimension. + """ + device_ = make_device(device) + if not torch.is_tensor(input): + input = torch.tensor(input, dtype=dtype, device=device_) + elif not input.device.type.startswith('mps'): + input = torch.tensor(input, dtype=torch.float32,device=device_) + + if input.dim() == 0: + input = input.view(1) + + if input.device == device_: + return input + + input = input.to(device=device) + return input + +def convert_to_tensors_and_broadcast( + *args, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", +): + """ + Helper function to handle parsing an arbitrary number of inputs (*args) + which all need to have the same batch dimension. + The output is a list of tensors. + Args: + *args: an arbitrary number of inputs + Each of the values in `args` can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N, K_i) or (1, K_i) where K_i are + an arbitrary number of dimensions which can vary for each + value in args. In this case each input is broadcast to a + tensor of shape (N, K_i) + dtype: data type to use when creating new tensors. + device: torch device on which the tensors should be placed. + Output: + args: A list of tensors of shape (N, K_i) + """ + # Convert all inputs to tensors with a batch dimension + args_1d = [format_tensor(c, dtype, device) for c in args] + + # Find broadcast size + sizes = [c.shape[0] for c in args_1d] + N = max(sizes) + + args_Nd = [] + for c in args_1d: + if c.shape[0] != 1 and c.shape[0] != N: + msg = "Got non-broadcastable sizes %r" % sizes + raise ValueError(msg) + + # Expand broadcast dim and keep non broadcast dims the same size + expand_sizes = (N,) + (-1,) * len(c.shape[1:]) + args_Nd.append(c.expand(*expand_sizes)) + + return args_Nd + +def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + """ + Helper function for _handle_input. + + Args: + c: Python scalar, torch scalar, or 1D torch tensor + + Returns: + c_vec: 1D torch tensor + """ + if not torch.is_tensor(c): + c = torch.tensor(c, dtype=dtype, device=device) + if c.dim() == 0: + c = c.view(1) + if c.device != device or c.dtype != dtype: + c = c.to(device=device, dtype=dtype) + return c + +def _handle_input( + x, + y, + z, + dtype: torch.dtype, + device: Optional[Device], + name: str, + allow_singleton: bool = False, +) -> torch.Tensor: + """ + Helper function to handle parsing logic for building transforms. The output + is always a tensor of shape (N, 3), but there are several types of allowed + input. + + Case I: Single Matrix + In this case x is a tensor of shape (N, 3), and y and z are None. Here just + return x. + + Case II: Vectors and Scalars + In this case each of x, y, and z can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N, 1) or (1, 1) + In this case x, y and z are broadcast to tensors of shape (N, 1) + and concatenated to a tensor of shape (N, 3) + + Case III: Singleton (only if allow_singleton=True) + In this case y and z are None, and x can be one of the following: + - Python scalar + - Torch scalar + - Torch tensor of shape (N, 1) or (1, 1) + Here x will be duplicated 3 times, and we return a tensor of shape (N, 3) + + Returns: + xyz: Tensor of shape (N, 3) + """ + device_ = get_device(x, device) + # If x is actually a tensor of shape (N, 3) then just return it + if torch.is_tensor(x) and x.dim() == 2: + if x.shape[1] != 3: + msg = "Expected tensor of shape (N, 3); got %r (in %s)" + raise ValueError(msg % (x.shape, name)) + if y is not None or z is not None: + msg = "Expected y and z to be None (in %s)" % name + raise ValueError(msg) + return x.to(device=device_, dtype=dtype) + + if allow_singleton and y is None and z is None: + y = x + z = x + + # Convert all to 1D tensors + xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]] + + # Broadcast and concatenate + sizes = [c.shape[0] for c in xyz] + N = max(sizes) + for c in xyz: + if c.shape[0] != 1 and c.shape[0] != N: + msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name) + raise ValueError(msg) + xyz = [c.expand(N) for c in xyz] + xyz = torch.stack(xyz, dim=1) + return xyz diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/inference_video.py b/extensions/deforum/scripts/deforum_helpers/src/rife/inference_video.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2358610802582f9681de236ea29b4a37186685 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/inference_video.py @@ -0,0 +1,282 @@ +# thanks to https://github.com/n00mkrad for the inspiration and a bit of code. Also thanks for https://github.com/XmYx for the initial reorganization of this script +import os, sys +from types import SimpleNamespace +import cv2 +import torch +import argparse +import shutil +import numpy as np +from tqdm import tqdm +from torch.nn import functional as F +import warnings +import _thread +from queue import Queue, Empty +import subprocess +import time +from .model.pytorch_msssim import ssim_matlab + +sys.path.append('../../') +from deforum_helpers.video_audio_utilities import ffmpeg_stitch_video +from deforum_helpers.general_utils import duplicate_pngs_from_folder + +warnings.filterwarnings("ignore") + +def run_rife_new_video_infer( + output=None, + model=None, + fp16=False, + UHD=False, # *Will be received as *True* if imgs/vid resolution is 2K or higher* + scale=1.0, + fps=None, + deforum_models_path=None, + raw_output_imgs_path=None, + img_batch_id=None, + ffmpeg_location=None, + audio_track=None, + interp_x_amount=2, + slow_mo_enabled=False, + slow_mo_x_amount=2, + ffmpeg_crf=17, + ffmpeg_preset='veryslow', + keep_imgs=False, + orig_vid_name = None): + + args = SimpleNamespace() + args.output = output + args.modelDir = model + args.fp16 = fp16 + args.UHD = UHD + args.scale = scale + args.fps = fps + args.deforum_models_path = deforum_models_path + args.raw_output_imgs_path = raw_output_imgs_path + args.img_batch_id = img_batch_id + args.ffmpeg_location = ffmpeg_location + args.audio_track = audio_track + args.interp_x_amount = interp_x_amount + args.slow_mo_enabled = slow_mo_enabled + args.slow_mo_x_amount = slow_mo_x_amount + args.ffmpeg_crf = ffmpeg_crf + args.ffmpeg_preset = ffmpeg_preset + args.keep_imgs = keep_imgs + args.orig_vid_name = orig_vid_name + + if args.UHD and args.scale == 1.0: + args.scale = 0.5 + + start_time = time.time() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_grad_enabled(False) + if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # TODO: Can/ need to handle this? currently it's always False and give errors if True but faster speeds on tensortcore equipped gpus? + if (args.fp16): + torch.set_default_tensor_type(torch.cuda.HalfTensor) + if args.modelDir is not None: + try: + from .rife_new_gen.RIFE_HDv3 import Model + except ImportError as e: + raise ValueError(f"{args.modelDir} could not be found. Please contact deforum support {e}") + except Exception as e: + raise ValueError(f"An error occured while trying to import {args.modelDir}: {e}") + else: + print("Got a request to frame-interpolate but no valid frame interpolation engine value provided. Doing... nothing") + return + + model = Model() + if not hasattr(model, 'version'): + model.version = 0 + model.load_model(args.modelDir, -1, deforum_models_path) + model.eval() + model.device() + + print(f"{args.modelDir}.pkl model successfully loaded into memory") + print("Interpolation progress (it's OK if it finishes before 100%):") + + interpolated_path = os.path.join(args.raw_output_imgs_path, 'interpolated_frames_rife') + # set custom name depending on if we interpolate after a run, or interpolate a video (related/unrelated to deforum, we don't know) directly from within the RIFE tab + if args.orig_vid_name is not None: # interpolating a video (deforum or unrelated) + custom_interp_path = "{}_{}".format(interpolated_path, args.orig_vid_name) + else: # interpolating after a deforum run: + custom_interp_path = "{}_{}".format(interpolated_path, args.img_batch_id) + + # In this folder we temporarily keep the original frames (converted/ copy-pasted and img format depends on scenario) + # the convertion case is done to avert a problem with 24 and 32 mixed outputs from the same animation run + temp_convert_raw_png_path = os.path.join(args.raw_output_imgs_path, "tmp_rife_folder") + + duplicate_pngs_from_folder(args.raw_output_imgs_path, temp_convert_raw_png_path, args.img_batch_id, args.orig_vid_name) + + videogen = [] + for f in os.listdir(temp_convert_raw_png_path): + # double check for old _depth_ files, not really needed probably but keeping it for now + if '_depth_' not in f: + videogen.append(f) + tot_frame = len(videogen) + videogen.sort(key= lambda x:int(x.split('.')[0])) + img_path = os.path.join(temp_convert_raw_png_path, videogen[0]) + lastframe = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + videogen = videogen[1:] + h, w, _ = lastframe.shape + vid_out = None + + if not os.path.exists(custom_interp_path): + os.mkdir(custom_interp_path) + + tmp = max(128, int(128 / args.scale)) + ph = ((h - 1) // tmp + 1) * tmp + pw = ((w - 1) // tmp + 1) * tmp + padding = (0, pw - w, 0, ph - h) + pbar = tqdm(total=tot_frame) + + write_buffer = Queue(maxsize=500) + read_buffer = Queue(maxsize=500) + + _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen, temp_convert_raw_png_path)) + _thread.start_new_thread(clear_write_buffer, (args, write_buffer, custom_interp_path)) + + I1 = torch.from_numpy(np.transpose(lastframe, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1, args.fp16, padding) + temp = None # save lastframe when processing static frame + + while True: + if temp is not None: + frame = temp + temp = None + else: + frame = read_buffer.get() + if frame is None: + break + I0 = I1 + I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1, args.fp16, padding) + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + + break_flag = False + if ssim > 0.996: + frame = read_buffer.get() # read a new frame + if frame is None: + break_flag = True + frame = lastframe + else: + temp = frame + I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. + I1 = pad_image(I1, args.fp16, padding) + I1 = model.inference(I0, I1, args.scale) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w] + + if ssim < 0.2: + output = [] + for i in range(args.interp_x_amount - 1): + output.append(I0) + else: + output = make_inference(model, I0, I1, args.interp_x_amount - 1, scale) + + write_buffer.put(lastframe) + for mid in output: + mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) + write_buffer.put(mid[:h, :w]) + pbar.update(1) + lastframe = frame + if break_flag: + break + + write_buffer.put(lastframe) + + while (not write_buffer.empty()): + time.sleep(0.1) + pbar.close() + shutil.rmtree(temp_convert_raw_png_path) + + print(f"Interpolation \033[0;32mdone\033[0m in {time.time()-start_time:.2f} seconds!") + # stitch video from interpolated frames, and add audio if needed + try: + print (f"*Passing interpolated frames to ffmpeg...*") + vid_out_path = stitch_video(args.img_batch_id, args.fps, custom_interp_path, args.audio_track, args.ffmpeg_location, args.interp_x_amount, args.slow_mo_enabled, args.slow_mo_x_amount, args.ffmpeg_crf, args.ffmpeg_preset, args.keep_imgs, args.orig_vid_name) + # remove folder with raw (non-interpolated) vid input frames in case of input VID and not PNGs + if orig_vid_name is not None: + shutil.rmtree(raw_output_imgs_path) + except Exception as e: + print(f'Video stitching gone wrong. *Interpolated frames were saved to HD as backup!*. Actual error: {e}') + +def clear_write_buffer(user_args, write_buffer, custom_interp_path): + cnt = 0 + + while True: + item = write_buffer.get() + if item is None: + break + filename = '{}/{:0>7d}.png'.format(custom_interp_path, cnt) + + cv2.imwrite(filename, item[:, :, ::-1]) + + cnt += 1 + +def build_read_buffer(user_args, read_buffer, videogen, temp_convert_raw_png_path): + for frame in videogen: + if not temp_convert_raw_png_path is None: + img_path = os.path.join(temp_convert_raw_png_path, frame) + frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() + read_buffer.put(frame) + read_buffer.put(None) + +def make_inference(model, I0, I1, n, scale): + if model.version >= 3.9: + res = [] + for i in range(n): + res.append(model.inference(I0, I1, (i + 1) * 1. / (n + 1), scale)) + return res + else: + middle = model.inference(I0, I1, scale) + if n == 1: + return [middle] + first_half = make_inference(model, I0, middle, n=n // 2, scale=scale) + second_half = make_inference(model, middle, I1, n=n // 2, scale=scale) + if n % 2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] + +def pad_image(img, fp16, padding): + if (fp16): + return F.pad(img, padding).half() + else: + return F.pad(img, padding) + +# TODO: move to fream_interpolation and add FILM to it! +def stitch_video(img_batch_id, fps, img_folder_path, audio_path, ffmpeg_location, interp_x_amount, slow_mo_enabled, slow_mo_x_amount, f_crf, f_preset, keep_imgs, orig_vid_name): + parent_folder = os.path.dirname(img_folder_path) + grandparent_folder = os.path.dirname(parent_folder) + if orig_vid_name is not None: + mp4_path = os.path.join(grandparent_folder, str(orig_vid_name) +'_RIFE_' + 'x' + str(interp_x_amount)) + else: + mp4_path = os.path.join(parent_folder, str(img_batch_id) +'_RIFE_' + 'x' + str(interp_x_amount)) + + if slow_mo_enabled: + mp4_path = mp4_path + '_slomo_x' + str(slow_mo_x_amount) + mp4_path = mp4_path + '.mp4' + + t = os.path.join(img_folder_path, "%07d.png") + add_soundtrack = 'None' + if not audio_path is None: + add_soundtrack = 'File' + + exception_raised = False + try: + ffmpeg_stitch_video(ffmpeg_location=ffmpeg_location, fps=fps, outmp4_path=mp4_path, stitch_from_frame=0, stitch_to_frame=1000000, imgs_path=t, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=f_crf, preset=f_preset) + except Exception as e: + exception_raised = True + print(f"An error occurred while stitching the video: {e}") + + if not exception_raised and not keep_imgs: + shutil.rmtree(img_folder_path) + + if (keep_imgs and orig_vid_name is not None) or (orig_vid_name is not None and exception_raised is True): + shutil.move(img_folder_path, grandparent_folder) + + return mp4_path \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/model/loss.py b/extensions/deforum/scripts/deforum_helpers/src/rife/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..72e5de6af050df7d55c2871a69637077970ddfb9 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/model/loss.py @@ -0,0 +1,128 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class EPE(nn.Module): + def __init__(self): + super(EPE, self).__init__() + + def forward(self, flow, gt, loss_mask): + loss_map = (flow - gt.detach()) ** 2 + loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + return (loss_map * loss_mask) + + +class Ternary(nn.Module): + def __init__(self): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape( + (patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) + + +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor([ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ]).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat( + [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] + pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] + + L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) + loss = (L1X+L1Y) + return loss + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + self.requires_grad = False + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X, Y, indices=None): + X = self.normalize(X) + Y = self.normalize(Y) + indices = [2, 7, 12, 21, 30] + weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + X = self.vgg_pretrained_features[i](X) + Y = self.vgg_pretrained_features[i](Y) + if (i+1) in indices: + loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 + k += 1 + return loss + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + ternary_loss = Ternary() + print(ternary_loss(img0, img1).shape) diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/model/pytorch_msssim/__init__.py b/extensions/deforum/scripts/deforum_helpers/src/rife/model/pytorch_msssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d30326188cf6afacf2fc84c7ae18efe14dae2e --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/model/pytorch_msssim/__init__.py @@ -0,0 +1,200 @@ +import torch +import torch.nn.functional as F +from math import exp +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + +def create_window_3d(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) + mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d(real_size, channel=1).to(img1.device) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) + mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq + sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq + sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs ** weights + pow2 = mssim ** weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 3 channel for SSIM + self.channel = 3 + self.window = create_window(window_size, channel=self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + dssim = (1 - _ssim) / 2 + return dssim + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/model/warplayer.py b/extensions/deforum/scripts/deforum_helpers/src/rife/model/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..21b0b904cf71b297fd43813134c57d13a3ae9e4a --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/model/warplayer.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/IFNet_HDv3.py b/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/IFNet_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..2360c9e7d15ad4c73e8bb34112999e3d46aeb8c2 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/IFNet_HDv3.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..model.warplayer import warp +# from train_log.refine import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True) + ) + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\ +) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c//2, 3, 2, 1), + conv(c//2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4*6, 4, 2, 1), + nn.PixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7, c=192) + self.block1 = IFBlock(8+4, c=128) + self.block2 = IFBlock(8+4, c=96) + self.block3 = IFBlock(8+4, c=64) + # self.contextnet = Contextnet() + # self.unet = Unet() + + def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + loss_cons = 0 + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], timestep), 1), None, scale=scale_list[i]) + if ensemble: + f1, m1 = block[i](torch.cat((img1[:, :3], img0[:, :3], 1-timestep), 1), None, scale=scale_list[i]) + flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = (mask + (-m1)) / 2 + else: + f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1), flow, scale=scale_list[i]) + if ensemble: + f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) + f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + m0 = (m0 + (-m1)) / 2 + flow = flow + f0 + mask = mask + m0 + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask_list[3] = torch.sigmoid(mask_list[3]) + merged[3] = merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) + if not fastmode: + print('contextnet is removed') + ''' + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[3] = torch.clamp(merged[3] + res, 0, 1) + ''' + return flow_list, mask_list[3], merged diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/RIFE_HDv3.py b/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/RIFE_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..e51408af7e574bc4c8b03739cda0bacd73cd0081 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/RIFE_HDv3.py @@ -0,0 +1,108 @@ +import os, sys +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from ..model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from .IFNet_HDv3 import * +import torch.nn.functional as F +from ..model.loss import * +sys.path.append('../../') +from deforum_helpers.general_utils import checksum + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + self.version = 3.9 + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank, deforum_models_path): + + download_rife_model(path, deforum_models_path) + + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict(convert(torch.load(os.path.join(deforum_models_path,'{}.pkl').format(path))), False) + else: + self.flownet.load_state_dict(convert(torch.load(os.path.join(deforum_models_path,'{}.pkl').format(path), map_location ='cpu')), False) + + def inference(self, img0, img1, timestep=0.5, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [8/scale, 4/scale, 2/scale, 1/scale] + flow, mask, merged = self.flownet(imgs, timestep, scale_list) + return merged[3] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [8, 4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[3] - gt).abs().mean() + loss_smooth = self.sobel(flow[3], flow[3]*0).mean() + # loss_vgg = self.vgg(merged[2], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[3], { + 'mask': mask, + 'flow': flow[3][:, :2], + 'loss_l1': loss_l1, + 'loss_cons': loss_cons, + 'loss_smooth': loss_smooth, + } + +def download_rife_model(path, deforum_models_path): + options = {'RIFE46': ( + 'af6f0b4bed96dea2c9f0624b449216c7adfaf7f0b722fba0c8f5c6e20b2ec39559cf33f3d238d53b160c22f00c6eaa47dc54a6e4f8aa4f59a6e4a9e90e1a808a', + "https://github.com/hithereai/Practical-RIFE/releases/download/rife46/RIFE46.pkl"), + 'RIFE43': ('ed660f58708ee369a0b3855f64d2d07a6997d949f33067faae51d740123c5ee015901cc57553594f2df8ec08131a1c5f7c883c481eac0f9addd84379acea90c8', + "https://github.com/hithereai/Practical-RIFE/releases/download/rife43/RIFE43.pkl"), + 'RIFE40': ('0baf0bed23597cda402a97a80a7d14c26a9ed797d2fc0790aac93b19ca5b0f50676ba07aa9f8423cf061ed881ece6e67922f001ea402bfced83ef67675142ce7', + "https://github.com/hithereai/Practical-RIFE/releases/download/rife40/RIFE40.pkl")} + if path in options: + target_file = f"{path}.pkl" + target_path = os.path.join(deforum_models_path, target_file) + if not os.path.exists(target_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(options[path][1], deforum_models_path) + if checksum(target_path) != options[path][0]: + raise Exception(f"Error while downloading {target_file}. Please download from here: {options[path][1]} and place in: " + deforum_models_path) \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/refine.py b/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3807c636d461862f13200fe0017b62db5c20c5 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/rife/rife_new_gen/refine.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.functional as F + +device = torch.device("cuda") + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.PReLU(out_planes) + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +c = 16 +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2*c) + self.conv3 = Conv2(2*c, 4*c) + self.conv4 = Conv2(4*c, 8*c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/extensions/deforum/scripts/deforum_helpers/src/utils.py b/extensions/deforum/scripts/deforum_helpers/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe08b0b1bd41f2bc59e9f8d188db08423fcf48a --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/src/utils.py @@ -0,0 +1,140 @@ +import base64 +import math +import re +from io import BytesIO + +import matplotlib.cm +import numpy as np +import torch +import torch.nn +from PIL import Image + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x, device='cpu'): + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + return x * std + mean + + +class RunningAverageDict: + def __init__(self): + self._dict = None + + def update(self, new_dict): + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=10, vmax=1000, cmap='magma_r'): + value = value.cpu().numpy()[0, :, :] + invalid_mask = value == -1 + + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + # squeeze last dim if it exists + # value = value.squeeze(axis=0) + cmapper = matplotlib.cm.get_cmap(cmap) + value = cmapper(value, bytes=True) # (nxmx4) + value[invalid_mask] = 255 + img = value[:, :, :3] + + # return img.transpose((2, 0, 1)) + return img + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def compute_errors(gt, pred): + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +##################################### Demo Utilities ############################################ +def b64_to_pil(b64string): + image_data = re.sub('^data:image/.+;base64,', '', b64string) + # image = Image.open(cStringIO.StringIO(image_data)) + return Image.open(BytesIO(base64.b64decode(image_data))) + + +# Compute edge magnitudes +from scipy import ndimage + + +def edges(d): + dx = ndimage.sobel(d, 0) # horizontal derivative + dy = ndimage.sobel(d, 1) # vertical derivative + return np.abs(dx) + np.abs(dy) + + +class PointCloudHelper(): + def __init__(self, width=640, height=480): + self.xx, self.yy = self.worldCoords(width, height) + + def worldCoords(self, width=640, height=480): + hfov_degrees, vfov_degrees = 57, 43 + hFov = math.radians(hfov_degrees) + vFov = math.radians(vfov_degrees) + cx, cy = width / 2, height / 2 + fx = width / (2 * math.tan(hFov / 2)) + fy = height / (2 * math.tan(vFov / 2)) + xx, yy = np.tile(range(width), height), np.repeat(range(height), width) + xx = (xx - cx) / fx + yy = (yy - cy) / fy + return xx, yy + + def depth_to_points(self, depth): + depth[edges(depth) > 0.3] = np.nan # Hide depth edges + length = depth.shape[0] * depth.shape[1] + # depth[edges(depth) > 0.3] = 1e6 # Hide depth edges + z = depth.reshape(length) + + return np.dstack((self.xx * z, self.yy * z, z)).reshape((length, 3)) + +##################################################################################################### diff --git a/extensions/deforum/scripts/deforum_helpers/upscaling.py b/extensions/deforum/scripts/deforum_helpers/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b841722b936135bf1f65e86e0c26d7084053b2 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/upscaling.py @@ -0,0 +1,264 @@ +import os +import numpy as np +import cv2 +from pathlib import Path +from tqdm import tqdm +from PIL import Image +from modules.scripts_postprocessing import PostprocessedImage +from modules import devices +import shutil +from queue import Queue, Empty +import modules.scripts as scr +from .frame_interpolation import clean_folder_name +from .general_utils import duplicate_pngs_from_folder, checksum +# TODO: move some funcs to this file? +from .video_audio_utilities import get_quick_vid_info, vid2frames, ffmpeg_stitch_video, extract_number, media_file_has_audio +from basicsr.utils.download_util import load_file_from_url +from .rich import console +import time +import subprocess + +def process_upscale_vid_upload_logic(file, selected_tab, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, vid_file_name, keep_imgs, f_location, f_crf, f_preset): + print("Got a request to *upscale* an existing video.") + + in_vid_fps, _, _ = get_quick_vid_info(file.name) + folder_name = clean_folder_name(Path(vid_file_name).stem) + outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-upscaling', folder_name) + i = 1 + while os.path.exists(outdir_no_tmp): + outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-upscaling', folder_name + '_' + str(i)) + i += 1 + + outdir = os.path.join(outdir_no_tmp, 'tmp_input_frames') + os.makedirs(outdir, exist_ok=True) + + vid2frames(video_path=file.name, video_in_frame_path=outdir, overwrite=True, extract_from_frame=0, extract_to_frame=-1, numeric_files_output=True, out_img_format='png') + + process_video_upscaling(selected_tab, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, orig_vid_fps=in_vid_fps, real_audio_track=file.name, raw_output_imgs_path=outdir, img_batch_id=None, ffmpeg_location=f_location, ffmpeg_crf=f_crf, ffmpeg_preset=f_preset, keep_upscale_imgs=keep_imgs, orig_vid_name=folder_name) + +def process_video_upscaling(resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, orig_vid_fps, real_audio_track, raw_output_imgs_path, img_batch_id, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, keep_upscale_imgs, orig_vid_name): + devices.torch_gc() + + print("Upscaling progress (it's OK if it finishes before 100%):") + + upscaled_path = os.path.join(raw_output_imgs_path, 'upscaled_frames') + if orig_vid_name is not None: # upscaling a video (deforum or unrelated) + custom_upscale_path = "{}_{}".format(upscaled_path, orig_vid_name) + else: # upscaling after a deforum run: + custom_upscale_path = "{}_{}".format(upscaled_path, img_batch_id) + + temp_convert_raw_png_path = os.path.join(raw_output_imgs_path, "tmp_upscale_folder") + duplicate_pngs_from_folder(raw_output_imgs_path, temp_convert_raw_png_path, img_batch_id, orig_vid_name) + + videogen = [] + for f in os.listdir(temp_convert_raw_png_path): + # double check for old _depth_ files, not really needed probably but keeping it for now + if '_depth_' not in f: + videogen.append(f) + + videogen.sort(key= lambda x:int(x.split('.')[0])) + vid_out = None + + if not os.path.exists(custom_upscale_path): + os.mkdir(custom_upscale_path) + + # Upscaling is a slow and demanding operation, so we don't need as much parallelization here + for i in tqdm(range(len(videogen)), desc="Upscaling"): + lastframe = videogen[i] + img_path = os.path.join(temp_convert_raw_png_path, lastframe) + image = process_frame(resize_mode, Image.open(img_path).convert("RGB"), upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility) + filename = '{}/{:0>7d}.png'.format(custom_upscale_path, i) + image.save(filename) + + shutil.rmtree(temp_convert_raw_png_path) + # stitch video from upscaled frames, and add audio if needed + try: + print (f"*Passing upsc frames to ffmpeg...*") + vid_out_path = stitch_video(img_batch_id, orig_vid_fps, custom_upscale_path, real_audio_track, ffmpeg_location, resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, ffmpeg_crf, ffmpeg_preset, keep_upscale_imgs, orig_vid_name) + # remove folder with raw (non-upscaled) vid input frames in case of input VID and not PNGs + if orig_vid_name is not None: + shutil.rmtree(raw_output_imgs_path) + except Exception as e: + print(f'Video stitching gone wrong. *Upscaled frames were saved to HD as backup!*. Actual error: {e}') + + devices.torch_gc() + +def process_frame(resize_mode, image, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): + pp = PostprocessedImage(image) + postproc = scr.scripts_postproc + upscaler_script = next(s for s in postproc.scripts if s.name == "Upscale") + upscaler_script.process(pp, resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility) + return pp.image + +def stitch_video(img_batch_id, fps, img_folder_path, audio_path, ffmpeg_location, resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, f_crf, f_preset, keep_imgs, orig_vid_name): + parent_folder = os.path.dirname(img_folder_path) + grandparent_folder = os.path.dirname(parent_folder) + if orig_vid_name is not None: + mp4_path = os.path.join(grandparent_folder, str(orig_vid_name) +'_upscaled_' + (('by_' + str(upscaling_resize).replace('.', '-')) if resize_mode == 0 else f"to_{upscaling_resize_w}_{upscaling_resize_h}")) + f"_with_{extras_upscaler_1}" + (f"_then_{extras_upscaler_2}" if extras_upscaler_2_visibility > 0 else "") + else: + mp4_path = os.path.join(parent_folder, str(img_batch_id) +'_upscaled_' + (('by_' + str(upscaling_resize).replace('.', '-')) if resize_mode == 0 else f"to_{upscaling_resize_w}_{upscaling_resize_h}")) + f"_with_{extras_upscaler_1}_then_{extras_upscaler_2}" + + mp4_path = mp4_path + '.mp4' + + t = os.path.join(img_folder_path, "%07d.png") + add_soundtrack = 'None' + if not audio_path is None: + add_soundtrack = 'File' + + exception_raised = False + try: + ffmpeg_stitch_video(ffmpeg_location=ffmpeg_location, fps=fps, outmp4_path=mp4_path, stitch_from_frame=0, stitch_to_frame=1000000, imgs_path=t, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=f_crf, preset=f_preset) + except Exception as e: + exception_raised = True + print(f"An error occurred while stitching the video: {e}") + + if not exception_raised and not keep_imgs: + shutil.rmtree(img_folder_path) + + if (keep_imgs and orig_vid_name is not None) or (orig_vid_name is not None and exception_raised is True): + shutil.move(img_folder_path, grandparent_folder) + + return mp4_path + +# NCNN Upscale section START +def process_ncnn_upscale_vid_upload_logic(vid_path, in_vid_fps, in_vid_res, out_vid_res, models_path, upscale_model, upscale_factor, keep_imgs, f_location, f_crf, f_preset, current_user_os): + print(f"Got a request to *upscale* a video using {upscale_model} at {upscale_factor}") + + folder_name = clean_folder_name(Path(vid_path.orig_name).stem) + outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-upscaling', folder_name) + i = 1 + while os.path.exists(outdir_no_tmp): + outdir_no_tmp = os.path.join(os.getcwd(), 'outputs', 'frame-upscaling', folder_name + '_' + str(i)) + i += 1 + + outdir = os.path.join(outdir_no_tmp, 'tmp_input_frames') + os.makedirs(outdir, exist_ok=True) + + vid2frames(video_path=vid_path.name, video_in_frame_path=outdir, overwrite=True, extract_from_frame=0, extract_to_frame=-1, numeric_files_output=True, out_img_format='png') + + process_ncnn_video_upscaling(vid_path, outdir, in_vid_fps, in_vid_res, out_vid_res, models_path, upscale_model, upscale_factor, keep_imgs, f_location, f_crf, f_preset, current_user_os) + +def process_ncnn_video_upscaling(vid_path, outdir, in_vid_fps, in_vid_res, out_vid_res, models_path, upscale_model, upscale_factor, keep_imgs, f_location, f_crf, f_preset, current_user_os): + # get clean number from 'x2, x3' etc + clean_num_r_up_factor = extract_number(upscale_factor) + # set paths + realesrgan_ncnn_location = os.path.join(models_path, 'realesrgan_ncnn', 'realesrgan-ncnn-vulkan' + ('.exe' if current_user_os == 'Windows' else '')) + upscaled_folder_path = os.path.join(os.path.dirname(outdir), 'Upscaled_frames') + # create folder for upscaled imgs to live in. this folder will stay alive if keep_imgs=True, otherwise get deleted at the end + os.makedirs(upscaled_folder_path, exist_ok=True) + out_upscaled_mp4_path = os.path.join(os.path.dirname(outdir), f"{vid_path.orig_name}_Upscaled_{upscale_factor}.mp4") + # download upscaling model if needed + check_and_download_realesrgan_ncnn(models_path, current_user_os) + # set cmd command + cmd = [realesrgan_ncnn_location, '-i', outdir, '-o', upscaled_folder_path, '-s', str(clean_num_r_up_factor), '-n', upscale_model] + # msg to print - need it to hide that text later on (!) + msg_to_print = f"Upscaling raw PNGs using {upscale_model} at {upscale_factor}..." + # blink the msg in the cli until action is done + console.print(msg_to_print, style="blink yellow", end="") + start_time = time.time() + # make call to ncnn upscaling executble + process = subprocess.run(cmd, capture_output=True, check=True, text=True) + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f"\rUpscaling \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) + # set custom path for ffmpeg func below + upscaled_imgs_path_for_ffmpeg = os.path.join(upscaled_folder_path, "%05d.png") + add_soundtrack = 'None' + # don't pass add_soundtrack to ffmpeg if orig video doesn't contain any audio, so we won't get a message saying audio couldn't be added :) + if media_file_has_audio(vid_path.name, f_location): + add_soundtrack = 'File' + # stitch video from upscaled pngs + ffmpeg_stitch_video(ffmpeg_location=f_location, fps=in_vid_fps, outmp4_path=out_upscaled_mp4_path, stitch_from_frame=0, stitch_to_frame=-1, imgs_path=upscaled_imgs_path_for_ffmpeg, add_soundtrack=add_soundtrack, audio_path=vid_path.name, crf=f_crf, preset=f_preset) + # delete the raw video pngs + shutil.rmtree(outdir) + # delete upscaled imgs if user requested + if not keep_imgs: + shutil.rmtree(upscaled_folder_path) + +def check_and_download_realesrgan_ncnn(models_folder, current_user_os): + import zipfile + if current_user_os == 'Windows': + zip_file_name = 'realesrgan-ncnn-windows.zip' + executble_name = 'realesrgan-ncnn-vulkan.exe' + zip_checksum_value = '1d073f520a4a3f6438a500fea88407964da6d4a87489719bedfa7445b76c019fdd95a5c39576ca190d7ac22c906b33d5250a6f48cb7eda2b6af3e86ec5f09dfc' + download_url = 'https://github.com/hithereai/Real-ESRGAN/releases/download/real-esrgan-ncnn-windows/realesrgan-ncnn-windows.zip' + elif current_user_os == 'Linux': + zip_file_name = 'realesrgan-ncnn-linux.zip' + executble_name = 'realesrgan-ncnn-vulkan' + zip_checksum_value = 'df44c4e9a1ff66331079795f018a67fbad8ce37c4472929a56b5a38440cf96982d6e164a086b438c3d26d269025290dd6498bd50846bda8691521ecf8f0fafdf' + download_url = 'https://github.com/hithereai/Real-ESRGAN/releases/download/real-esrgan-ncnn-linux/realesrgan-ncnn-linux.zip' + elif current_user_os == 'Mac': + zip_file_name = 'realesrgan-ncnn-mac.zip' + executble_name = 'realesrgan-ncnn-vulkan' + zip_checksum_value = '65f09472025b55b18cf6ba64149ede8cded90c20e18d35a9edb1ab60715b383a6ffbf1be90d973fc2075cf99d4cc1411fbdc459411af5c904f544b8656111469' + download_url = 'https://github.com/hithereai/Real-ESRGAN/releases/download/real-esrgan-ncnn-mac/realesrgan-ncnn-mac.zip' + else: # who are you then? + raise Exception(f"No support for OS type: {current_user_os}") + + # set paths + realesrgan_ncnn_folder = os.path.join(models_folder, 'realesrgan_ncnn') + realesrgan_exec_path = os.path.join(realesrgan_ncnn_folder, executble_name) + realesrgan_zip_path = os.path.join(realesrgan_ncnn_folder, zip_file_name) + # return if exec file already exist + if os.path.exists(realesrgan_exec_path): + return + try: + os.makedirs(realesrgan_ncnn_folder, exist_ok=True) + # download exec and model files from url + load_file_from_url(download_url, realesrgan_ncnn_folder) + # check downloaded zip's hash + with open(realesrgan_zip_path, 'rb') as f: + file_hash = checksum(realesrgan_zip_path) + # wrong hash, file is probably broken/ download interrupted + if file_hash != zip_checksum_value: + raise Exception(f"Error while downloading {realesrgan_zip_path}. Please download from: {download_url}, and extract its contents into: {models_folder}/realesrgan_ncnn") + # hash ok, extract zip contents into our folder + with zipfile.ZipFile(realesrgan_zip_path, 'r') as zip_ref: + zip_ref.extractall(realesrgan_ncnn_folder) + # delete the zip file + os.remove(realesrgan_zip_path) + # chmod 755 the exec if we're in a linux machine, otherwise we'd get permission errors + if current_user_os in ('Linux', 'Mac'): + os.chmod(realesrgan_exec_path, 0o755) + # enable running the exec for mac users + if current_user_os == 'Mac': + os.system(f'xattr -d com.apple.quarantine "{realesrgan_exec_path}"') + + except Exception as e: + raise Exception(f"Error while downloading {realesrgan_zip_path}. Please download from: {download_url}, and extract its contents into: {models_folder}/realesrgan_ncnn") + +def make_upscale_v2(upscale_factor, upscale_model, keep_imgs, imgs_raw_path, imgs_batch_id, deforum_models_path, current_user_os, ffmpeg_location, ffmpeg_crf, ffmpeg_preset, fps, stitch_from_frame, stitch_to_frame, audio_path, add_soundtrack): + # get clean number from 'x2, x3' etc + clean_num_r_up_factor = extract_number(upscale_factor) + # set paths + realesrgan_ncnn_location = os.path.join(deforum_models_path, 'realesrgan_ncnn', 'realesrgan-ncnn-vulkan' + ('.exe' if current_user_os == 'Windows' else '')) + upscaled_folder_path = os.path.join(imgs_raw_path, f"{imgs_batch_id}_upscaled") + temp_folder_to_keep_raw_ims = os.path.join(upscaled_folder_path, 'temp_raw_imgs_to_upscale') + out_upscaled_mp4_path = os.path.join(imgs_raw_path, f"{imgs_batch_id}_Upscaled_{upscale_factor}.mp4") + # download upscaling model if needed + check_and_download_realesrgan_ncnn(deforum_models_path, current_user_os) + # make a folder with only the imgs we need to duplicate so we can call the ncnn with the folder syntax (quicker!) + duplicate_pngs_from_folder(from_folder=imgs_raw_path, to_folder=temp_folder_to_keep_raw_ims, img_batch_id=imgs_batch_id, orig_vid_name='Dummy') + # set dynamic cmd command + cmd = [realesrgan_ncnn_location, '-i', temp_folder_to_keep_raw_ims, '-o', upscaled_folder_path, '-s', str(clean_num_r_up_factor), '-n', upscale_model] + # msg to print - need it to hide that text later on (!) + msg_to_print = f"Upscaling raw output PNGs using {upscale_model} at {upscale_factor}..." + # blink the msg in the cli until action is done + console.print(msg_to_print, style="blink yellow", end="") + start_time = time.time() + # make call to ncnn upscaling executble + process = subprocess.run(cmd, capture_output=True, check=True, text=True, cwd=(os.path.join(deforum_models_path, 'realesrgan_ncnn') if current_user_os == 'Mac' else None)) + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f"\rUpscaling \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) + # set custom path for ffmpeg func below + upscaled_imgs_path_for_ffmpeg = os.path.join(upscaled_folder_path, f"{imgs_batch_id}_%05d.png") + # stitch video from upscaled pngs + ffmpeg_stitch_video(ffmpeg_location=ffmpeg_location, fps=fps, outmp4_path=out_upscaled_mp4_path, stitch_from_frame=stitch_from_frame, stitch_to_frame=stitch_to_frame, imgs_path=upscaled_imgs_path_for_ffmpeg, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=ffmpeg_crf, preset=ffmpeg_preset) + + # delete the duplicated raw imgs + shutil.rmtree(temp_folder_to_keep_raw_ims) + + if not keep_imgs: + shutil.rmtree(upscaled_folder_path) +# NCNN Upscale section END \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/video_audio_utilities.py b/extensions/deforum/scripts/deforum_helpers/video_audio_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcf90119de1c6603abc3c905d0a902db499ca98 --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/video_audio_utilities.py @@ -0,0 +1,305 @@ +import os +import cv2 +import shutil +import math +import requests +import subprocess +import time +from pkg_resources import resource_filename +from modules.shared import state +from .general_utils import checksum, duplicate_pngs_from_folder +from basicsr.utils.download_util import load_file_from_url +from .rich import console + +# e.g gets 'x2' returns just 2 as int +def extract_number(string): + return int(string[1:]) if len(string) > 1 and string[1:].isdigit() else -1 + +def vid2frames(video_path, video_in_frame_path, n=1, overwrite=True, extract_from_frame=0, extract_to_frame=-1, out_img_format='jpg', numeric_files_output = False): + if (extract_to_frame <= extract_from_frame) and extract_to_frame != -1: + raise RuntimeError('Error: extract_to_frame can not be higher than extract_from_frame') + + if n < 1: n = 1 #HACK Gradio interface does not currently allow min/max in gr.Number(...) + + # check vid path using a function and only enter if we get True + if is_vid_path_valid(video_path): + + name = get_frame_name(video_path) + + vidcap = cv2.VideoCapture(video_path) + video_fps = vidcap.get(cv2.CAP_PROP_FPS) + + input_content = [] + if os.path.exists(video_in_frame_path) : + input_content = os.listdir(video_in_frame_path) + + # check if existing frame is the same video, if not we need to erase it and repopulate + if len(input_content) > 0: + #get the name of the existing frame + content_name = get_frame_name(input_content[0]) + if not content_name.startswith(name): + overwrite = True + + # grab the frame count to check against existing directory len + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # raise error if the user wants to skip more frames than exist + if n >= frame_count : + raise RuntimeError('Skipping more frames than input video contains. extract_nth_frames larger than input frames') + + expected_frame_count = math.ceil(frame_count / n) + # Check to see if the frame count is matches the number of files in path + if overwrite or expected_frame_count != len(input_content): + shutil.rmtree(video_in_frame_path) + os.makedirs(video_in_frame_path, exist_ok=True) # just deleted the folder so we need to make it again + input_content = os.listdir(video_in_frame_path) + + print(f"Trying to extract frames from video with input FPS of {video_fps}. Please wait patiently.") + if len(input_content) == 0: + vidcap.set(cv2.CAP_PROP_POS_FRAMES, extract_from_frame) # Set the starting frame + success,image = vidcap.read() + count = extract_from_frame + t=1 + success = True + while success: + if state.interrupted: + return + if (count <= extract_to_frame or extract_to_frame == -1) and count % n == 0: + if numeric_files_output == True: + cv2.imwrite(video_in_frame_path + os.path.sep + f"{t:05}.{out_img_format}" , image) # save frame as file + else: + cv2.imwrite(video_in_frame_path + os.path.sep + name + f"{t:05}.{out_img_format}" , image) # save frame as file + t += 1 + success,image = vidcap.read() + count += 1 + print(f"Successfully extracted {count} frames from video.") + else: + print("Frames already unpacked") + vidcap.release() + return video_fps + +# make sure the video_path provided is an existing local file or a web URL with a supported file extension +def is_vid_path_valid(video_path): + # make sure file format is supported! + file_formats = ["mov", "mpeg", "mp4", "m4v", "avi", "mpg", "webm"] + extension = video_path.rsplit('.', 1)[-1].lower() + # vid path is actually a URL, check it + if video_path.startswith('http://') or video_path.startswith('https://'): + response = requests.head(video_path, allow_redirects=True) + if response.status_code == 404: + raise ConnectionError("Video URL is not valid. Response status code: {}".format(response.status_code)) + elif response.status_code == 302: + response = requests.head(response.headers['location'], allow_redirects=True) + if response.status_code != 200: + raise ConnectionError("Video URL is not valid. Response status code: {}".format(response.status_code)) + if extension not in file_formats: + raise ValueError("Video file format '{}' not supported. Supported formats are: {}".format(extension, file_formats)) + else: + if not os.path.exists(video_path): + raise RuntimeError("Video path does not exist.") + if extension not in file_formats: + raise ValueError("Video file format '{}' not supported. Supported formats are: {}".format(extension, file_formats)) + return True + +# quick-retreive frame count, FPS and H/W dimensions of a video (local or URL-based) +def get_quick_vid_info(vid_path): + vidcap = cv2.VideoCapture(vid_path) + video_fps = vidcap.get(cv2.CAP_PROP_FPS) + video_frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + video_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + video_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + vidcap.release() + if video_fps.is_integer(): + video_fps = int(video_fps) + + return video_fps, video_frame_count, (video_width, video_height) + +# Stitch images to a h264 mp4 video using ffmpeg +def ffmpeg_stitch_video(ffmpeg_location=None, fps=None, outmp4_path=None, stitch_from_frame=0, stitch_to_frame=None, imgs_path=None, add_soundtrack=None, audio_path=None, crf=17, preset='veryslow'): + start_time = time.time() + + print(f"Got a request to stitch frames to video using FFmpeg.\nFrames:\n{imgs_path}\nTo Video:\n{outmp4_path}") + msg_to_print = f"Stitching *video*..." + console.print(msg_to_print, style="blink yellow", end="") + if stitch_to_frame == -1: + stitch_to_frame = 9999999 + try: + cmd = [ + ffmpeg_location, + '-y', + '-vcodec', 'png', + '-r', str(float(fps)), + '-start_number', str(stitch_from_frame), + '-i', imgs_path, + '-frames:v', str(stitch_to_frame), + '-c:v', 'libx264', + '-vf', + f'fps={float(fps)}', + '-pix_fmt', 'yuv420p', + '-crf', str(crf), + '-preset', preset, + '-pattern_type', 'sequence', + outmp4_path + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + except FileNotFoundError: + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + raise FileNotFoundError("FFmpeg not found. Please make sure you have a working ffmpeg path under 'ffmpeg_location' parameter.") + except Exception as e: + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + raise Exception(f'Error stitching frames to video. Actual runtime error:{e}') + + if add_soundtrack != 'None': + audio_add_start_time = time.time() + try: + cmd = [ + ffmpeg_location, + '-i', + outmp4_path, + '-i', + audio_path, + '-map', '0:v', + '-map', '1:a', + '-c:v', 'copy', + '-shortest', + outmp4_path+'.temp.mp4' + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + raise RuntimeError(stderr) + os.replace(outmp4_path+'.temp.mp4', outmp4_path) + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f"\rFFmpeg Video+Audio stitching \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) + except Exception as e: + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f'\rError adding audio to video. Actual error: {e}', flush=True) + print(f"FFMPEG Video (sorry, no audio) stitching \033[33mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) + else: + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f"\rVideo stitching \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!", flush=True) + +def get_frame_name(path): + name = os.path.basename(path) + name = os.path.splitext(name)[0] + return name + +def get_next_frame(outdir, video_path, frame_idx, mask=False): + frame_path = 'inputframes' + if (mask): frame_path = 'maskframes' + return os.path.join(outdir, frame_path, get_frame_name(video_path) + f"{frame_idx+1:05}.jpg") + +def find_ffmpeg_binary(): + try: + import google.colab + return 'ffmpeg' + except: + pass + for package in ['imageio_ffmpeg', 'imageio-ffmpeg']: + try: + package_path = resource_filename(package, 'binaries') + files = [os.path.join(package_path, f) for f in os.listdir(package_path) if f.startswith("ffmpeg-")] + files.sort(key=lambda x: os.path.getmtime(x), reverse=True) + return files[0] if files else 'ffmpeg' + except: + return 'ffmpeg' + +# These 2 functions belong to "stitch frames to video" in Output tab +def get_manual_frame_to_vid_output_path(input_path): + root, ext = os.path.splitext(input_path) + base, _ = root.rsplit("_", 1) + output_path = f"{base}.mp4" + i = 1 + while os.path.exists(output_path): + output_path = f"{base}_{i}.mp4" + i += 1 + return output_path + +def direct_stitch_vid_from_frames(image_path, fps, f_location, f_crf, f_preset, add_soundtrack, audio_path): + import re + # checking - do we actually have at least 4 matched files for the provided pattern + file_list = [image_path % i for i in range(4)] + exists_list = [os.path.isfile(file) for file in file_list] + # we got 4 files, moving on + if all(exists_list): + out_mp4_path = get_manual_frame_to_vid_output_path(image_path) + ffmpeg_stitch_video(ffmpeg_location=f_location, fps=fps, outmp4_path=out_mp4_path, stitch_from_frame=0, stitch_to_frame=-1, imgs_path=image_path, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=f_crf, preset=f_preset) + else: + print("Couldn't find images that match the provided path/ pattern. At least 2 matched images are required.") +# end of 2 stitch frame to video funcs + +# returns True if filename (could be also media URL) contains an audio stream, othehrwise False +def media_file_has_audio(filename, ffmpeg_location): + result = subprocess.run([ffmpeg_location, "-i", filename, "-af", "volumedetect", "-f", "null", "-"], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) + output = result.stderr.decode() + return True if "Stream #0:1: Audio: " in output or "Stream #0:1(und): Audio" in output else False + +# download gifski binaries if needed - linux and windows only atm (apple users won't even see the option) +def check_and_download_gifski(models_folder, current_user_os): + if current_user_os == 'Windows': + file_name = 'gifski.exe' + checksum_value = 'b0dd261ad021c31c7fdb99db761b45165e6b2a7e8e09c5d070a2b8064b575d7a4976c364d8508b28a6940343119b16a23e9f7d76f1f3d5ff02289d3068b469cf' + download_url = 'https://github.com/hithereai/d/releases/download/giski-windows-bin/gifski.exe' + elif current_user_os == 'Linux': + file_name = 'gifski' + checksum_value = 'e65bf9502bca520a7fd373397e41078d5c73db12ec3e9b47458c282d076c04fa697adecb5debb5d37fc9cbbee0673bb95e78d92c1cf813b4f5cc1cabe96880ff' + download_url = 'https://github.com/hithereai/d/releases/download/gifski-linux-bin/gifski' + elif current_user_os == 'Mac': + file_name = 'gifski' + checksum_value = '622a65d25609677169ed2c1c53fd9aa496a98b357cf84d0c3627ae99c85a565d61ca42cdc4d24ed6d60403bb79b6866ce24f3c4b6fff58c4d27632264a96353c' + download_url = 'https://github.com/hithereai/d/releases/download/gifski-mac-bin/gifski' + else: # who are you then? + raise Exception(f"No support for OS type: {current_user_os}") + + file_path = os.path.join(models_folder, file_name) + + if not os.path.exists(file_path): + load_file_from_url(download_url, models_folder) + if current_user_os in ['Linux','Mac']: + os.chmod(file_path, 0o755) + if current_user_os == 'Mac': + # enable running the exec for mac users + os.system(f'xattr -d com.apple.quarantine "{file_path}"') + if checksum(file_path) != checksum_value: + raise Exception(f"Error while downloading {file_name}. Please download from: {download_url} and place in: {models_folder}") + +# create a gif using gifski - limited to up to 30 fps (from the ui; if users wanna try to hack it, results are not good, but possible up to 100 fps theoretically) +def make_gifski_gif(imgs_raw_path, imgs_batch_id, fps, models_folder, current_user_os): + import glob + msg_to_print = f"Stitching *gif* from frames using Gifski..." + # blink the msg in the cli until action is done + console.print(msg_to_print, style="blink yellow", end="") + start_time = time.time() + gifski_location = os.path.join(models_folder, 'gifski' + ('.exe' if current_user_os == 'Windows' else '')) + final_gif_path = os.path.join(imgs_raw_path, imgs_batch_id + '.gif') + if current_user_os == "Linux": + input_img_pattern = imgs_batch_id + '_0*.png' + input_img_files = [os.path.join(imgs_raw_path, file) for file in sorted(glob.glob(os.path.join(imgs_raw_path, input_img_pattern)))] + cmd = [gifski_location, '-o', final_gif_path] + input_img_files + ['--fps', str(fps), '--quality', str(95)] + elif current_user_os == "Windows": + input_img_pattern_for_gifski = os.path.join(imgs_raw_path, imgs_batch_id + '_0*.png') + cmd = [gifski_location, '-o', final_gif_path, input_img_pattern_for_gifski, '--fps', str(fps), '--quality', str(95)] + else: # should never this else as we check before, but just in case + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + raise Exception(f"No support for OS type: {current_user_os}") + + check_and_download_gifski(models_folder, current_user_os) + + try: + process = subprocess.run(cmd, capture_output=True, check=True, text=True, cwd=(models_folder if current_user_os == 'Mac' else None)) + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f"GIF stitching \033[0;32mdone\033[0m in {time.time() - start_time:.2f} seconds!") + except Exception as e: + print("\r" + " " * len(msg_to_print), end="", flush=True) + print(f"\r{msg_to_print}", flush=True) + print(f"GIF stitching *failed* with error:\n{e}") \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/webui_sd_pipeline.py b/extensions/deforum/scripts/deforum_helpers/webui_sd_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd1baab3461db6a4ee896dae2ddfc31cf22973f --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/webui_sd_pipeline.py @@ -0,0 +1,49 @@ +from modules.processing import StableDiffusionProcessingImg2Img +from modules.shared import opts, sd_model +import os + +def get_webui_sd_pipeline(args, root, frame): + import re + assert args.prompt is not None + + # Setup the pipeline + p = StableDiffusionProcessingImg2Img( + sd_model=sd_model, + outpath_samples = opts.outdir_samples or opts.outdir_img2img_samples, + #we'll setup the rest later + ) + + os.makedirs(args.outdir, exist_ok=True) + p.width, p.height = map(lambda x: x - x % 64, (args.W, args.H)) + p.steps = args.steps + p.seed = args.seed + p.sampler_name = args.sampler + p.batch_size = args.n_batch + p.tiling = args.tiling + p.restore_faces = args.restore_faces + p.subseed = args.subseed + p.subseed_strength = args.subseed_strength + p.seed_resize_from_w = args.seed_resize_from_w + p.seed_resize_from_h = args.seed_resize_from_h + p.fill = args.fill + p.ddim_eta = args.ddim_eta + p.batch_size = args.n_samples + p.width = args.W + p.height = args.H + p.seed = args.seed + p.do_not_save_samples = not args.save_sample_per_step + p.sampler_name = args.sampler + p.mask_blur = args.mask_overlay_blur + p.extra_generation_params["Mask blur"] = args.mask_overlay_blur + p.n_iter = 1 + p.steps = args.steps + if opts.img2img_fix_steps: + p.denoising_strength = 1 / (1 - args.strength + 1.0/args.steps) #see https://github.com/deforum-art/deforum-for-automatic1111-webui/issues/3 + else: + p.denoising_strength = 1 - args.strength + p.cfg_scale = args.scale + p.image_cfg_scale = args.pix2pix_img_cfg_scale + p.outpath_samples = root.outpath_samples + + + return p \ No newline at end of file diff --git a/extensions/deforum/scripts/deforum_helpers/word_masking.py b/extensions/deforum/scripts/deforum_helpers/word_masking.py new file mode 100644 index 0000000000000000000000000000000000000000..9b16ba64bfe798266f2de2436bd6802285d62c6e --- /dev/null +++ b/extensions/deforum/scripts/deforum_helpers/word_masking.py @@ -0,0 +1,41 @@ +import os +import torch +from PIL import Image +from torchvision import transforms +from clipseg.models.clipseg import CLIPDensePredT + +preclipseg_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + transforms.Resize((512, 512)), #TODO: check if the size is hardcoded +]) + +def find_clipseg(root): + src_basedirs = [] + for basedir in root.basedirs: + src_basedirs.append(basedir + '/scripts/deforum_helpers/src') + src_basedirs.append(basedir + '/extensions/deforum/scripts/deforum_helpers/src') + src_basedirs.append(basedir + '/extensions/deforum-for-automatic1111-webui/scripts/deforum_helpers/src') + + for basedir in src_basedirs: + pth = os.path.join(basedir, './clipseg/weights/rd64-uni.pth') + if os.path.exists(pth): + return pth + raise Exception('CLIPseg weights not found!') + +def setup_clipseg(root): + model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) + model.eval() + model.load_state_dict(torch.load(find_clipseg(root), map_location=root.device), strict=False) + + model.to(root.device) + root.clipseg_model = model + +def get_word_mask(root, frame, word_mask): + if root.clipseg_model is None: + setup_clipseg(root) + img = preclipseg_transform(frame).to(root.device, dtype=torch.float32) + word_masks = [word_mask] + with torch.no_grad(): + preds = root.clipseg_model(img.repeat(len(word_masks),1,1,1), word_masks)[0] + return Image.fromarray(torch.sigmoid(preds[0][0]).multiply(255).to(dtype=torch.uint8,device='cpu').numpy()) diff --git a/extensions/deforum/style.css b/extensions/deforum/style.css new file mode 100644 index 0000000000000000000000000000000000000000..5b3615d207357b2b00c1ba32a737e213e1bdd5ce --- /dev/null +++ b/extensions/deforum/style.css @@ -0,0 +1,36 @@ +#vid_to_interpolate_chosen_file .w-full, #vid_to_upscale_chosen_file .w-full, #controlnet_input_video_chosen_file .w-full, #controlnet_input_video_mask_chosen_file .w-full { + display: flex !important; + align-items: flex-start !important; + justify-content: center !important; +} + +#vid_to_interpolate_chosen_file, #vid_to_upscale_chosen_file, #controlnet_input_video_chosen_file, #controlnet_input_video_mask_chosen_file { + height: 85px !important; +} + +#save_zip_deforum, #save_deforum { + display: none; +} + +#extra_schedules::before { + content: "Schedules:"; + font-size: 10px !important; +} + +#hybrid_msg_html { + color: Tomato !important; + margin-top: 5px !important; + text-align: center !important; + font-size: 20px !important; + font-weight: bold !important; +} + +#deforum_results .flex #image_buttons_deforum #img2img_tab, +#deforum_results .flex #image_buttons_deforum #inpaint_tab, +#deforum_results .flex #image_buttons_deforum #extras_tab { + display: none !important; +} + +#controlnet_not_found_html_msg { + color: Tomato; +} diff --git a/extensions/put extensions here.txt b/extensions/put extensions here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/html/card-no-preview.png b/html/card-no-preview.png new file mode 100644 index 0000000000000000000000000000000000000000..f135fc4ec1e9a9f0850a310fbfde4b8811232437 --- /dev/null +++ b/html/card-no-preview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab5740a5ae4494cd483daa1c5ba577b62260acd19cf94198c527ae05650f32dd +size 84440 diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html new file mode 100644 index 0000000000000000000000000000000000000000..8a5e2fbd223e71abacca9a602bd1be154f5fb520 --- /dev/null +++ b/html/extra-networks-card.html @@ -0,0 +1,12 @@ +
    +
    +
    + + +
    + {name} +
    +
    + diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html new file mode 100644 index 0000000000000000000000000000000000000000..389358d6c4b383fdc3c5686e029e7b3b1ae9a493 --- /dev/null +++ b/html/extra-networks-no-cards.html @@ -0,0 +1,8 @@ +
    +

    Nothing here. Add some content to the following directories:

    + +
      +{dirs} +
    +
    + diff --git a/html/footer.html b/html/footer.html new file mode 100644 index 0000000000000000000000000000000000000000..bad87ff619df2488b9732017a1e5658e240adb2e --- /dev/null +++ b/html/footer.html @@ -0,0 +1,13 @@ +
    + API +  •  + Github +  •  + Gradio +  •  + Reload UI +
    +
    +
    +{versions} +
    diff --git a/html/image-update.svg b/html/image-update.svg new file mode 100644 index 0000000000000000000000000000000000000000..3abf12df0f7774c13203e3c49ec3544649df42f4 --- /dev/null +++ b/html/image-update.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/html/licenses.html b/html/licenses.html new file mode 100644 index 0000000000000000000000000000000000000000..570630eb4ada6511ac5c8cd6c06f405a1b55c2bd --- /dev/null +++ b/html/licenses.html @@ -0,0 +1,419 @@ + + +

    CodeFormer

    +Parts of CodeFormer code had to be copied to be compatible with GFPGAN. +
    +S-Lab License 1.0
    +
    +Copyright 2022 S-Lab
    +
    +Redistribution and use for non-commercial purpose in source and
    +binary forms, with or without modification, are permitted provided
    +that the following conditions are met:
    +
    +1. Redistributions of source code must retain the above copyright
    +   notice, this list of conditions and the following disclaimer.
    +
    +2. Redistributions in binary form must reproduce the above copyright
    +   notice, this list of conditions and the following disclaimer in
    +   the documentation and/or other materials provided with the
    +   distribution.
    +
    +3. Neither the name of the copyright holder nor the names of its
    +   contributors may be used to endorse or promote products derived
    +   from this software without specific prior written permission.
    +
    +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    +
    +In the event that redistribution and/or use for commercial purpose in
    +source or binary forms, with or without modification is required,
    +please contact the contributor(s) of the work.
    +
    + + +

    ESRGAN

    +Code for architecture and reading models copied. +
    +MIT License
    +
    +Copyright (c) 2021 victorca25
    +
    +Permission is hereby granted, free of charge, to any person obtaining a copy
    +of this software and associated documentation files (the "Software"), to deal
    +in the Software without restriction, including without limitation the rights
    +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    +copies of the Software, and to permit persons to whom the Software is
    +furnished to do so, subject to the following conditions:
    +
    +The above copyright notice and this permission notice shall be included in all
    +copies or substantial portions of the Software.
    +
    +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    +SOFTWARE.
    +
    + +

    Real-ESRGAN

    +Some code is copied to support ESRGAN models. +
    +BSD 3-Clause License
    +
    +Copyright (c) 2021, Xintao Wang
    +All rights reserved.
    +
    +Redistribution and use in source and binary forms, with or without
    +modification, are permitted provided that the following conditions are met:
    +
    +1. Redistributions of source code must retain the above copyright notice, this
    +   list of conditions and the following disclaimer.
    +
    +2. Redistributions in binary form must reproduce the above copyright notice,
    +   this list of conditions and the following disclaimer in the documentation
    +   and/or other materials provided with the distribution.
    +
    +3. Neither the name of the copyright holder nor the names of its
    +   contributors may be used to endorse or promote products derived from
    +   this software without specific prior written permission.
    +
    +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
    +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
    +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
    +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
    +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
    +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
    +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
    +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    +
    + +

    InvokeAI

    +Some code for compatibility with OSX is taken from lstein's repository. +
    +MIT License
    +
    +Copyright (c) 2022 InvokeAI Team
    +
    +Permission is hereby granted, free of charge, to any person obtaining a copy
    +of this software and associated documentation files (the "Software"), to deal
    +in the Software without restriction, including without limitation the rights
    +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    +copies of the Software, and to permit persons to whom the Software is
    +furnished to do so, subject to the following conditions:
    +
    +The above copyright notice and this permission notice shall be included in all
    +copies or substantial portions of the Software.
    +
    +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    +SOFTWARE.
    +
    + +

    LDSR

    +Code added by contirubtors, most likely copied from this repository. +
    +MIT License
    +
    +Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
    +
    +Permission is hereby granted, free of charge, to any person obtaining a copy
    +of this software and associated documentation files (the "Software"), to deal
    +in the Software without restriction, including without limitation the rights
    +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    +copies of the Software, and to permit persons to whom the Software is
    +furnished to do so, subject to the following conditions:
    +
    +The above copyright notice and this permission notice shall be included in all
    +copies or substantial portions of the Software.
    +
    +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    +SOFTWARE.
    +
    + +

    CLIP Interrogator

    +Some small amounts of code borrowed and reworked. +
    +MIT License
    +
    +Copyright (c) 2022 pharmapsychotic
    +
    +Permission is hereby granted, free of charge, to any person obtaining a copy
    +of this software and associated documentation files (the "Software"), to deal
    +in the Software without restriction, including without limitation the rights
    +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    +copies of the Software, and to permit persons to whom the Software is
    +furnished to do so, subject to the following conditions:
    +
    +The above copyright notice and this permission notice shall be included in all
    +copies or substantial portions of the Software.
    +
    +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    +SOFTWARE.
    +
    + +

    SwinIR

    +Code added by contributors, most likely copied from this repository. + +
    +                                 Apache License
    +                           Version 2.0, January 2004
    +                        http://www.apache.org/licenses/
    +
    +   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
    +
    +   1. Definitions.
    +
    +      "License" shall mean the terms and conditions for use, reproduction,
    +      and distribution as defined by Sections 1 through 9 of this document.
    +
    +      "Licensor" shall mean the copyright owner or entity authorized by
    +      the copyright owner that is granting the License.
    +
    +      "Legal Entity" shall mean the union of the acting entity and all
    +      other entities that control, are controlled by, or are under common
    +      control with that entity. For the purposes of this definition,
    +      "control" means (i) the power, direct or indirect, to cause the
    +      direction or management of such entity, whether by contract or
    +      otherwise, or (ii) ownership of fifty percent (50%) or more of the
    +      outstanding shares, or (iii) beneficial ownership of such entity.
    +
    +      "You" (or "Your") shall mean an individual or Legal Entity
    +      exercising permissions granted by this License.
    +
    +      "Source" form shall mean the preferred form for making modifications,
    +      including but not limited to software source code, documentation
    +      source, and configuration files.
    +
    +      "Object" form shall mean any form resulting from mechanical
    +      transformation or translation of a Source form, including but
    +      not limited to compiled object code, generated documentation,
    +      and conversions to other media types.
    +
    +      "Work" shall mean the work of authorship, whether in Source or
    +      Object form, made available under the License, as indicated by a
    +      copyright notice that is included in or attached to the work
    +      (an example is provided in the Appendix below).
    +
    +      "Derivative Works" shall mean any work, whether in Source or Object
    +      form, that is based on (or derived from) the Work and for which the
    +      editorial revisions, annotations, elaborations, or other modifications
    +      represent, as a whole, an original work of authorship. For the purposes
    +      of this License, Derivative Works shall not include works that remain
    +      separable from, or merely link (or bind by name) to the interfaces of,
    +      the Work and Derivative Works thereof.
    +
    +      "Contribution" shall mean any work of authorship, including
    +      the original version of the Work and any modifications or additions
    +      to that Work or Derivative Works thereof, that is intentionally
    +      submitted to Licensor for inclusion in the Work by the copyright owner
    +      or by an individual or Legal Entity authorized to submit on behalf of
    +      the copyright owner. For the purposes of this definition, "submitted"
    +      means any form of electronic, verbal, or written communication sent
    +      to the Licensor or its representatives, including but not limited to
    +      communication on electronic mailing lists, source code control systems,
    +      and issue tracking systems that are managed by, or on behalf of, the
    +      Licensor for the purpose of discussing and improving the Work, but
    +      excluding communication that is conspicuously marked or otherwise
    +      designated in writing by the copyright owner as "Not a Contribution."
    +
    +      "Contributor" shall mean Licensor and any individual or Legal Entity
    +      on behalf of whom a Contribution has been received by Licensor and
    +      subsequently incorporated within the Work.
    +
    +   2. Grant of Copyright License. Subject to the terms and conditions of
    +      this License, each Contributor hereby grants to You a perpetual,
    +      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
    +      copyright license to reproduce, prepare Derivative Works of,
    +      publicly display, publicly perform, sublicense, and distribute the
    +      Work and such Derivative Works in Source or Object form.
    +
    +   3. Grant of Patent License. Subject to the terms and conditions of
    +      this License, each Contributor hereby grants to You a perpetual,
    +      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
    +      (except as stated in this section) patent license to make, have made,
    +      use, offer to sell, sell, import, and otherwise transfer the Work,
    +      where such license applies only to those patent claims licensable
    +      by such Contributor that are necessarily infringed by their
    +      Contribution(s) alone or by combination of their Contribution(s)
    +      with the Work to which such Contribution(s) was submitted. If You
    +      institute patent litigation against any entity (including a
    +      cross-claim or counterclaim in a lawsuit) alleging that the Work
    +      or a Contribution incorporated within the Work constitutes direct
    +      or contributory patent infringement, then any patent licenses
    +      granted to You under this License for that Work shall terminate
    +      as of the date such litigation is filed.
    +
    +   4. Redistribution. You may reproduce and distribute copies of the
    +      Work or Derivative Works thereof in any medium, with or without
    +      modifications, and in Source or Object form, provided that You
    +      meet the following conditions:
    +
    +      (a) You must give any other recipients of the Work or
    +          Derivative Works a copy of this License; and
    +
    +      (b) You must cause any modified files to carry prominent notices
    +          stating that You changed the files; and
    +
    +      (c) You must retain, in the Source form of any Derivative Works
    +          that You distribute, all copyright, patent, trademark, and
    +          attribution notices from the Source form of the Work,
    +          excluding those notices that do not pertain to any part of
    +          the Derivative Works; and
    +
    +      (d) If the Work includes a "NOTICE" text file as part of its
    +          distribution, then any Derivative Works that You distribute must
    +          include a readable copy of the attribution notices contained
    +          within such NOTICE file, excluding those notices that do not
    +          pertain to any part of the Derivative Works, in at least one
    +          of the following places: within a NOTICE text file distributed
    +          as part of the Derivative Works; within the Source form or
    +          documentation, if provided along with the Derivative Works; or,
    +          within a display generated by the Derivative Works, if and
    +          wherever such third-party notices normally appear. The contents
    +          of the NOTICE file are for informational purposes only and
    +          do not modify the License. You may add Your own attribution
    +          notices within Derivative Works that You distribute, alongside
    +          or as an addendum to the NOTICE text from the Work, provided
    +          that such additional attribution notices cannot be construed
    +          as modifying the License.
    +
    +      You may add Your own copyright statement to Your modifications and
    +      may provide additional or different license terms and conditions
    +      for use, reproduction, or distribution of Your modifications, or
    +      for any such Derivative Works as a whole, provided Your use,
    +      reproduction, and distribution of the Work otherwise complies with
    +      the conditions stated in this License.
    +
    +   5. Submission of Contributions. Unless You explicitly state otherwise,
    +      any Contribution intentionally submitted for inclusion in the Work
    +      by You to the Licensor shall be under the terms and conditions of
    +      this License, without any additional terms or conditions.
    +      Notwithstanding the above, nothing herein shall supersede or modify
    +      the terms of any separate license agreement you may have executed
    +      with Licensor regarding such Contributions.
    +
    +   6. Trademarks. This License does not grant permission to use the trade
    +      names, trademarks, service marks, or product names of the Licensor,
    +      except as required for reasonable and customary use in describing the
    +      origin of the Work and reproducing the content of the NOTICE file.
    +
    +   7. Disclaimer of Warranty. Unless required by applicable law or
    +      agreed to in writing, Licensor provides the Work (and each
    +      Contributor provides its Contributions) on an "AS IS" BASIS,
    +      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
    +      implied, including, without limitation, any warranties or conditions
    +      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
    +      PARTICULAR PURPOSE. You are solely responsible for determining the
    +      appropriateness of using or redistributing the Work and assume any
    +      risks associated with Your exercise of permissions under this License.
    +
    +   8. Limitation of Liability. In no event and under no legal theory,
    +      whether in tort (including negligence), contract, or otherwise,
    +      unless required by applicable law (such as deliberate and grossly
    +      negligent acts) or agreed to in writing, shall any Contributor be
    +      liable to You for damages, including any direct, indirect, special,
    +      incidental, or consequential damages of any character arising as a
    +      result of this License or out of the use or inability to use the
    +      Work (including but not limited to damages for loss of goodwill,
    +      work stoppage, computer failure or malfunction, or any and all
    +      other commercial damages or losses), even if such Contributor
    +      has been advised of the possibility of such damages.
    +
    +   9. Accepting Warranty or Additional Liability. While redistributing
    +      the Work or Derivative Works thereof, You may choose to offer,
    +      and charge a fee for, acceptance of support, warranty, indemnity,
    +      or other liability obligations and/or rights consistent with this
    +      License. However, in accepting such obligations, You may act only
    +      on Your own behalf and on Your sole responsibility, not on behalf
    +      of any other Contributor, and only if You agree to indemnify,
    +      defend, and hold each Contributor harmless for any liability
    +      incurred by, or claims asserted against, such Contributor by reason
    +      of your accepting any such warranty or additional liability.
    +
    +   END OF TERMS AND CONDITIONS
    +
    +   APPENDIX: How to apply the Apache License to your work.
    +
    +      To apply the Apache License to your work, attach the following
    +      boilerplate notice, with the fields enclosed by brackets "[]"
    +      replaced with your own identifying information. (Don't include
    +      the brackets!)  The text should be enclosed in the appropriate
    +      comment syntax for the file format. We also recommend that a
    +      file or class name and description of purpose be included on the
    +      same "printed page" as the copyright notice for easier
    +      identification within third-party archives.
    +
    +   Copyright [2021] [SwinIR Authors]
    +
    +   Licensed under the Apache License, Version 2.0 (the "License");
    +   you may not use this file except in compliance with the License.
    +   You may obtain a copy of the License at
    +
    +       http://www.apache.org/licenses/LICENSE-2.0
    +
    +   Unless required by applicable law or agreed to in writing, software
    +   distributed under the License is distributed on an "AS IS" BASIS,
    +   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +   See the License for the specific language governing permissions and
    +   limitations under the License.
    +
    + +

    Memory Efficient Attention

    +The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that. +
    +MIT License
    +
    +Copyright (c) 2023 Alex Birch
    +Copyright (c) 2023 Amin Rezaei
    +
    +Permission is hereby granted, free of charge, to any person obtaining a copy
    +of this software and associated documentation files (the "Software"), to deal
    +in the Software without restriction, including without limitation the rights
    +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    +copies of the Software, and to permit persons to whom the Software is
    +furnished to do so, subject to the following conditions:
    +
    +The above copyright notice and this permission notice shall be included in all
    +copies or substantial portions of the Software.
    +
    +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    +SOFTWARE.
    +
    + diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js new file mode 100644 index 0000000000000000000000000000000000000000..0f164b82c1a1f9dd2ad0e6a745bcdd7e652a53e6 --- /dev/null +++ b/javascript/aspectRatioOverlay.js @@ -0,0 +1,113 @@ + +let currentWidth = null; +let currentHeight = null; +let arFrameTimeout = setTimeout(function(){},0); + +function dimensionChange(e, is_width, is_height){ + + if(is_width){ + currentWidth = e.target.value*1.0 + } + if(is_height){ + currentHeight = e.target.value*1.0 + } + + var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) + + if(!inImg2img){ + return; + } + + var targetElement = null; + + var tabIndex = get_tab_index('mode_img2img') + if(tabIndex == 0){ // img2img + targetElement = gradioApp().querySelector('div[data-testid=image] img'); + } else if(tabIndex == 1){ //Sketch + targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); + } else if(tabIndex == 2){ // Inpaint + targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); + } else if(tabIndex == 3){ // Inpaint sketch + targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); + } + + + if(targetElement){ + + var arPreviewRect = gradioApp().querySelector('#imageARPreview'); + if(!arPreviewRect){ + arPreviewRect = document.createElement('div') + arPreviewRect.id = "imageARPreview"; + gradioApp().getRootNode().appendChild(arPreviewRect) + } + + + + var viewportOffset = targetElement.getBoundingClientRect(); + + viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) + + scaledx = targetElement.naturalWidth*viewportscale + scaledy = targetElement.naturalHeight*viewportscale + + cleintRectTop = (viewportOffset.top+window.scrollY) + cleintRectLeft = (viewportOffset.left+window.scrollX) + cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) + cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) + + viewRectTop = cleintRectCentreY-(scaledy/2) + viewRectLeft = cleintRectCentreX-(scaledx/2) + arRectWidth = scaledx + arRectHeight = scaledy + + arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) + arscaledx = currentWidth*arscale + arscaledy = currentHeight*arscale + + arRectTop = cleintRectCentreY-(arscaledy/2) + arRectLeft = cleintRectCentreX-(arscaledx/2) + arRectWidth = arscaledx + arRectHeight = arscaledy + + arPreviewRect.style.top = arRectTop+'px'; + arPreviewRect.style.left = arRectLeft+'px'; + arPreviewRect.style.width = arRectWidth+'px'; + arPreviewRect.style.height = arRectHeight+'px'; + + clearTimeout(arFrameTimeout); + arFrameTimeout = setTimeout(function(){ + arPreviewRect.style.display = 'none'; + },2000); + + arPreviewRect.style.display = 'block'; + + } + +} + + +onUiUpdate(function(){ + var arPreviewRect = gradioApp().querySelector('#imageARPreview'); + if(arPreviewRect){ + arPreviewRect.style.display = 'none'; + } + var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) + if(inImg2img){ + let inputs = gradioApp().querySelectorAll('input'); + inputs.forEach(function(e){ + var is_width = e.parentElement.id == "img2img_width" + var is_height = e.parentElement.id == "img2img_height" + + if((is_width || is_height) && !e.classList.contains('scrollwatch')){ + e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) + e.classList.add('scrollwatch') + } + if(is_width){ + currentWidth = e.value*1.0 + } + if(is_height){ + currentHeight = e.value*1.0 + } + }) + } +}); diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js new file mode 100644 index 0000000000000000000000000000000000000000..11bcce1bcbdc0ed5c1004fbd9d971d255645826b --- /dev/null +++ b/javascript/contextMenus.js @@ -0,0 +1,177 @@ + +contextMenuInit = function(){ + let eventListenerApplied=false; + let menuSpecs = new Map(); + + const uid = function(){ + return Date.now().toString(36) + Math.random().toString(36).substr(2); + } + + function showContextMenu(event,element,menuEntries){ + let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; + let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; + + let oldMenu = gradioApp().querySelector('#context-menu') + if(oldMenu){ + oldMenu.remove() + } + + let tabButton = uiCurrentTab + let baseStyle = window.getComputedStyle(tabButton) + + const contextMenu = document.createElement('nav') + contextMenu.id = "context-menu" + contextMenu.style.background = baseStyle.background + contextMenu.style.color = baseStyle.color + contextMenu.style.fontFamily = baseStyle.fontFamily + contextMenu.style.top = posy+'px' + contextMenu.style.left = posx+'px' + + + + const contextMenuList = document.createElement('ul') + contextMenuList.className = 'context-menu-items'; + contextMenu.append(contextMenuList); + + menuEntries.forEach(function(entry){ + let contextMenuEntry = document.createElement('a') + contextMenuEntry.innerHTML = entry['name'] + contextMenuEntry.addEventListener("click", function(e) { + entry['func'](); + }) + contextMenuList.append(contextMenuEntry); + + }) + + gradioApp().getRootNode().appendChild(contextMenu) + + let menuWidth = contextMenu.offsetWidth + 4; + let menuHeight = contextMenu.offsetHeight + 4; + + let windowWidth = window.innerWidth; + let windowHeight = window.innerHeight; + + if ( (windowWidth - posx) < menuWidth ) { + contextMenu.style.left = windowWidth - menuWidth + "px"; + } + + if ( (windowHeight - posy) < menuHeight ) { + contextMenu.style.top = windowHeight - menuHeight + "px"; + } + + } + + function appendContextMenuOption(targetElementSelector,entryName,entryFunction){ + + currentItems = menuSpecs.get(targetElementSelector) + + if(!currentItems){ + currentItems = [] + menuSpecs.set(targetElementSelector,currentItems); + } + let newItem = {'id':targetElementSelector+'_'+uid(), + 'name':entryName, + 'func':entryFunction, + 'isNew':true} + + currentItems.push(newItem) + return newItem['id'] + } + + function removeContextMenuOption(uid){ + menuSpecs.forEach(function(v,k) { + let index = -1 + v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) + if(index>=0){ + v.splice(index, 1); + } + }) + } + + function addContextMenuEventListener(){ + if(eventListenerApplied){ + return; + } + gradioApp().addEventListener("click", function(e) { + let source = e.composedPath()[0] + if(source.id && source.id.indexOf('check_progress')>-1){ + return + } + + let oldMenu = gradioApp().querySelector('#context-menu') + if(oldMenu){ + oldMenu.remove() + } + }); + gradioApp().addEventListener("contextmenu", function(e) { + let oldMenu = gradioApp().querySelector('#context-menu') + if(oldMenu){ + oldMenu.remove() + } + menuSpecs.forEach(function(v,k) { + if(e.composedPath()[0].matches(k)){ + showContextMenu(e,e.composedPath()[0],v) + e.preventDefault() + return + } + }) + }); + eventListenerApplied=true + + } + + return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] +} + +initResponse = contextMenuInit(); +appendContextMenuOption = initResponse[0]; +removeContextMenuOption = initResponse[1]; +addContextMenuEventListener = initResponse[2]; + +(function(){ + //Start example Context Menu Items + let generateOnRepeat = function(genbuttonid,interruptbuttonid){ + let genbutton = gradioApp().querySelector(genbuttonid); + let interruptbutton = gradioApp().querySelector(interruptbuttonid); + if(!interruptbutton.offsetParent){ + genbutton.click(); + } + clearInterval(window.generateOnRepeatInterval) + window.generateOnRepeatInterval = setInterval(function(){ + if(!interruptbutton.offsetParent){ + genbutton.click(); + } + }, + 500) + } + + appendContextMenuOption('#txt2img_generate','Generate forever',function(){ + generateOnRepeat('#txt2img_generate','#txt2img_interrupt'); + }) + appendContextMenuOption('#img2img_generate','Generate forever',function(){ + generateOnRepeat('#img2img_generate','#img2img_interrupt'); + }) + + let cancelGenerateForever = function(){ + clearInterval(window.generateOnRepeatInterval) + } + + appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) + + appendContextMenuOption('#roll','Roll three', + function(){ + let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); + setTimeout(function(){rollbutton.click()},100) + setTimeout(function(){rollbutton.click()},200) + setTimeout(function(){rollbutton.click()},300) + } + ) +})(); +//End example Context Menu Items + +onUiUpdate(function(){ + addContextMenuEventListener() +}); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js new file mode 100644 index 0000000000000000000000000000000000000000..fe00892481a8ff8b997759b21cb36eb364db788b --- /dev/null +++ b/javascript/dragdrop.js @@ -0,0 +1,97 @@ +// allows drag-dropping files into gradio image elements, and also pasting images from clipboard + +function isValidImageList( files ) { + return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); +} + +function dropReplaceImage( imgWrap, files ) { + if ( ! isValidImageList( files ) ) { + return; + } + + const tmpFile = files[0]; + + imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); + const callback = () => { + const fileInput = imgWrap.querySelector('input[type="file"]'); + if ( fileInput ) { + if ( files.length === 0 ) { + files = new DataTransfer(); + files.items.add(tmpFile); + fileInput.files = files.files; + } else { + fileInput.files = files; + } + fileInput.dispatchEvent(new Event('change')); + } + }; + + if ( imgWrap.closest('#pnginfo_image') ) { + // special treatment for PNG Info tab, wait for fetch request to finish + const oldFetch = window.fetch; + window.fetch = async (input, options) => { + const response = await oldFetch(input, options); + if ( 'api/predict/' === input ) { + const content = await response.text(); + window.fetch = oldFetch; + window.requestAnimationFrame( () => callback() ); + return new Response(content, { + status: response.status, + statusText: response.statusText, + headers: response.headers + }) + } + return response; + }; + } else { + window.requestAnimationFrame( () => callback() ); + } +} + +window.document.addEventListener('dragover', e => { + const target = e.composedPath()[0]; + const imgWrap = target.closest('[data-testid="image"]'); + if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { + return; + } + e.stopPropagation(); + e.preventDefault(); + e.dataTransfer.dropEffect = 'copy'; +}); + +window.document.addEventListener('drop', e => { + const target = e.composedPath()[0]; + if (target.placeholder.indexOf("Prompt") == -1) { + return; + } + const imgWrap = target.closest('[data-testid="image"]'); + if ( !imgWrap ) { + return; + } + e.stopPropagation(); + e.preventDefault(); + const files = e.dataTransfer.files; + dropReplaceImage( imgWrap, files ); +}); + +window.addEventListener('paste', e => { + const files = e.clipboardData.files; + if ( ! isValidImageList( files ) ) { + return; + } + + const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] + .filter(el => uiElementIsVisible(el)); + if ( ! visibleImageFields.length ) { + return; + } + + const firstFreeImageField = visibleImageFields + .filter(el => el.querySelector('input[type=file]'))?.[0]; + + dropReplaceImage( + firstFreeImageField ? + firstFreeImageField : + visibleImageFields[visibleImageFields.length - 1] + , files ); +}); diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js new file mode 100644 index 0000000000000000000000000000000000000000..619bb1fa3d7a9aef3bc61a8d72b202b939039757 --- /dev/null +++ b/javascript/edit-attention.js @@ -0,0 +1,96 @@ +function keyupEditAttention(event){ + let target = event.originalTarget || event.composedPath()[0]; + if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; + if (! (event.metaKey || event.ctrlKey)) return; + + let isPlus = event.key == "ArrowUp" + let isMinus = event.key == "ArrowDown" + if (!isPlus && !isMinus) return; + + let selectionStart = target.selectionStart; + let selectionEnd = target.selectionEnd; + let text = target.value; + + function selectCurrentParenthesisBlock(OPEN, CLOSE){ + if (selectionStart !== selectionEnd) return false; + + // Find opening parenthesis around current cursor + const before = text.substring(0, selectionStart); + let beforeParen = before.lastIndexOf(OPEN); + if (beforeParen == -1) return false; + let beforeParenClose = before.lastIndexOf(CLOSE); + while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { + beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); + beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); + } + + // Find closing parenthesis around current cursor + const after = text.substring(selectionStart); + let afterParen = after.indexOf(CLOSE); + if (afterParen == -1) return false; + let afterParenOpen = after.indexOf(OPEN); + while (afterParenOpen !== -1 && afterParen > afterParenOpen) { + afterParen = after.indexOf(CLOSE, afterParen + 1); + afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); + } + if (beforeParen === -1 || afterParen === -1) return false; + + // Set the selection to the text between the parenthesis + const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); + const lastColon = parenContent.lastIndexOf(":"); + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + lastColon; + target.setSelectionRange(selectionStart, selectionEnd); + return true; + } + + // If the user hasn't selected anything, let's select their current parenthesis block + if(! selectCurrentParenthesisBlock('<', '>')){ + selectCurrentParenthesisBlock('(', ')') + } + + event.preventDefault(); + + closeCharacter = ')' + delta = opts.keyedit_precision_attention + + if (selectionStart > 0 && text[selectionStart - 1] == '<'){ + closeCharacter = '>' + delta = opts.keyedit_precision_extra + } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { + + // do not include spaces at the end + while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ + selectionEnd -= 1; + } + if(selectionStart == selectionEnd){ + return + } + + text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); + + selectionStart += 1; + selectionEnd += 1; + } + + end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if(String(weight).length == 1) weight += ".0" + + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + + target.focus(); + target.value = text; + target.selectionStart = selectionStart; + target.selectionEnd = selectionEnd; + + updateInput(target) +} + +addEventListener('keydown', (event) => { + keyupEditAttention(event); +}); \ No newline at end of file diff --git a/javascript/extensions.js b/javascript/extensions.js new file mode 100644 index 0000000000000000000000000000000000000000..c593cd2e5701db5a89f9b890bc952722ed5c3bbf --- /dev/null +++ b/javascript/extensions.js @@ -0,0 +1,49 @@ + +function extensions_apply(_, _){ + var disable = [] + var update = [] + + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + + if(x.name.startsWith("update_") && x.checked) + update.push(x.name.substr(7)) + }) + + restart_reload() + + return [JSON.stringify(disable), JSON.stringify(update)] +} + +function extensions_check(){ + var disable = [] + + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + }) + + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ + x.innerHTML = "Loading..." + }) + + + var id = randomId() + requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){ + + }) + + return [id, JSON.stringify(disable)] +} + +function install_extension_from_index(button, url){ + button.disabled = "disabled" + button.value = "Installing..." + + textarea = gradioApp().querySelector('#extension_to_install textarea') + textarea.value = url + updateInput(textarea) + + gradioApp().querySelector('#install_extension_button').click() +} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js new file mode 100644 index 0000000000000000000000000000000000000000..17bf200047dfa881ddf7f705788271fd52c29ae8 --- /dev/null +++ b/javascript/extraNetworks.js @@ -0,0 +1,107 @@ + +function setupExtraNetworksForTab(tabname){ + gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') + + var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div') + var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') + var refresh = gradioApp().getElementById(tabname+'_extra_refresh') + var close = gradioApp().getElementById(tabname+'_extra_close') + + search.classList.add('search') + tabs.appendChild(search) + tabs.appendChild(refresh) + tabs.appendChild(close) + + search.addEventListener("input", function(evt){ + searchTerm = search.value.toLowerCase() + + gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ + text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() + elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" + }) + }); +} + +var activePromptTextarea = {}; + +function setupExtraNetworks(){ + setupExtraNetworksForTab('txt2img') + setupExtraNetworksForTab('img2img') + + function registerPrompt(tabname, id){ + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if (! activePromptTextarea[tabname]){ + activePromptTextarea[tabname] = textarea + } + + textarea.addEventListener("focus", function(){ + activePromptTextarea[tabname] = textarea; + }); + } + + registerPrompt('txt2img', 'txt2img_prompt') + registerPrompt('txt2img', 'txt2img_neg_prompt') + registerPrompt('img2img', 'img2img_prompt') + registerPrompt('img2img', 'img2img_neg_prompt') +} + +onUiLoaded(setupExtraNetworks) + +var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; +var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; + +function tryToRemoveExtraNetworkFromPrompt(textarea, text){ + var m = text.match(re_extranet) + if(! m) return false + + var partToSearch = m[1] + var replaced = false + var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){ + m = found.match(re_extranet); + if(m[1] == partToSearch){ + replaced = true; + return "" + } + return found; + }) + + if(replaced){ + textarea.value = newTextareaText + return true; + } + + return false +} + +function cardClicked(tabname, textToAdd, allowNegativePrompt){ + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") + + if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ + textarea.value = textarea.value + " " + textToAdd + } + + updateInput(textarea) +} + +function saveCardPreview(event, tabname, filename){ + var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + var button = gradioApp().getElementById(tabname + '_save_preview') + + textarea.value = filename + updateInput(textarea) + + button.click() + + event.stopPropagation() + event.preventDefault() +} + +function extraNetworksSearchButton(tabs_id, event){ + searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea') + button = event.target + text = button.classList.contains("search-all") ? "" : button.textContent.trim() + + searchTextarea.value = text + updateInput(searchTextarea) +} \ No newline at end of file diff --git a/javascript/generationParams.js b/javascript/generationParams.js new file mode 100644 index 0000000000000000000000000000000000000000..95f050939b72a8d09d62de8d725caf1e7d15d3c0 --- /dev/null +++ b/javascript/generationParams.js @@ -0,0 +1,33 @@ +// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes + +let txt2img_gallery, img2img_gallery, modal = undefined; +onUiUpdate(function(){ + if (!txt2img_gallery) { + txt2img_gallery = attachGalleryListeners("txt2img") + } + if (!img2img_gallery) { + img2img_gallery = attachGalleryListeners("img2img") + } + if (!modal) { + modal = gradioApp().getElementById('lightboxModal') + modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); + } +}); + +let modalObserver = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText + if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') + gradioApp().getElementById(selectedTab+"_generation_info_button").click() + }); +}); + +function attachGalleryListeners(tab_name) { + gallery = gradioApp().querySelector('#'+tab_name+'_gallery') + gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); + gallery?.addEventListener('keydown', (e) => { + if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow + gradioApp().getElementById(tab_name+"_generation_info_button").click() + }); + return gallery; +} diff --git a/javascript/hints.js b/javascript/hints.js new file mode 100644 index 0000000000000000000000000000000000000000..f1199009b181b83713bb1d136e41d4b29e183634 --- /dev/null +++ b/javascript/hints.js @@ -0,0 +1,146 @@ +// mouseover tooltips for various UI elements + +titles = { + "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", + "Sampling method": "Which algorithm to use to produce the image", + "GFPGAN": "Restore low quality faces using GFPGAN neural network", + "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", + "DDIM": "Denoising Diffusion Implicit Models - best at inpainting", + "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", + + "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)", + "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)", + "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", + "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", + "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", + "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", + "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", + "\u{1f4c2}": "Open images output directory", + "\u{1f4be}": "Save style", + "\u{1f5d1}": "Clear prompt", + "\u{1f4cb}": "Apply selected styles to current prompt", + "\u{1f4d2}": "Paste available values into the field", + "\u{1f3b4}": "Show extra networks", + + + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", + "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", + + "Just resize": "Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.", + "Crop and resize": "Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.", + "Resize and fill": "Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.", + + "Mask blur": "How much to blur the mask before processing, in pixels.", + "Masked content": "What to put inside the masked area before processing it with Stable Diffusion.", + "fill": "fill it with colors of the image", + "original": "keep whatever was there originally", + "latent noise": "fill it with latent space noise", + "latent nothing": "fill it with latent space zeroes", + "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image", + + "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", + "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.", + + "Skip": "Stop processing current image and continue processing.", + "Interrupt": "Stop processing images and return any results accumulated so far.", + "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", + + "X values": "Separate values for X axis using commas.", + "Y values": "Separate values for Y axis using commas.", + + "None": "Do not do anything special", + "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)", + "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", + "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", + + "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", + "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order", + + "Tiling": "Produce an image that can be tiled.", + "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", + + "Variation seed": "Seed of a different picture to be mixed into the generation.", + "Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).", + "Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution", + "Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution", + + "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", + + "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime], [datetime
    ', None) + + self.token_mults = {} + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.bos_token_id + self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_pad = self.id_end + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.transformer.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9fbbe6ca1edfe420c8d5cbc737cdfee4a73622 --- /dev/null +++ b/modules/sd_hijack_clip_old.py @@ -0,0 +1,81 @@ +from modules import sd_hijack_clip +from modules import shared + + +def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + +def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..55a2ce4d19200acafd79e6fce7e017c4abc50a73 --- /dev/null +++ b/modules/sd_hijack_inpainting.py @@ -0,0 +1,103 @@ +import os +import torch + +from einops import repeat +from omegaconf import ListConfig + +import ldm.models.diffusion.ddpm +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms + +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddim import DDIMSampler, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +@torch.no_grad() +def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + + +def do_inpainting_hijack(): + # p_sample_plms is needed because PLMS can't work with dicts as conditionings + + ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py new file mode 100644 index 0000000000000000000000000000000000000000..3c727d3b75332508629458d23f7fb86cc9ede44b --- /dev/null +++ b/modules/sd_hijack_ip2p.py @@ -0,0 +1,13 @@ +import collections +import os.path +import sys +import gc +import time + +def should_hijack_ip2p(checkpoint_info): + from modules import sd_models_config + + ckpt_basename = os.path.basename(checkpoint_info.filename).lower() + cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() + + return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f733e8529fb6cd68d97b2f255bc705d0cd949fbc --- /dev/null +++ b/modules/sd_hijack_open_clip.py @@ -0,0 +1,37 @@ +import open_clip.tokenizer +import torch + +from modules import sd_hijack_clip, devices +from modules.shared import opts + +tokenizer = open_clip.tokenizer._tokenizer + + +class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers + z = self.wrapped.encode_with_transformer(tokens) + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py new file mode 100644 index 0000000000000000000000000000000000000000..c02d954c7ab040be940a3650a64d1d1978409fb5 --- /dev/null +++ b/modules/sd_hijack_optimizations.py @@ -0,0 +1,444 @@ +import math +import sys +import traceback +import psutil + +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + +from modules import shared, errors, devices +from modules.hypernetworks import hypernetwork + +from .sub_quadratic_attention import efficient_dot_product_attention + + +if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: + try: + import xformers.ops + shared.xformers_available = True + except Exception: + print("Cannot import xformers", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + del context, context_k, context_v, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v + + r1 = r1.to(dtype) + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion and modified +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r1 = r1.to(dtype) + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- +mem_total_gb = psutil.virtual_memory().total // (1 << 30) + +def einsum_op_compvis(q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + +def einsum_op_slice_0(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + +def einsum_op_slice_1(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) + return r + +def einsum_op_mps_v1(q, k, v): + if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + if slice_size % 4096 == 0: + slice_size -= 1 + return einsum_op_slice_1(q, k, v, slice_size) + +def einsum_op_mps_v2(q, k, v): + if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: + return einsum_op_compvis(q, k, v) + else: + return einsum_op_slice_0(q, k, v, 1) + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + +def einsum_op_cuda(q, k, v): + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + +def einsum_op(q, k, v): + if q.device.type == 'cuda': + return einsum_op_cuda(q, k, v) + + if q.device.type == 'mps': + if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: + return einsum_op_mps_v1(q, k, v) + return einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32) + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + +# -- End of code from https://github.com/invoke-ai/InvokeAI -- + + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.to(dtype) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: + chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) + + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + +def xformers_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + + out = out.to(dtype) + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + return self.to_out(out) + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + mem_free_total = get_available_vram() + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 + +def xformers_attnblock_forward(self, x): + try: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + except NotImplementedError: + return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..843ab66cfbd07e2b757a226584cc51656ff3f448 --- /dev/null +++ b/modules/sd_hijack_unet.py @@ -0,0 +1,79 @@ +import torch +from packaging import version + +from modules import devices +from modules.sd_hijack_utils import CondFunc + + +class TorchHijackForUnet: + """ + This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; + this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 + """ + + def __getattr__(self, item): + if item == 'cat': + return self.cat + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + def cat(self, tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + if a.shape[-2:] != b.shape[-2:]: + a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") + + tensors = (a, b) + + return torch.cat(tensors, *args, **kwargs) + + +th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + + if isinstance(cond, dict): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + + with devices.autocast(): + return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + + +ddpm_edit_hijack = None +def hijack_ddpm_edit(): + global ddpm_edit_hijack + if not ddpm_edit_hijack: + CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) + CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) + + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +if version.parse(torch.__version__) <= version.parse("1.13.1"): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) + +first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8684475ec5c85d7a6f5aa18238a3a5003b17234 --- /dev/null +++ b/modules/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac51c386fdb72610053e472c55104038ab4e6ba --- /dev/null +++ b/modules/sd_hijack_xlmr.py @@ -0,0 +1,34 @@ +import open_clip.tokenizer +import torch + +from modules import sd_hijack_clip, devices +from modules.shared import opts + + +class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.id_start = wrapped.config.bos_token_id + self.id_end = wrapped.config.eos_token_id + self.id_pad = wrapped.config.pad_token_id + + self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma + + def encode_with_transformers(self, tokens): + # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a + # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer + # layer to work with - you have to use the last + + attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) + features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) + z = features['projection_state'] + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.roberta.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + + return embedded diff --git a/modules/sd_models.py b/modules/sd_models.py new file mode 100644 index 0000000000000000000000000000000000000000..93959f55f325875f46bd6850b138a7742af107d7 --- /dev/null +++ b/modules/sd_models.py @@ -0,0 +1,495 @@ +import collections +import os.path +import sys +import gc +import torch +import re +import safetensors.torch +from omegaconf import OmegaConf +from os import mkdir +from urllib import request +import ldm.modules.midas as midas + +from ldm.util import instantiate_from_config + +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules.paths import models_path +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer + +model_dir = "Stable-diffusion" +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) + +checkpoints_list = {} +checkpoint_alisases = {} +checkpoints_loaded = collections.OrderedDict() + + +class CheckpointInfo: + def __init__(self, filename): + self.filename = filename + abspath = os.path.abspath(filename) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(filename) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + self.name = name + self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + self.hash = model_hash(filename) + + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) + self.shorthash = self.sha256[0:10] if self.sha256 else None + + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + + def register(self): + checkpoints_list[self.title] = self + for id in self.ids: + checkpoint_alisases[id] = self + + def calculate_shorthash(self): + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) + if self.sha256 is None: + return + + self.shorthash = self.sha256[0:10] + + if self.shorthash not in self.ids: + self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] + + checkpoints_list.pop(self.title) + self.title = f'{self.name} [{self.shorthash}]' + self.register() + + return self.shorthash + + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + + from transformers import logging, CLIPModel + + logging.set_verbosity_error() +except Exception: + pass + + +def setup_model(): + if not os.path.exists(model_path): + os.makedirs(model_path) + + list_models() + enable_midas_autodownload() + + +def checkpoint_tiles(): + def convert(name): + return int(name) if name.isdigit() else name.lower() + + def alphanumeric_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) + + +def list_models(): + checkpoints_list.clear() + checkpoint_alisases.clear() + + cmd_ckpt = shared.cmd_opts.ckpt + if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): + model_url = None + else: + model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" + + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + + if os.path.exists(cmd_ckpt): + checkpoint_info = CheckpointInfo(cmd_ckpt) + checkpoint_info.register() + + shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title + elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: + print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + + for filename in model_list: + checkpoint_info = CheckpointInfo(filename) + checkpoint_info.register() + + +def get_closet_checkpoint_match(search_string): + checkpoint_info = checkpoint_alisases.get(search_string, None) + if checkpoint_info is not None: + return checkpoint_info + + found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + + return None + + +def model_hash(filename): + """old hash that only looks at a small part of the file and is prone to collisions""" + + try: + with open(filename, "rb") as file: + import hashlib + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' + + +def select_checkpoint(): + model_checkpoint = shared.opts.sd_model_checkpoint + + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) + if checkpoint_info is not None: + return checkpoint_info + + if len(checkpoints_list) == 0: + print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) + if shared.cmd_opts.ckpt is not None: + print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) + print(f" - directory {model_path}", file=sys.stderr) + if shared.cmd_opts.ckpt_dir is not None: + print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) + print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr) + exit(1) + + checkpoint_info = next(iter(checkpoints_list.values())) + if model_checkpoint is not None: + print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr) + + return checkpoint_info + + +chckpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', +} + + +def transform_checkpoint_dict_key(k): + for text, replacement in chckpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + + return k + + +def get_state_dict_from_checkpoint(pl_sd): + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) + + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + + if new_key is not None: + sd[new_key] = v + + pl_sd.clear() + pl_sd.update(sd) + + return pl_sd + + +def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) + else: + pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) + + if print_global_state and "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + + sd = get_state_dict_from_checkpoint(pl_sd) + return sd + + +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") + + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") + + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) + + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None + + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model + + timer.record("apply half()") + + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") + + # clean up cache if limit is reached + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) + + model.sd_model_hash = sd_model_hash + model.sd_model_checkpoint = checkpoint_info.filename + model.sd_checkpoint_info = checkpoint_info + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 + + model.logvar = model.logvar.to(devices.device) # fix for training + + sd_vae.delete_base_vae() + sd_vae.clear_loaded_vae() + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") + + +def enable_midas_autodownload(): + """ + Gives the ldm.modules.midas.api.load_model function automatic downloading. + + When the 512-depth-ema model, and other future models like it, is loaded, + it calls midas.api.load_model to load the associated midas depth model. + This function applies a wrapper to download the model to the correct + location automatically. + """ + + midas_path = os.path.join(paths.models_path, 'midas') + + # stable-diffusion-stability-ai hard-codes the midas model path to + # a location that differs from where other scripts using this model look. + # HACK: Overriding the path here. + for k, v in midas.api.ISL_PATHS.items(): + file_name = os.path.basename(v) + midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name) + + midas_urls = { + "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt", + "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt", + } + + midas.api.load_model_inner = midas.api.load_model + + def load_model_wrapper(model_type): + path = midas.api.ISL_PATHS[model_type] + if not os.path.exists(path): + if not os.path.exists(midas_path): + mkdir(midas_path) + + print(f"Downloading midas model weights for {model_type} to {path}") + request.urlretrieve(midas_urls[model_type], path) + print(f"{model_type} downloaded") + + return midas.api.load_model_inner(model_type) + + midas.api.load_model = load_model_wrapper + + +def repair_config(sd_config): + + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True + + +sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' +sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): + from modules import lowvram, sd_hijack + checkpoint_info = checkpoint_info or select_checkpoint() + + if shared.sd_model: + sd_hijack.model_hijack.undo_hijack(shared.sd_model) + shared.sd_model = None + gc.collect() + devices.torch_gc() + + do_inpainting_hijack() + + timer = Timer() + + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + + timer.record("find config") + + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + + sd_model = None + try: + with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): + sd_model = instantiate_from_config(sd_config.model) + except Exception as e: + pass + + if sd_model is None: + print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) + sd_model = instantiate_from_config(sd_config.model) + + sd_model.used_config = checkpoint_config + + timer.record("create model") + + load_model_weights(sd_model, checkpoint_info, state_dict, timer) + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) + else: + sd_model.to(shared.device) + + timer.record("move model to device") + + sd_hijack.model_hijack.hijack(sd_model) + + timer.record("hijack") + + sd_model.eval() + shared.sd_model = sd_model + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + + timer.record("load textual inversion embeddings") + + script_callbacks.model_loaded_callback(sd_model) + + timer.record("scripts callbacks") + + print(f"Model loaded in {timer.summary()}.") + + return sd_model + + +def reload_model_weights(sd_model=None, info=None): + from modules import lowvram, devices, sd_hijack + checkpoint_info = info or select_checkpoint() + + if not sd_model: + sd_model = shared.sd_model + + if sd_model is None: # previous model load failed + current_checkpoint_info = None + else: + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: + return + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) + + sd_hijack.model_hijack.undo_hijack(sd_model) + + timer = Timer() + + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + + try: + load_model_weights(sd_model, checkpoint_info, state_dict, timer) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info, None, timer) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) + timer.record("move model to device") + + print(f"Weights loaded in {timer.summary()}.") + + return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py new file mode 100644 index 0000000000000000000000000000000000000000..91c21700417f615e8d14e96c4f7a3e89368b6381 --- /dev/null +++ b/modules/sd_models_config.py @@ -0,0 +1,112 @@ +import re +import os + +import torch + +from modules import shared, paths, sd_disable_initialization + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") + + +config_default = shared.sd_default_config +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") + + +def is_using_v_parameterization_for_sd2(state_dict): + """ + Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. + """ + + import ldm.modules.diffusionmodules.openaimodel + from modules import devices + + device = devices.cpu + + with sd_disable_initialization.DisableInitialization(): + unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( + use_checkpoint=True, + use_fp16=False, + image_size=32, + in_channels=4, + out_channels=4, + model_channels=320, + attention_resolutions=[4, 2, 1], + num_res_blocks=2, + channel_mult=[1, 2, 4, 4], + num_head_channels=64, + use_spatial_transformer=True, + use_linear_in_transformer=True, + transformer_depth=1, + context_dim=1024, + legacy=False + ) + unet.eval() + + with torch.no_grad(): + unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} + unet.load_state_dict(unet_sd, strict=True) + unet.to(device=device, dtype=torch.float) + + test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 + x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 + + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + + return out < -1 + + +def guess_model_config_from_state_dict(sd, filename): + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + + if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + return config_depth_model + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if diffusion_model_input.shape[1] == 9: + return config_sd2_inpainting + elif is_using_v_parameterization_for_sd2(sd): + return config_sd2v + else: + return config_sd2 + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..28c2136fe73ac2c093a02cbdee54dd0afd024d21 --- /dev/null +++ b/modules/sd_samplers.py @@ -0,0 +1,47 @@ +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared + +# imports for functions that previously were here and are used by other modules +from modules.sd_samplers_common import samples_to_image_grid, sample_to_image + +all_samplers = [ + *sd_samplers_kdiffusion.samplers_data_k_diffusion, + *sd_samplers_compvis.samplers_data_compvis, +] +all_samplers_map = {x.name: x for x in all_samplers} + +samplers = [] +samplers_for_img2img = [] +samplers_map = {} + + +def create_sampler(name, model): + if name is not None: + config = all_samplers_map.get(name, None) + else: + config = all_samplers[0] + + assert config is not None, f'bad sampler name: {name}' + + sampler = config.constructor(model) + sampler.config = config + + return sampler + + +def set_samplers(): + global samplers, samplers_for_img2img + + hidden = set(shared.opts.hide_samplers) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) + + samplers = [x for x in all_samplers if x.name not in hidden] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + + samplers_map.clear() + for sampler in all_samplers: + samplers_map[sampler.name.lower()] = sampler.name + for alias in sampler.aliases: + samplers_map[alias.lower()] = sampler.name + + +set_samplers() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a1aac7cf0aaf25375dcbc2b1ac0f5486ee683a6c --- /dev/null +++ b/modules/sd_samplers_common.py @@ -0,0 +1,62 @@ +from collections import namedtuple +import numpy as np +import torch +from PIL import Image +from modules import devices, processing, images, sd_vae_approx + +from modules.shared import opts, state +import modules.shared as shared + +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) + + +def setup_img2img_steps(p, steps=None): + if opts.img2img_fix_steps or steps is not None: + requested_steps = (steps or p.steps) + steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 + t_enc = requested_steps - 1 + else: + steps = p.steps + t_enc = int(min(p.denoising_strength, 0.999) * steps) + + return steps, t_enc + + +approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} + + +def single_sample_to_image(sample, approximation=None): + if approximation is None: + approximation = approximation_indexes.get(opts.show_progress_type, 0) + + if approximation == 2: + x_sample = sd_vae_approx.cheap_approximation(sample) + elif approximation == 1: + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + else: + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] + + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + return Image.fromarray(x_sample) + + +def sample_to_image(samples, index=0, approximation=None): + return single_sample_to_image(samples[index], approximation) + + +def samples_to_image_grid(samples, approximation=None): + return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) + + +def store_latent(decoded): + state.current_latent = decoded + + if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: + if not shared.parallel_processing_allowed: + shared.state.assign_current_image(sample_to_image(decoded)) + + +class InterruptedException(BaseException): + pass diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py new file mode 100644 index 0000000000000000000000000000000000000000..d03131cd49c27374e84e854777730541ed71061a --- /dev/null +++ b/modules/sd_samplers_compvis.py @@ -0,0 +1,160 @@ +import math +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms + +import numpy as np +import torch + +from modules.shared import state +from modules import sd_samplers_common, prompt_parser, shared + + +samplers_data_compvis = [ + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), +] + + +class VanillaStableDiffusionSampler: + def __init__(self, constructor, sd_model): + self.sampler = constructor(sd_model) + self.is_plms = hasattr(self.sampler, 'p_sample_plms') + self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim + self.mask = None + self.nmask = None + self.init_latent = None + self.sampler_noises = None + self.step = 0 + self.stop_at = None + self.eta = None + self.config = None + self.last_latent = None + + self.conditioning_key = sd_model.model.conditioning_key + + def number_of_needed_noises(self, p): + return 0 + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except sd_samplers_common.InterruptedException: + return self.last_latent + + def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + if self.stop_at is not None and self.step > self.stop_at: + raise sd_samplers_common.InterruptedException + + # Have to unwrap the inpainting conditioning here to perform pre-processing + image_conditioning = None + if isinstance(cond, dict): + image_conditioning = cond["c_concat"][0] + cond = cond["c_crossattn"][0] + unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + + assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' + cond = tensor + + # for DDIM, shapes must match, we can't just process cond and uncond independently; + # filling unconditional_conditioning with repeats of the last vector to match length is + # not 100% correct but should work well enough + if unconditional_conditioning.shape[1] < cond.shape[1]: + last_vector = unconditional_conditioning[:, -1:] + last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) + unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) + elif unconditional_conditioning.shape[1] > cond.shape[1]: + unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] + + if self.mask is not None: + img_orig = self.sampler.model.q_sample(self.init_latent, ts) + x_dec = img_orig * self.mask + self.nmask * x_dec + + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. + if image_conditioning is not None: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + + if self.mask is not None: + self.last_latent = self.init_latent * self.mask + self.nmask * res[1] + else: + self.last_latent = res[1] + + sd_samplers_common.store_latent(self.last_latent) + + self.step += 1 + state.sampling_step = self.step + shared.total_tqdm.update() + + return res + + def initialize(self, p): + self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim + if self.eta != 0.0: + p.extra_generation_params["Eta DDIM"] = self.eta + + for fieldname in ['p_sample_ddim', 'p_sample_plms']: + if hasattr(self.sampler, fieldname): + setattr(self.sampler, fieldname, self.p_sample_ddim_hook) + + self.mask = p.mask if hasattr(p, 'mask') else None + self.nmask = p.nmask if hasattr(p, 'nmask') else None + + def adjust_steps_if_invalid(self, p, num_steps): + if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + valid_step = 999 / (1000 // num_steps) + if valid_step == math.floor(valid_step): + return int(valid_step) + 1 + + return num_steps + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + steps = self.adjust_steps_if_invalid(p, steps) + self.initialize(p) + + self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) + x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) + + self.init_latent = x + self.last_latent = x + self.step = 0 + + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + self.initialize(p) + + self.init_latent = None + self.last_latent = x + self.step = 0 + + steps = self.adjust_steps_if_invalid(p, steps or p.steps) + + # Wrap the conditioning models with additional image conditioning for inpainting model + # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape + if image_conditioning is not None: + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} + + samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) + + return samples_ddim diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..528f513fe5a7dc3227ef238e7c3a5091b2f868b6 --- /dev/null +++ b/modules/sd_samplers_kdiffusion.py @@ -0,0 +1,357 @@ +from collections import deque +import torch +import inspect +import einops +import k_diffusion.sampling +from modules import prompt_parser, devices, sd_samplers_common + +from modules.shared import opts, state +import modules.shared as shared +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback + +samplers_k_diffusion = [ + ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), + ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), + ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), +] + +samplers_data_k_diffusion = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion + if hasattr(k_diffusion.sampling, funcname) +] + +sampler_extra_params = { + 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], +} + + +class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model): + super().__init__() + self.inner_model = model + self.mask = None + self.nmask = None + self.init_latent = None + self.step = 0 + self.image_cfg_scale = None + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) + + return denoised + + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + + if tensor.shape[1] == uncond.shape[1]: + if not is_edit_model: + cond_in = torch.cat([tensor, uncond]) + else: + cond_in = torch.cat([tensor, uncond, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + + if not is_edit_model: + c_crossattn = [tensor[a:b]] + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]}) + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") + + if opts.live_preview_content == "Prompt": + sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) + elif opts.live_preview_content == "Negative prompt": + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) + + if not is_edit_model: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + else: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) + + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + self.step += 1 + + return denoised + + +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + if x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + else: + return torch.randn_like(x) + + +class KDiffusionSampler: + def __init__(self, funcname, sd_model): + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.funcname = funcname + self.func = getattr(k_diffusion.sampling, self.funcname) + self.extra_params = sampler_extra_params.get(funcname, []) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None + self.last_latent = None + + self.conditioning_key = sd_model.model.conditioning_key + + def callback_state(self, d): + step = d['i'] + latent = d["denoised"] + if opts.live_preview_content == "Combined": + sd_samplers_common.store_latent(latent) + self.last_latent = latent + + if self.stop_at is not None and step > self.stop_at: + raise sd_samplers_common.InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except sd_samplers_common.InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p): + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else opts.eta_ancestral + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params["Eta"] = self.eta + + extra_params_kwargs['eta'] = self.eta + + return extra_params_kwargs + + def get_sigmas(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + if p.sampler_noise_scheduler_override: + sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + else: + sigmas = self.model_wrap.get_sigmas(steps) + + if discard_next_to_last_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + return sigmas + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + sigmas = self.get_sigmas(p, steps) + + sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'sigma_min' in parameters: + ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last + extra_params_kwargs['sigma_min'] = sigma_sched[-2] + if 'sigma_max' in parameters: + extra_params_kwargs['sigma_max'] = sigma_sched[0] + if 'n' in parameters: + extra_params_kwargs['n'] = len(sigma_sched) - 1 + if 'sigma_sched' in parameters: + extra_params_kwargs['sigma_sched'] = sigma_sched + if 'sigmas' in parameters: + extra_params_kwargs['sigmas'] = sigma_sched + + if self.funcname == 'sample_dpmpp_sde': + noise_sampler = self.create_noise_sampler(x, sigmas, p) + extra_params_kwargs['noise_sampler'] = noise_sampler + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + + sigmas = self.get_sigmas(p, steps) + + x = x * sigmas[0] + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'sigma_min' in parameters: + extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() + extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() + if 'n' in parameters: + extra_params_kwargs['n'] = steps + else: + extra_params_kwargs['sigmas'] = sigmas + + if self.funcname == 'sample_dpmpp_sde': + noise_sampler = self.create_noise_sampler(x, sigmas, p) + extra_params_kwargs['noise_sampler'] = noise_sampler + + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + return samples + diff --git a/modules/sd_vae.py b/modules/sd_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..9b00f76e9c62c794b3a27b36bae0f168ff4f5ab8 --- /dev/null +++ b/modules/sd_vae.py @@ -0,0 +1,216 @@ +import torch +import safetensors.torch +import os +import collections +from collections import namedtuple +from modules import paths, shared, devices, script_callbacks, sd_models +import glob +from copy import deepcopy + + +vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} +vae_dict = {} + + +base_vae = None +loaded_vae_file = None +checkpoint_info = None + +checkpoints_loaded = collections.OrderedDict() + +def get_base_vae(model): + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: + return base_vae + return None + + +def store_base_vae(model): + global base_vae, checkpoint_info + if checkpoint_info != model.sd_checkpoint_info: + assert not loaded_vae_file, "Trying to store non-base VAE!" + base_vae = deepcopy(model.first_stage_model.state_dict()) + checkpoint_info = model.sd_checkpoint_info + + +def delete_base_vae(): + global base_vae, checkpoint_info + base_vae = None + checkpoint_info = None + + +def restore_base_vae(model): + global loaded_vae_file + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + print("Restoring base VAE") + _load_vae_dict(model, base_vae) + loaded_vae_file = None + delete_base_vae() + + +def get_filename(filepath): + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), + ] + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir): + paths += [ + os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + + for filepath in candidates: + name = get_filename(filepath) + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if not is_automatic: + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file + # save_settings = False + + cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + + if vae_file: + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") + store_base_vae(model) + + vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) + _load_vae_dict(model, vae_dict_1) + + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + + # clean up cache if limit is reached + if cache_enabled: + while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model + checkpoints_loaded.popitem(last=False) # LRU + + # If vae used is not in dict, update it + # It will be removed on refresh though + vae_opt = get_filename(vae_file) + if vae_opt not in vae_dict: + vae_dict[vae_opt] = vae_file + + elif loaded_vae_file: + restore_base_vae(model) + + loaded_vae_file = vae_file + + +# don't call this from outside +def _load_vae_dict(model, vae_dict_1): + model.first_stage_model.load_state_dict(vae_dict_1) + model.first_stage_model.to(devices.dtype_vae) + + +def clear_loaded_vae(): + global loaded_vae_file + loaded_vae_file = None + + +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): + from modules import lowvram, devices, sd_hijack + + if not sd_model: + sd_model = shared.sd_model + + checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_file = checkpoint_info.filename + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" + + if loaded_vae_file == vae_file: + return + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) + + sd_hijack.model_hijack.undo_hijack(sd_model) + + load_vae(sd_model, vae_file, vae_source) + + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) + + print("VAE weights loaded.") + return sd_model diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..0027343a74500502b79bccd8d4b885ddcd7db568 --- /dev/null +++ b/modules/sd_vae_approx.py @@ -0,0 +1,58 @@ +import os + +import torch +from torch import nn +from modules import devices, paths + +sd_vae_approx_model = None + + +class VAEApprox(nn.Module): + def __init__(self): + super(VAEApprox, self).__init__() + self.conv1 = nn.Conv2d(4, 8, (7, 7)) + self.conv2 = nn.Conv2d(8, 16, (5, 5)) + self.conv3 = nn.Conv2d(16, 32, (3, 3)) + self.conv4 = nn.Conv2d(32, 64, (3, 3)) + self.conv5 = nn.Conv2d(64, 32, (3, 3)) + self.conv6 = nn.Conv2d(32, 16, (3, 3)) + self.conv7 = nn.Conv2d(16, 8, (3, 3)) + self.conv8 = nn.Conv2d(8, 3, (3, 3)) + + def forward(self, x): + extra = 11 + x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) + x = nn.functional.pad(x, (extra, extra, extra, extra)) + + for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: + x = layer(x) + x = nn.functional.leaky_relu(x, 0.1) + + return x + + +def model(): + global sd_vae_approx_model + + if sd_vae_approx_model is None: + sd_vae_approx_model = VAEApprox() + sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) + sd_vae_approx_model.eval() + sd_vae_approx_model.to(devices.device, devices.dtype) + + return sd_vae_approx_model + + +def cheap_approximation(sample): + # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 + + coefs = torch.tensor([ + [0.298, 0.207, 0.208], + [0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473], + ]).to(sample.device) + + x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) + + return x_sample diff --git a/modules/shared.py b/modules/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..805f9cc19cf9e529f19f1f94449a454b5050ec52 --- /dev/null +++ b/modules/shared.py @@ -0,0 +1,720 @@ +import argparse +import datetime +import json +import os +import sys +import time + +from PIL import Image +import gradio as gr +import tqdm + +import modules.interrogate +import modules.memmon +import modules.styles +import modules.devices as devices +from modules import localization, extensions, script_loading, errors, ui_components, shared_items +from modules.paths import models_path, script_path, data_path + + +demo = None + +sd_configs_path = os.path.join(script_path, "configs") +sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") +sd_model_file = os.path.join(script_path, 'model.ckpt') +default_sd_model_file = sd_model_file + +parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) +parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) +parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) +parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") +parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") +parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) +parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) +parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") +parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") +parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") +parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") +parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") +parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") +parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") +parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") +parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") +parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") +parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") +parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") +parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) +parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") +parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) +parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) +parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) +parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") +parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") +parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) +parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") +parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") +parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) +parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) +parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) +parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") +parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None) +parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') +parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") +parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) +parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) +parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) +parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) +parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) +parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") +parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") +parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") +parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) +parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) +parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) +parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) +parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") +parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") +parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) +parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) + + +script_loading.preload_extensions(extensions.extensions_dir, parser) +script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) + +cmd_opts = parser.parse_args() + +restricted_opts = { + "samples_filename_pattern", + "directories_filename_pattern", + "outdir_samples", + "outdir_txt2img_samples", + "outdir_img2img_samples", + "outdir_extras_samples", + "outdir_grids", + "outdir_txt2img_grids", + "outdir_save", +} + +ui_reorder_categories = [ + "inpaint", + "sampler", + "checkboxes", + "hires_fix", + "dimensions", + "cfg", + "seed", + "batch", + "override_settings", + "scripts", +] + +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access + +devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ + (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) + +device = devices.device +weight_load_location = None if cmd_opts.lowram else "cpu" + +batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) +parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram +xformers_available = False +config_filename = cmd_opts.ui_settings_file + +os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) +hypernetworks = {} +loaded_hypernetworks = [] + + +def reload_hypernetworks(): + from modules.hypernetworks import hypernetwork + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + + +class State: + skipped = False + interrupted = False + job = "" + job_no = 0 + job_count = 0 + processing_has_refined_job_count = False + job_timestamp = '0' + sampling_step = 0 + sampling_steps = 0 + current_latent = None + current_image = None + current_image_sampling_step = 0 + id_live_preview = 0 + textinfo = None + time_start = None + need_restart = False + server_start = None + + def skip(self): + self.skipped = True + + def interrupt(self): + self.interrupted = True + + def nextjob(self): + if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: + self.do_set_current_image() + + self.job_no += 1 + self.sampling_step = 0 + self.current_image_sampling_step = 0 + + def dict(self): + obj = { + "skipped": self.skipped, + "interrupted": self.interrupted, + "job": self.job, + "job_count": self.job_count, + "job_timestamp": self.job_timestamp, + "job_no": self.job_no, + "sampling_step": self.sampling_step, + "sampling_steps": self.sampling_steps, + } + + return obj + + def begin(self): + self.sampling_step = 0 + self.job_count = -1 + self.processing_has_refined_job_count = False + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.id_live_preview = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + self.time_start = time.time() + + devices.torch_gc() + + def end(self): + self.job = "" + self.job_count = 0 + + devices.torch_gc() + + def set_current_image(self): + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" + if not parallel_processing_allowed: + return + + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1: + self.do_set_current_image() + + def do_set_current_image(self): + if self.current_latent is None: + return + + import modules.sd_samplers + if opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 + + +state = State() +state.server_start = time.time() + +styles_filename = cmd_opts.styles_file +prompt_styles = modules.styles.StyleDatabase(styles_filename) + +interrogator = modules.interrogate.InterrogateModels("interrogate") + +face_restorers = [] + +class OptionInfo: + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): + self.default = default + self.label = label + self.component = component + self.component_args = component_args + self.onchange = onchange + self.section = section + self.refresh = refresh + + +def options_section(section_identifier, options_dict): + for k, v in options_dict.items(): + v.section = section_identifier + + return options_dict + + +def list_checkpoint_tiles(): + import modules.sd_models + return modules.sd_models.checkpoint_tiles() + + +def refresh_checkpoints(): + import modules.sd_models + return modules.sd_models.list_models() + + +def list_samplers(): + import modules.sd_samplers + return modules.sd_samplers.all_samplers + + +hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} + +options_templates = {} + +options_templates.update(options_section(('saving-images', "Saving images/grids"), { + "samples_save": OptionInfo(True, "Always save all generated images"), + "samples_format": OptionInfo('png', 'File format for images'), + "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs), + "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), + + "grid_save": OptionInfo(True, "Always save all generated image grids"), + "grid_format": OptionInfo('png', 'File format for grids'), + "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), + "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), + "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), + "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + + "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), + "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), + "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), + "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), + "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), + "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), + "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), + "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), + + "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), + "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), + "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), + + "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), + "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + +})) + +options_templates.update(options_section(('saving-paths', "Paths for saving"), { + "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), +})) + +options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { + "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), + "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), + "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), + "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs), + "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), +})) + +options_templates.update(options_section(('upscaling', "Upscaling"), { + "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), + "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), +})) + +options_templates.update(options_section(('face-restoration', "Face restoration"), { + "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), + "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), +})) + +options_templates.update(options_section(('system', "System"), { + "show_warnings": OptionInfo(False, "Show warnings in console."), + "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), + "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), +})) + +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), + "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), + "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), + "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), + "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), + "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."), + "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."), + "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), +})) + +options_templates.update(options_section(('sd', "Stable Diffusion"), { + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), + "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), + "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), + "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), + "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), +})) + +options_templates.update(options_section(('compatibility', "Compatibility"), { + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), + "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), +})) + +options_templates.update(options_section(('interrogate', "Interrogate Options"), { + "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), + "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), + "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), + "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), + "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), + "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), + "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), + "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), + "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), +})) + +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), +})) + +options_templates.update(options_section(('ui', "User interface"), { + "return_grid": OptionInfo(True, "Show grid in results for web"), + "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), + "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), + "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), + "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), + "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), + "font": OptionInfo("", "Font for image grids that have text"), + "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), + "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), + "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), + "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), +})) + +options_templates.update(options_section(('ui', "Live previews"), { + "show_progressbar": OptionInfo(True, "Show progressbar"), + "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), + "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), + "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), + "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), + "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") +})) + +options_templates.update(options_section(('sampler-params', "Sampler parameters"), { + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), + 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), +})) + +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable those extensions"), + "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), +})) + +options_templates.update() + + +class Options: + data = None + data_labels = options_templates + typemap = {int: float} + + def __init__(self): + self.data = {k: v.default for k, v in self.data_labels.items()} + + def __setattr__(self, key, value): + if self.data is not None: + if key in self.data or key in self.data_labels: + assert not cmd_opts.freeze_settings, "changing settings is disabled" + + info = opts.data_labels.get(key, None) + comp_args = info.component_args if info else None + if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + self.data[key] = value + return + + return super(Options, self).__setattr__(key, value) + + def __getattr__(self, item): + if self.data is not None: + if item in self.data: + return self.data[item] + + if item in self.data_labels: + return self.data_labels[item].default + + return super(Options, self).__getattribute__(item) + + def set(self, key, value): + """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" + + oldval = self.data.get(key, None) + if oldval == value: + return False + + try: + setattr(self, key, value) + except RuntimeError: + return False + + if self.data_labels[key].onchange is not None: + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False + + return True + + def save(self, filename): + assert not cmd_opts.freeze_settings, "saving settings is disabled" + + with open(filename, "w", encoding="utf8") as file: + json.dump(self.data, file, indent=4) + + def same_type(self, x, y): + if x is None or y is None: + return True + + type_x = self.typemap.get(type(x), type(x)) + type_y = self.typemap.get(type(y), type(y)) + + return type_x == type_y + + def load(self, filename): + with open(filename, "r", encoding="utf8") as file: + self.data = json.load(file) + + bad_settings = 0 + for k, v in self.data.items(): + info = self.data_labels.get(k, None) + if info is not None and not self.same_type(info.default, v): + print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr) + bad_settings += 1 + + if bad_settings > 0: + print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) + + def onchange(self, key, func, call=True): + item = self.data_labels.get(key) + item.onchange = func + + if call: + func() + + def dumpjson(self): + d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} + return json.dumps(d) + + def add_option(self, key, info): + self.data_labels[key] = info + + def reorder(self): + """reorder settings so that all items related to section always go together""" + + section_ids = {} + settings_items = self.data_labels.items() + for k, item in settings_items: + if item.section not in section_ids: + section_ids[item.section] = len(section_ids) + + self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + + def cast_value(self, key, value): + """casts an arbitrary to the same type as this setting's value with key + Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str) + """ + + if value is None: + return None + + default_value = self.data_labels[key].default + if default_value is None: + default_value = getattr(self, key, None) + if default_value is None: + return None + + expected_type = type(default_value) + if expected_type == bool and value == "False": + value = False + else: + value = expected_type(value) + + return value + + + +opts = Options() +if os.path.exists(config_filename): + opts.load(config_filename) + +settings_components = None +"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" + +latent_upscale_default_mode = "Latent" +latent_upscale_modes = { + "Latent": {"mode": "bilinear", "antialias": False}, + "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, + "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, + "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, + "Latent (nearest)": {"mode": "nearest", "antialias": False}, + "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, +} + +sd_upscalers = [] + +sd_model = None + +clip_model = None + +progress_print_out = sys.stdout + + +class TotalTQDM: + def __init__(self): + self._tqdm = None + + def reset(self): + self._tqdm = tqdm.tqdm( + desc="Total progress", + total=state.job_count * state.sampling_steps, + position=1, + file=progress_print_out + ) + + def update(self): + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.update() + + def updateTotal(self, new_total): + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.total = new_total + + def clear(self): + if self._tqdm is not None: + self._tqdm.close() + self._tqdm = None + + +total_tqdm = TotalTQDM() + +mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) +mem_mon.start() + + +def listfiles(dirname): + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] + return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" diff --git a/modules/shared_items.py b/modules/shared_items.py new file mode 100644 index 0000000000000000000000000000000000000000..e792a1349a2aaf763f3ede98479d0a2b5d92a454 --- /dev/null +++ b/modules/shared_items.py @@ -0,0 +1,23 @@ + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + + +def postprocessing_scripts(): + import modules.scripts + + return modules.scripts.scripts_postproc.scripts + + +def sd_vae_items(): + import modules.sd_vae + + return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) + + +def refresh_vae_list(): + import modules.sd_vae + + modules.sd_vae.refresh_vae_list() diff --git a/modules/styles.py b/modules/styles.py new file mode 100644 index 0000000000000000000000000000000000000000..990d562369b49c4c5ce593d571a50640d9966301 --- /dev/null +++ b/modules/styles.py @@ -0,0 +1,87 @@ +# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime +from __future__ import annotations + +import csv +import os +import os.path +import typing +import collections.abc as abc +import tempfile +import shutil + +if typing.TYPE_CHECKING: + # Only import this when code is being type-checked, it doesn't have any effect at runtime + from .processing import StableDiffusionProcessing + + +class PromptStyle(typing.NamedTuple): + name: str + prompt: str + negative_prompt: str + + +def merge_prompts(style_prompt: str, prompt: str) -> str: + if "{prompt}" in style_prompt: + res = style_prompt.replace("{prompt}", prompt) + else: + parts = filter(None, (prompt.strip(), style_prompt.strip())) + res = ", ".join(parts) + + return res + + +def apply_styles_to_prompt(prompt, styles): + for style in styles: + prompt = merge_prompts(style, prompt) + + return prompt + + +class StyleDatabase: + def __init__(self, path: str): + self.no_style = PromptStyle("None", "", "") + self.styles = {} + self.path = path + + self.reload() + + def reload(self): + self.styles.clear() + + if not os.path.exists(self.path): + return + + with open(self.path, "r", encoding="utf-8-sig", newline='') as file: + reader = csv.DictReader(file) + for row in reader: + # Support loading old CSV format with "name, text"-columns + prompt = row["prompt"] if "prompt" in row else row["text"] + negative_prompt = row.get("negative_prompt", "") + self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + + def get_style_prompts(self, styles): + return [self.styles.get(x, self.no_style).prompt for x in styles] + + def get_negative_style_prompts(self, styles): + return [self.styles.get(x, self.no_style).negative_prompt for x in styles] + + def apply_styles_to_prompt(self, prompt, styles): + return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) + + def apply_negative_styles_to_prompt(self, prompt, styles): + return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) + + def save_styles(self, path: str) -> None: + # Write to temporary file first, so we don't nuke the file if something goes wrong + fd, temp_path = tempfile.mkstemp(".csv") + with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: + # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, + # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) + writer.writeheader() + writer.writerows(style._asdict() for k, style in self.styles.items()) + + # Always keep a backup file around + if os.path.exists(path): + shutil.move(path, path + ".bak") + shutil.move(temp_path, path) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..055953236e5e1d1401feca93ca6a1cc342cb8595 --- /dev/null +++ b/modules/sub_quadratic_attention.py @@ -0,0 +1,214 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, List + + +def narrow_trunc( + input: Tensor, + dim: int, + start: int, + length: int +) -> Tensor: + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + + +class SummarizeChunk: + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + + +class ComputeQueryChunkAttn: + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = narrow_trunc( + key, + 1, + chunk_idx, + kv_chunk_size + ) + value_chunk = narrow_trunc( + value, + 1, + chunk_idx, + kv_chunk_size + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) + return hidden_states_slice + + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + + +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key: keys for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return narrow_trunc( + query, + 1, + chunk_idx, + min(query_chunk_size, q_tokens) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py new file mode 100644 index 0000000000000000000000000000000000000000..68e1103c514c5d2d75f23175126cfea0b3dcfca9 --- /dev/null +++ b/modules/textual_inversion/autocrop.py @@ -0,0 +1,341 @@ +import cv2 +import requests +import os +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + im_debug = im.copy() + + focus = focal_point(im_debug, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + results = [] + + results.append(im.crop(tuple(crop))) + + if settings.annotate_image: + d = ImageDraw.Draw(im_debug) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + results.append(im_debug) + if settings.destop_view_image: + im_debug.show() + + return results + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] + entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] + face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] + + pois = [] + + weight_pref_total = 0 + if len(corner_points) > 0: + weight_pref_total += settings.corner_points_weight + if len(entropy_points) > 0: + weight_pref_total += settings.entropy_points_weight + if len(face_points) > 0: + weight_pref_total += settings.face_points_weight + + corner_centroid = None + if len(corner_points) > 0: + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) + + entropy_centroid = None + if len(entropy_points) > 0: + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) + + face_centroid = None + if len(face_points) > 0: + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) + + average_point = poi_average(pois, settings) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) + + d.ellipse(average_point.bounding(max_size), outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + if settings.dnn_model_path is not None: + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.9, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0.33)), # face focus up/down is close to the top of the head + size = w, + weight = 1/len(faces[1]) + ) + ) + return results + else: + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] + return [] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.06, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] + + +def image_entropy(im): + # greyscale image entropy + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + +def centroid(pois): + x = [poi.x for poi in pois] + y = [poi.y for poi in pois] + return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) + + +def poi_average(pois, settings): + weight = 0.0 + x = 0.0 + y = 0.0 + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight + avg_x = round(weight and x / weight) + avg_y = round(weight and y / weight) + + return PointOfInterest(avg_x, avg_y) + + +def is_landscape(w, h): + return w > h + + +def is_portrait(w, h): + return h > w + + +def is_square(w, h): + return w == h + + +def download_and_cache_models(dirname): + download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' + model_file_name = 'face_detection_yunet.onnx' + + if not os.path.exists(dirname): + os.makedirs(dirname) + + cache_file = os.path.join(dirname, model_file_name) + if not os.path.exists(cache_file): + print(f"downloading face detection model from '{download_url}' to '{cache_file}'") + response = requests.get(download_url) + with open(cache_file, "wb") as f: + f.write(response.content) + + if os.path.exists(cache_file): + return cache_file + return None + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = face_points_weight + self.annotate_image = annotate_image + self.destop_view_image = False + self.dnn_model_path = dnn_model_path diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..af9fbcf288ccc7a5402ed1cae2da853022ab080d --- /dev/null +++ b/modules/textual_inversion/dataset.py @@ -0,0 +1,246 @@ +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset, DataLoader, Sampler +from torchvision import transforms +from collections import defaultdict +from random import shuffle, choices + +import random +import tqdm +from modules import devices, shared +import re + +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +re_numbers_at_start = re.compile(r"^[-\d]+\s*") + + +class DatasetEntry: + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): + self.filename = filename + self.filename_text = filename_text + self.weight = weight + self.latent_dist = latent_dist + self.latent_sample = latent_sample + self.cond = cond + self.cond_text = cond_text + self.pixel_values = pixel_values + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + + self.placeholder_token = placeholder_token + + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + + self.shuffle_tags = shuffle_tags + self.tag_drop_out = tag_drop_out + groups = defaultdict(list) + + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + alpha_channel = None + if shared.state.interrupted: + raise Exception("interrupted") + try: + image = Image.open(path) + #Currently does not work for single color transparency + #We would need to read image.info['transparency'] for that + if use_weight and 'A' in image.getbands(): + alpha_channel = image.getchannel('A') + image = image.convert('RGB') + if not varsize: + image = image.resize((width, height), PIL.Image.BICUBIC) + except Exception: + continue + + text_filename = os.path.splitext(path)[0] + ".txt" + filename = os.path.basename(path) + + if os.path.exists(text_filename): + with open(text_filename, "r", encoding="utf8") as file: + filename_text = file.read() + else: + filename_text = os.path.splitext(filename)[0] + filename_text = re.sub(re_numbers_at_start, '', filename_text) + if re_word: + tokens = re_word.findall(filename_text) + filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) + latent_sample = None + + with devices.autocast(): + latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) + + #Perform latent sampling, even for random sampling. + #We need the sample dimensions for the weights + if latent_sampling_method == "deterministic": + if isinstance(latent_dist, DiagonalGaussianDistribution): + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + else: + latent_sampling_method = "once" + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + + if use_weight and alpha_channel is not None: + channels, *latent_size = latent_sample.shape + weight_img = alpha_channel.resize(latent_size) + npweight = np.array(weight_img).astype(np.float32) + #Repeat for every channel in the latent sample + weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) + #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. + weight -= weight.min() + weight /= weight.mean() + elif use_weight: + #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later + weight = torch.ones(latent_sample.shape) + else: + weight = None + + if latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) + else: + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) + + if not (self.tag_drop_out != 0 or self.shuffle_tags): + entry.cond_text = self.create_text(filename_text) + + if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): + with devices.autocast(): + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + groups[image.size].append(len(self.dataset)) + self.dataset.append(entry) + del torchdata + del latent_dist + del latent_sample + del weight + + self.length = len(self.dataset) + self.groups = list(groups.values()) + assert self.length > 0, "No images have been found in the dataset." + self.batch_size = min(batch_size, self.length) + self.gradient_step = min(gradient_step, self.length // self.batch_size) + self.latent_sampling_method = latent_sampling_method + + if len(groups) > 1: + print("Buckets:") + for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): + print(f" {w}x{h}: {len(ids)}") + print() + + def create_text(self, filename_text): + text = random.choice(self.lines) + tags = filename_text.split(',') + if self.tag_drop_out != 0: + tags = [t for t in tags if random.random() > self.tag_drop_out] + if self.shuffle_tags: + random.shuffle(tags) + text = text.replace("[filewords]", ','.join(tags)) + text = text.replace("[name]", self.placeholder_token) + return text + + def __len__(self): + return self.length + + def __getitem__(self, i): + entry = self.dataset[i] + if self.tag_drop_out != 0 or self.shuffle_tags: + entry.cond_text = self.create_text(entry.filename_text) + if self.latent_sampling_method == "random": + entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) + return entry + + +class GroupedBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + super().__init__(data_source) + + n = len(data_source) + self.groups = data_source.groups + self.len = n_batch = n // batch_size + expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] + self.base = [int(e) // batch_size for e in expected] + self.n_rand_batches = nrb = n_batch - sum(self.base) + self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] + self.batch_size = batch_size + + def __len__(self): + return self.len + + def __iter__(self): + b = self.batch_size + + for g in self.groups: + shuffle(g) + + batches = [] + for g in self.groups: + batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) + for _ in range(self.n_rand_batches): + rand_group = choices(self.groups, self.probs)[0] + batches.append(choices(rand_group, k=b)) + + shuffle(batches) + + yield from batches + + +class PersonalizedDataLoader(DataLoader): + def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) + if latent_sampling_method == "random": + self.collate_fn = collate_wrapper_random + else: + self.collate_fn = collate_wrapper + + +class BatchLoader: + def __init__(self, data): + self.cond_text = [entry.cond_text for entry in data] + self.cond = [entry.cond for entry in data] + self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + if all(entry.weight is not None for entry in data): + self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) + else: + self.weight = None + #self.emb_index = [entry.emb_index for entry in data] + #print(self.latent_sample.device) + + def pin_memory(self): + self.latent_sample = self.latent_sample.pin_memory() + return self + +def collate_wrapper(batch): + return BatchLoader(batch) + +class BatchLoaderRandom(BatchLoader): + def __init__(self, data): + super().__init__(data) + + def pin_memory(self): + return self + +def collate_wrapper_random(batch): + return BatchLoaderRandom(batch) \ No newline at end of file diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..5593f88c799c53a8529ae8fd12a84c2e92396b26 --- /dev/null +++ b/modules/textual_inversion/image_embedding.py @@ -0,0 +1,220 @@ +import base64 +import json +import numpy as np +import zlib +from PIL import Image, PngImagePlugin, ImageDraw, ImageFont +from fonts.ttf import Roboto +import torch +from modules.shared import opts + + +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, obj) + + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, d): + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) + return d + + +def embedding_to_b64(data): + d = json.dumps(data, cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + + +def embedding_from_b64(data): + d = base64.b64decode(data) + return json.loads(d, cls=EmbeddingDecoder) + + +def lcg(m=2**32, a=1664525, c=1013904223, seed=0): + while True: + seed = (a * seed + c) % m + yield seed % 255 + + +def xor_block(block): + g = lcg() + randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F) + + +def style_block(block, sequence): + im = Image.new('RGB', (block.shape[1], block.shape[0])) + draw = ImageDraw.Draw(im) + i = 0 + for x in range(-6, im.size[0], 8): + for yi, y in enumerate(range(-6, im.size[1], 8)): + offset = 0 + if yi % 2 == 0: + offset = 4 + shade = sequence[i % len(sequence)] + i += 1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade)) + + fg = np.array(im).astype(np.uint8) & 0xF0 + + return block ^ fg + + +def insert_image_data_embed(image, data): + d = 3 + data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9) + data_np_ = np.frombuffer(data_compressed, np.uint8).copy() + data_np_high = data_np_ >> 4 + data_np_low = data_np_ & 0x0F + + h = image.size[1] + next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) + next_size = next_size + ((h*d)-(next_size % (h*d))) + + data_np_low = np.resize(data_np_low, next_size) + data_np_low = data_np_low.reshape((h, -1, d)) + + data_np_high = np.resize(data_np_high, next_size) + data_np_high = data_np_high.reshape((h, -1, d)) + + edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] + edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8) + + data_np_low = style_block(data_np_low, sequence=edge_style) + data_np_low = xor_block(data_np_low) + data_np_high = style_block(data_np_high, sequence=edge_style[::-1]) + data_np_high = xor_block(data_np_high) + + im_low = Image.fromarray(data_np_low, mode='RGB') + im_high = Image.fromarray(data_np_high, mode='RGB') + + background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0)) + background.paste(im_low, (0, 0)) + background.paste(image, (im_low.size[0]+1, 0)) + background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0)) + + return background + + +def crop_black(img, tol=0): + mask = (img > tol).all(2) + mask0, mask1 = mask.any(0), mask.any(1) + col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax() + row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end, col_start:col_end] + + +def extract_image_data_embed(image): + d = 3 + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F + black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0) + if black_cols[0].shape[0] < 2: + print('No Image data blocks found.') + return None + + data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8) + data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8) + + data_block_lower = xor_block(data_block_lower) + data_block_upper = xor_block(data_block_upper) + + data_block = (data_block_upper << 4) | (data_block_lower) + data_block = data_block.flatten().tobytes() + + data = zlib.decompress(data_block) + return json.loads(data, cls=EmbeddingDecoder) + + +def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None): + from math import cos + + image = srcimage.copy() + fontsize = 32 + if textfont is None: + try: + textfont = ImageFont.truetype(opts.font or Roboto, fontsize) + textfont = opts.font or Roboto + except Exception: + textfont = Roboto + + factor = 1.5 + gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0)) + for y in range(image.size[1]): + mag = 1-cos(y/image.size[1]*factor) + mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) + gradient.putpixel((0, y), (0, 0, 0, int(mag*255))) + image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) + + draw = ImageDraw.Draw(image) + + font = ImageFont.truetype(textfont, fontsize) + padding = 10 + + _, _, w, h = draw.textbbox((0, 0), title, font=font) + fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72) + font = ImageFont.truetype(textfont, fontsize) + _, _, w, h = draw.textbbox((0, 0), title, font=font) + draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230)) + + _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font) + fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + _, _, w, h = draw.textbbox((0, 0), footerMid, font=font) + fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + _, _, w, h = draw.textbbox((0, 0), footerRight, font=font) + fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + + font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right)) + + draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230)) + draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230)) + draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230)) + + return image + + +if __name__ == '__main__': + + testEmbed = Image.open('test_embedding.png') + data = extract_image_data_embed(testEmbed) + assert data is not None + + data = embedding_from_b64(testEmbed.text['sd-ti-embedding']) + assert data is not None + + image = Image.new('RGBA', (512, 512), (255, 255, 200, 255)) + cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') + + test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}} + + embedded_image = insert_image_data_embed(cap_image, test_embed) + + retrived_embed = extract_image_data_embed(embedded_image) + + assert str(retrived_embed) == str(test_embed) + + embedded_image2 = insert_image_data_embed(cap_image, retrived_embed) + + assert embedded_image == embedded_image2 + + g = lcg() + shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist() + + reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, + 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, + 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, + 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, + 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, + 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, + 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, + 204, 86, 73, 222, 44, 198, 118, 240, 97] + + assert shared_random == reference_random + + hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) + + assert 12731374 == hunna_kay_random_sum diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..f63fc72ff8d4269967dc0f92b61a278d068324b5 --- /dev/null +++ b/modules/textual_inversion/learn_schedule.py @@ -0,0 +1,81 @@ +import tqdm + + +class LearnScheduleIterator: + def __init__(self, learn_rate, max_steps, cur_step=0): + """ + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 + """ + + pairs = learn_rate.split(',') + self.rates = [] + self.it = 0 + self.maxit = 0 + try: + for i, pair in enumerate(pairs): + if not pair.strip(): + continue + tmp = pair.split(':') + if len(tmp) == 2: + step = int(tmp[1]) + if step > cur_step: + self.rates.append((float(tmp[0]), min(step, max_steps))) + self.maxit += 1 + if step > max_steps: + return + elif step == -1: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 + return + else: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 + return + assert self.rates + except (ValueError, AssertionError): + raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') + + + def __iter__(self): + return self + + def __next__(self): + if self.it < self.maxit: + self.it += 1 + return self.rates[self.it - 1] + else: + raise StopIteration + + +class LearnRateScheduler: + def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): + self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) + (self.learn_rate, self.end_step) = next(self.schedules) + self.verbose = verbose + + if self.verbose: + print(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + self.finished = False + + def step(self, step_number): + if step_number < self.end_step: + return False + + try: + (self.learn_rate, self.end_step) = next(self.schedules) + except StopIteration: + self.finished = True + return False + return True + + def apply(self, optimizer, step_number): + if not self.step(step_number): + return + + if self.verbose: + tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + for pg in optimizer.param_groups: + pg['lr'] = self.learn_rate + diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..734a4b6f463d11685933771825b2792c8270e53a --- /dev/null +++ b/modules/textual_inversion/logging.py @@ -0,0 +1,24 @@ +import datetime +import json +import os + +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} +saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} +saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..2239cb842ed5d25f9bde19f10509220ad6c01c18 --- /dev/null +++ b/modules/textual_inversion/preprocess.py @@ -0,0 +1,230 @@ +import os +from PIL import Image, ImageOps +import math +import platform +import sys +import tqdm +import time + +from modules import paths, shared, images, deepbooru +from modules.shared import opts, cmd_opts +from modules.textual_inversion import autocrop + + +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): + try: + if process_caption: + shared.interrogator.load() + + if process_caption_deepbooru: + deepbooru.model.start() + + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) + + finally: + + if process_caption: + shared.interrogator.send_blip_to_ram() + + if process_caption_deepbooru: + deepbooru.model.stop() + + +def listfiles(dirname): + return os.listdir(dirname) + + +class PreprocessParams: + src = None + dstdir = None + subindex = 0 + flip = False + process_caption = False + process_caption_deepbooru = False + preprocess_txt_action = None + + +def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None): + caption = "" + + if params.process_caption: + caption += shared.interrogator.generate_caption(image) + + if params.process_caption_deepbooru: + if len(caption) > 0: + caption += ", " + caption += deepbooru.model.tag_multi(image) + + filename_part = params.src + filename_part = os.path.splitext(filename_part)[0] + filename_part = os.path.basename(filename_part) + + basename = f"{index:05}-{params.subindex}-{filename_part}" + image.save(os.path.join(params.dstdir, f"{basename}.png")) + + if params.preprocess_txt_action == 'prepend' and existing_caption: + caption = existing_caption + ' ' + caption + elif params.preprocess_txt_action == 'append' and existing_caption: + caption = caption + ' ' + existing_caption + elif params.preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + + if len(caption) > 0: + with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: + file.write(caption) + + params.subindex += 1 + + +def save_pic(image, index, params, existing_caption=None): + save_pic_with_caption(image, index, params, existing_caption=existing_caption) + + if params.flip: + save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption) + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + +# not using torchvision.transforms.CenterCrop because it doesn't allow float regions +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + + +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) + wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1], + default=None + ) + return wh and center_crop(image, *wh) + + +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): + width = process_width + height = process_height + src = os.path.abspath(process_src) + dst = os.path.abspath(process_dst) + split_threshold = max(0.0, min(1.0, split_threshold)) + overlap_ratio = max(0.0, min(0.9, overlap_ratio)) + + assert src != dst, 'same directory specified as source and destination' + + os.makedirs(dst, exist_ok=True) + + files = listfiles(src) + + shared.state.job = "preprocess" + shared.state.textinfo = "Preprocessing..." + shared.state.job_count = len(files) + + params = PreprocessParams() + params.dstdir = dst + params.flip = process_flip + params.process_caption = process_caption + params.process_caption_deepbooru = process_caption_deepbooru + params.preprocess_txt_action = preprocess_txt_action + + pbar = tqdm.tqdm(files) + for index, imagefile in enumerate(pbar): + params.subindex = 0 + filename = os.path.join(src, imagefile) + try: + img = Image.open(filename).convert("RGB") + except Exception: + continue + + description = f"Preprocessing [Image {index}/{len(files)}]" + pbar.set_description(description) + shared.state.textinfo = description + + params.src = filename + + existing_caption = None + existing_caption_filename = os.path.splitext(filename)[0] + '.txt' + if os.path.exists(existing_caption_filename): + with open(existing_caption_filename, 'r', encoding="utf8") as file: + existing_caption = file.read() + + if shared.state.interrupted: + break + + if img.height > img.width: + ratio = (img.width * height) / (img.height * width) + inverse_xy = False + else: + ratio = (img.height * width) / (img.width * height) + inverse_xy = True + + process_default_resize = True + + if process_split and ratio < 1.0 and ratio <= split_threshold: + for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio): + save_pic(splitted, index, params, existing_caption=existing_caption) + process_default_resize = False + + if process_focal_crop and img.height != img.width: + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) + except Exception as e: + print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) + + autocrop_settings = autocrop.Settings( + crop_width = width, + crop_height = height, + face_points_weight = process_focal_crop_face_weight, + entropy_points_weight = process_focal_crop_entropy_weight, + corner_points_weight = process_focal_crop_edges_weight, + annotate_image = process_focal_crop_debug, + dnn_model_path = dnn_model_path, + ) + for focal in autocrop.crop_image(img, autocrop_settings): + save_pic(focal, index, params, existing_caption=existing_caption) + process_default_resize = False + + if process_multicrop: + cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) + if cropped is not None: + save_pic(cropped, index, params, existing_caption=existing_caption) + else: + print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)") + process_default_resize = False + + if process_default_resize: + img = images.resize_image(1, img, width, height) + save_pic(img, index, params, existing_caption=existing_caption) + + shared.state.nextjob() diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png new file mode 100644 index 0000000000000000000000000000000000000000..5a570c7be308aa1043c6b5e1d0c53df3227e001c --- /dev/null +++ b/modules/textual_inversion/test_embedding.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ceb3de4098040013be6ce7169b0c0e67c2de86a8cfb43d02b16013f6af2d352e +size 489220 diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..c63c7d1dda0840fb337cedcd95d16e05e60935a7 --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,657 @@ +import os +import sys +import traceback +import inspect +from collections import namedtuple + +import torch +import tqdm +import html +import datetime +import csv +import safetensors.torch + +import numpy as np +from PIL import Image, PngImagePlugin +from torch.utils.tensorboard import SummaryWriter + +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint +import modules.textual_inversion.dataset +from modules.textual_inversion.learn_schedule import LearnRateScheduler + +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay +from modules.textual_inversion.logging import save_settings_to_file + + +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.shape = None + self.vectors = 0 + self.cached_checksum = None + self.sd_checkpoint = None + self.sd_checkpoint_name = None + self.optimizer_state_dict = None + self.filename = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + "sd_checkpoint": self.sd_checkpoint, + "sd_checkpoint_name": self.sd_checkpoint_name, + } + + torch.save(embedding_data, filename) + + if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: + optimizer_saved_dict = { + 'hash': self.checksum(), + 'optimizer_state_dict': self.optimizer_state_dict, + } + torch.save(optimizer_saved_dict, filename + '.optim') + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + + +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) + + +class EmbeddingDatabase: + def __init__(self): + self.ids_lookup = {} + self.word_embeddings = {} + self.skipped_embeddings = {} + self.expected_shape = -1 + self.embedding_dirs = {} + self.previously_displayed_embeddings = () + + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + + def clear_embedding_dirs(self): + self.embedding_dirs.clear() + + def register_embedding(self, embedding, model): + self.word_embeddings[embedding.name] = embedding + + ids = model.cond_stage_model.tokenize([embedding.name])[0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + + self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) + + return embedding + + def get_expected_shape(self): + vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) + return vec.shape[1] + + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() + + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) + else: + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + data = safetensors.torch.load_file(path, device="cpu") + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + embedding.filename = path + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding + + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, dirs, fns in os.walk(embdir.path, followlinks=True): + for fn in fns: + try: + fullfn = os.path.join(root, fn) + + if os.stat(fullfn).st_size == 0: + continue + + self.load_from_file(fullfn, fn) + except Exception: + print(f"Error loading embedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for path, embdir in self.embedding_dirs.items(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for path, embdir in self.embedding_dirs.items(): + self.load_from_dir(embdir) + embdir.update() + + displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) + if self.previously_displayed_embeddings != displayed_embeddings: + self.previously_displayed_embeddings = displayed_embeddings + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None, None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding, len(ids) + + return None, None + + +def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): + cond_model = shared.sd_model.cond_stage_model + + with devices.autocast(): + cond_model([""]) # will send cond model to GPU if lowvram/medvram is active + + #cond_model expects at least some text, so we provide '*' as backup. + embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + #Only copy if we provided an init_text, otherwise keep vectors as zeros + if init_text: + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def write_loss(log_directory, filename, step, epoch_len, values): + if shared.opts.training_write_csv_every == 0: + return + + if step % shared.opts.training_write_csv_every != 0: + return + write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True + + with open(os.path.join(log_directory, filename), "a+", newline='') as fout: + csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) + + if write_csv_header: + csv_writer.writeheader() + + epoch = (step - 1) // epoch_len + epoch_step = (step - 1) % epoch_len + + csv_writer.writerow({ + "step": step, + "epoch": epoch, + "epoch_step": epoch_step, + **values, + }) + +def tensorboard_setup(log_directory): + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + return SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + +def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num): + tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step) + tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) + +def tensorboard_add_scaler(tensorboard_writer, tag, value, step): + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) + +def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): + # Convert a pil image to a torch tensor + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) + + tensorboard_writer.add_image(tag, img_tensor, global_step=step) + +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): + assert model_name, f"{name} not selected" + assert learn_rate, "Learning rate is empty or 0" + assert isinstance(batch_size, int), "Batch size must be integer" + assert batch_size > 0, "Batch size must be positive" + assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" + assert gradient_step > 0, "Gradient accumulation step must be positive" + assert data_root, "Dataset directory is empty" + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" + assert steps, "Max steps is empty or 0" + assert isinstance(steps, int), "Max steps must be integer" + assert steps > 0, "Max steps must be positive" + assert isinstance(save_model_every, int), "Save {name} must be integer" + assert save_model_every >= 0, "Save {name} must be positive or 0" + assert isinstance(create_image_every, int), "Create image must be integer" + assert create_image_every >= 0, "Create image must be positive or 0" + if save_model_every or create_image_every: + assert log_directory, "Log directory is empty" + + +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + save_embedding_every = save_embedding_every or 0 + create_image_every = create_image_every or 0 + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path + + shared.state.job = "train-embedding" + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) + unload = shared.opts.unload_models_when_training + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() + + initial_step = embedding.step or 0 + if initial_step >= steps: + shared.state.textinfo = "Model has already been trained beyond specified max steps" + return embedding, filename + + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ + torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ + None + if clip_grad: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) + # dataset loading may take a while, so input validations and early returns should be done before this + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + old_parallel_processing_allowed = shared.parallel_processing_allowed + + if shared.opts.training_enable_tensorboard: + tensorboard_writer = tensorboard_setup(log_directory) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight) + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) + + latent_sampling_method = ds.latent_sampling_method + + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) + + if unload: + shared.parallel_processing_allowed = False + shared.sd_model.first_stage_model.to(devices.cpu) + + embedding.vec.requires_grad = True + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) + if shared.opts.save_optimizer_state: + optimizer_state_dict = None + if os.path.exists(filename + '.optim'): + optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') + if embedding.checksum() == optimizer_saved_dict.get('hash', None): + optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + + if optimizer_state_dict is not None: + optimizer.load_state_dict(optimizer_state_dict) + print("Loaded existing optimizer from checkpoint") + else: + print("No saved optimizer exists in checkpoint") + + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + + last_saved_file = "" + last_saved_image = "" + forced_filename = "" + embedding_yet_to_be_embedded = False + + is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} + img_c = None + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + sd_hijack_checkpoint.add() + + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + if shared.state.interrupted: + break + + if clip_grad: + clip_grad_sched.step(embedding.step) + + with devices.autocast(): + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if use_weight: + w = batch.weight.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) + + if is_training_inpainting_model: + if img_c is None: + img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) + + cond = {"c_concat": [img_c], "c_crossattn": [c]} + else: + cond = c + + if use_weight: + loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step + del w + else: + loss = shared.sd_model.forward(x, cond)[0] / gradient_step + del x + + _loss_step += loss.item() + scaler.scale(loss).backward() + + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + + if clip_grad: + clip_grad(embedding.vec, clip_grad_sched.learn_rate) + + scaler.step(optimizer) + scaler.update() + embedding.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // steps_per_epoch + epoch_step = embedding.step % steps_per_epoch + + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" + pbar.set_description(description) + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None + + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.assign_current_image(image) + + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) + + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) + + title = "<{}>".format(data.get('name', '???')) + + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' + + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.shorthash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) + + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False + + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

    +Loss: {loss_step:.7f}
    +Step: {steps_done}
    +Last prompt: {html.escape(batch.cond_text[0])}
    +Last saved embedding: {html.escape(last_saved_file)}
    +Last saved image: {html.escape(last_saved_image)}
    +

    +""" + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + pass + finally: + pbar.leave = False + pbar.close() + shared.sd_model.first_stage_model.to(devices.device) + shared.parallel_processing_allowed = old_parallel_processing_allowed + sd_hijack_checkpoint.remove() + + return embedding, filename + + +def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): + old_embedding_name = embedding.name + old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None + old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None + old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None + try: + embedding.sd_checkpoint = checkpoint.shorthash + embedding.sd_checkpoint_name = checkpoint.model_name + if remove_cached_checksum: + embedding.cached_checksum = None + embedding.name = embedding_name + embedding.optimizer_state_dict = optimizer.state_dict() + embedding.save(filename) + except: + embedding.sd_checkpoint = old_sd_checkpoint + embedding.sd_checkpoint_name = old_sd_checkpoint_name + embedding.name = old_embedding_name + embedding.cached_checksum = old_cached_checksum + raise diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..35c4feeff455b6bd2b699dd72b27f852932c78c2 --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,45 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared + + +def create_embedding(name, initialization_text, nvpt, overwrite_old): + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def preprocess(*args): + modules.textual_inversion.preprocess.preprocess(*args) + + return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" + + +def train_embedding(*args): + + assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' + + apply_optimizations = shared.opts.training_xattention_optimizations + try: + if not apply_optimizations: + sd_hijack.undo_optimizations() + + embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + if not apply_optimizations: + sd_hijack.apply_optimizations() + diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..57a4f17a16b4cb46fafb690b1180c054db284a17 --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,35 @@ +import time + + +class Timer: + def __init__(self): + self.start = time.time() + self.records = {} + self.total = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def record(self, category, extra_time=0): + e = self.elapsed() + if category not in self.records: + self.records[category] = 0 + + self.records[category] += e + extra_time + self.total += e + extra_time + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [x for x in self.records.items() if x[1] >= 0.1] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res diff --git a/modules/txt2img.py b/modules/txt2img.py new file mode 100644 index 0000000000000000000000000000000000000000..16841d0f2f3095ef27667ef90e5c0f1baeef733d --- /dev/null +++ b/modules/txt2img.py @@ -0,0 +1,69 @@ +import modules.scripts +from modules import sd_samplers +from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ + StableDiffusionProcessingImg2Img, process_images +from modules.shared import opts, cmd_opts +import modules.shared as shared +import modules.processing as processing +from modules.ui import plaintext_to_html + + +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args): + override_settings = create_override_settings_dict(override_settings_texts) + + p = StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, + outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, + prompt=prompt, + styles=prompt_styles, + negative_prompt=negative_prompt, + seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, + seed_enable_extras=seed_enable_extras, + sampler_name=sd_samplers.samplers[sampler_index].name, + batch_size=batch_size, + n_iter=n_iter, + steps=steps, + cfg_scale=cfg_scale, + width=width, + height=height, + restore_faces=restore_faces, + tiling=tiling, + enable_hr=enable_hr, + denoising_strength=denoising_strength if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, + hr_second_pass_steps=hr_second_pass_steps, + hr_resize_x=hr_resize_x, + hr_resize_y=hr_resize_y, + override_settings=override_settings, + ) + + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args + + if cmd_opts.enable_console_prompts: + print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + + processed = modules.scripts.scripts_txt2img.run(p, *args) + + if processed is None: + processed = process_images(p) + + p.close() + + shared.total_tqdm.clear() + + generation_info_js = processed.js() + if opts.samples_log_stdout: + print(generation_info_js) + + if opts.do_not_show_images: + processed.images = [] + + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) diff --git a/modules/ui.py b/modules/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..0516c6436081f41bc34652932fa3d71f5c35decd --- /dev/null +++ b/modules/ui.py @@ -0,0 +1,1798 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import sys +import tempfile +import time +import traceback +from functools import partial, reduce +import warnings + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path, data_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text +import modules.extras + +warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.wrap.cover-bg .z-20::before { content:"" } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 +switch_values_symbol = '\U000021C5' # ⇅ + + +def plaintext_to_html(text): + return ui_common.plaintext_to_html(text) + + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] + + +def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): + if mode in {0, 1, 3, 4}: + return [interrogation_function(ii_singles[mode]), None] + elif mode == 2: + return [interrogation_function(ii_singles[mode]["image"]), None] + elif mode == 5: + assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + images = shared.listfiles(ii_input_dir) + print(f"Will process {len(images)} images.") + if ii_output_dir != "": + os.makedirs(ii_output_dir, exist_ok=True) + else: + ii_output_dir = ii_input_dir + + for image in images: + img = Image.open(image) + filename = os.path.basename(image) + left, _ = os.path.splitext(filename) + print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) + + return [gr.update(), None] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + return gr.update() if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr.update() if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + text, _ = extra_networks.parse_prompt(text) + + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)") + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): + with gr.Row(elem_id=f"{id_part}_generate_box"): + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(elem_id=f"{id_part}_tools"): + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") + + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") + negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + with gr.Row(elem_id=f"{id_part}_styles_row"): + prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) + create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") + + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button + + +def setup_progressbar(*args, **kwargs): + pass + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return getattr(opts, key) + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + return ui_common.create_output_panel(tabname, outdir) + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)): + yield category + + +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + +def create_override_settings_dropdown(tabname, row): + dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) + + dropdown.change( + fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + inputs=[dropdown], + outputs=[dropdown], + ) + + return dropdown + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='compact', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes", variant="compact"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "override_settings": + with FormRow(elem_id="txt2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('txt2img', row) + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + dummy_component, + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_styles, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + override_settings, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + )) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + + with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='compact', elem_id="img2img_settings"): + copy_image_buttons = [] + copy_image_destinations = {} + + def add_copy_image_controls(tab_name, elem): + with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): + gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") + + for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + if name == tab_name: + gr.Button(title, interactive=False) + copy_image_destinations[name] = elem + continue + + button = gr.Button(title) + copy_image_buttons.append((button, name, elem)) + + with gr.Tabs(elem_id="mode_img2img"): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
    Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + f"

    Process images in a directory on the same machine where the server is running." + + f"
    Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
    Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

    " + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js="switch_to_"+name.replace(" ", "_"), + inputs=[], + outputs=[], + ) + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + with FormRow(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes", variant="compact"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "override_settings": + with FormRow(elem_id="img2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('img2img', row) + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + elif category == "inpaint": + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + image_cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + img2img_batch_inpaint_mask_dir, + override_settings, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + interrogate_args = dict( + _js="get_img2img_tab_index", + inputs=[ + dummy_component, + img2img_batch_input_dir, + img2img_batch_output_dir, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + init_img_inpaint, + ], + outputs=[img2img_prompt, dummy_component], + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) + + img2img_interrogate.click( + fn=lambda *args: process_interrogate(interrogate, *args), + **interrogate_args, + ) + + img2img_deepbooru.click( + fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), + **interrogate_args, + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_styles, img2img_prompt_styles], + ) + + for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, styles], + outputs=[prompt, negative_prompt, styles], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (image_cfg_scale, "Image CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + )) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + ui_postprocessing.create_ui() + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + + for tabname, button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image, + )) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + def update_interp_description(value): + interp_description_css = "

    {}

    " + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='compact'): + interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") + + with FormRow(elem_id="modelmerger_models"): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description]) + + with FormRow(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + with FormRow(): + with gr.Column(): + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with FormRow(): + bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") + + with FormRow(): + discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + + with gr.Row(): + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') + + with gr.Column(variant='compact', elem_id="modelmerger_results_container"): + with gr.Group(elem_id="modelmerger_results_panel"): + modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

    See wiki for detailed explanation.

    ") + + with gr.Row(variant="compact").style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Column(visible=False) as process_multicrop_col: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") + process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") + with gr.Row(): + process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") + process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") + with gr.Row(): + process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") + process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + process_multicrop.change( + fn=lambda show: gr_show(show), + inputs=[process_multicrop], + outputs=[process_multicrop_col], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

    Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

    ") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(elem_id='ti_gallery_container'): + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + dummy_component, + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + process_multicrop, + process_multicrop_mindim, + process_multicrop_maxdim, + process_multicrop_minarea, + process_multicrop_maxarea, + process_multicrop_objective, + process_multicrop_threshold, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + dummy_component, + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + use_weight, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + dummy_component, + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + use_weight, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + shared.settings_components = component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return get_value_for_setting(key), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + current_row = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + gr.Group() + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + with gr.TabItem("Licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings", variant="compact"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.connect_paste_params_buttons() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + text_settings.change( + fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"), + inputs=[], + outputs=[image_cfg_scale], + ) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [get_value_for_setting(key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] + return results + + modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result]) + modelmerger_merge.click( + fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), + _js='modelmerger', + inputs=[ + dummy_component, + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + config_source, + bake_in_vae, + discard_weights, + ], + outputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + modelmerger_result, + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + pass + + # this warning is generally not useful; + # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + def check_dropdown(val): + if getattr(x, 'multiselect', False): + return all([value in x.choices for value in val]) + else: + return val in x.choices + + apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + # Required as a workaround for change() event not triggering when loading values from ui-config.json + interp_description.value = update_interp_description(interp_method.value) + + return demo + + +def reload_javascript(): + head = f'\n' + + inline = f"{localization.localization_js(shared.opts.localization)};" + if cmd_opts.theme is not None: + inline += f"set_theme('{cmd_opts.theme}');" + + for script in modules.scripts.list_scripts("javascript", ".js"): + head += f'\n' + + head += f'\n' + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{head}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {getattr(torch, '__long_version__',torch.__version__)} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} + •  +checkpoint: N/A +""" diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 0000000000000000000000000000000000000000..fd047f318879ebb038981aeaa0f96549e1c91bcf --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,206 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import subprocess as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images + +folder_symbol = '\U0001f4c2' # 📂 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text): + text = "

    " + "
    \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

    " + return text + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + for paste_tabname, paste_button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery + )) + + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/modules/ui_components.py b/modules/ui_components.py new file mode 100644 index 0000000000000000000000000000000000000000..284ca0cf2d6951be1b2cedeffe5fb6f11000c78b --- /dev/null +++ b/modules/ui_components.py @@ -0,0 +1,58 @@ +import gradio as gr + + +class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", **kwargs) + + def get_block_name(self): + return "button" + + +class ToolButtonTop(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool-top", **kwargs) + + def get_block_name(self): + return "button" + + +class FormRow(gr.Row, gr.components.FormComponent): + """Same as gr.Row but fits inside gradio forms""" + + def get_block_name(self): + return "row" + + +class FormGroup(gr.Group, gr.components.FormComponent): + """Same as gr.Row but fits inside gradio forms""" + + def get_block_name(self): + return "group" + + +class FormHTML(gr.HTML, gr.components.FormComponent): + """Same as gr.HTML but fits inside gradio forms""" + + def get_block_name(self): + return "html" + + +class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): + """Same as gr.ColorPicker but fits inside gradio forms""" + + def get_block_name(self): + return "colorpicker" + + +class DropdownMulti(gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4308ef028c55dfdfb22ec6dc36702f274445a3 --- /dev/null +++ b/modules/ui_extensions.py @@ -0,0 +1,354 @@ +import json +import os.path +import shutil +import sys +import time +import traceback + +import git + +import gradio as gr +import html +import shutil +import errno + +from modules import extensions, shared, paths +from modules.call_queue import wrap_gradio_gpu_call + +available_extensions = {"extensions": []} + + +def check_access(): + assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags" + + +def apply_and_restart(disable_list, update_list): + check_access() + + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + update = json.loads(update_list) + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + + update = set(update) + + for ext in extensions.extensions: + if ext.name not in update: + continue + + try: + ext.fetch_and_reset_hard() + except Exception: + print(f"Error getting updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + shared.state.interrupt() + shared.state.need_restart = True + + +def check_updates(id_task, disable_list): + check_access() + + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.state.job_count = len(exts) + + for ext in exts: + shared.state.textinfo = ext.name + + try: + ext.check_updates() + except Exception: + print(f"Error checking updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + shared.state.nextjob() + + return extension_table(), "" + + +def extension_table(): + code = f""" + + + + + + + + + + + """ + + for ext in extensions.extensions: + remote = f"""{html.escape("built-in" if ext.is_builtin else ext.remote or '')}""" + + if ext.can_update: + ext_status = f"""""" + else: + ext_status = ext.status + + code += f""" + + + + + {ext_status} + + """ + + code += """ + +
    ExtensionURLVersionUpdate
    {remote}{ext.version}
    + """ + + return code + + +def normalize_git_url(url): + if url is None: + return "" + + url = url.replace(".git", "") + return url + + +def install_extension_from_url(dirname, url): + check_access() + + assert url, 'No URL specified' + + if dirname is None or dirname == "": + *parts, last_part = url.split('/') + last_part = normalize_git_url(last_part) + + dirname = last_part + + target_dir = os.path.join(extensions.extensions_dir, dirname) + assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' + + normalized_url = normalize_git_url(url) + assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' + + tmpdir = os.path.join(paths.data_path, "tmp", dirname) + + try: + shutil.rmtree(tmpdir, True) + + repo = git.Repo.clone_from(url, tmpdir) + repo.remote().fetch() + + try: + os.rename(tmpdir, target_dir) + except OSError as err: + # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it + # Shouldn't cause any new issues at least but we probably want to handle it there too. + if err.errno == errno.EXDEV: + # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems + # Since we can't use a rename, do the slower but more versitile shutil.move() + shutil.move(tmpdir, target_dir) + else: + # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled. + raise(err) + + import launch + launch.run_extension_installer(target_dir) + + extensions.list_extensions() + return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] + finally: + shutil.rmtree(tmpdir, True) + + +def install_extension_from_index(url, hide_tags, sort_column): + ext_table, message = install_extension_from_url(None, url) + + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) + + return code, ext_table, message + + +def refresh_available_extensions(url, hide_tags, sort_column): + global available_extensions + + import urllib.request + with urllib.request.urlopen(url) as response: + text = response.read() + + available_extensions = json.loads(text) + + code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) + + return url, code, gr.CheckboxGroup.update(choices=tags), '' + + +def refresh_available_extensions_for_tags(hide_tags, sort_column): + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) + + return code, '' + + +sort_ordering = [ + # (reverse, order_by_function) + (True, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('name', 'z')), + (True, lambda x: x.get('name', 'z')), + (False, lambda x: 'z'), +] + + +def refresh_available_extensions_from_data(hide_tags, sort_column): + extlist = available_extensions["extensions"] + installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + + tags = available_extensions.get("tags", {}) + tags_to_hide = set(hide_tags) + hidden = 0 + + code = f""" + + + + + + + + + + """ + + sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0] + + for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): + name = ext.get("name", "noname") + added = ext.get('added', 'unknown') + url = ext.get("url", None) + description = ext.get("description", "") + extension_tags = ext.get("tags", []) + + if url is None: + continue + + existing = installed_extension_urls.get(normalize_git_url(url), None) + extension_tags = extension_tags + ["installed"] if existing else extension_tags + + if len([x for x in extension_tags if x in tags_to_hide]) > 0: + hidden += 1 + continue + + install_code = f"""""" + + tags_text = ", ".join([f"{x}" for x in extension_tags]) + + code += f""" + + + + + + + """ + + for tag in [x for x in extension_tags if x not in tags]: + tags[tag] = tag + + code += """ + +
    ExtensionDescriptionAction
    {html.escape(name)}
    {tags_text}
    {html.escape(description)}

    Added: {html.escape(added)}

    {install_code}
    + """ + + if hidden > 0: + code += f"

    Extension hidden: {hidden}

    " + + return code, list(tags) + + +def create_ui(): + import modules.ui + + with gr.Blocks(analytics_enabled=False) as ui: + with gr.Tabs(elem_id="tabs_extensions") as tabs: + with gr.TabItem("Installed"): + + with gr.Row(elem_id="extensions_installed_top"): + apply = gr.Button(value="Apply and restart UI", variant="primary") + check = gr.Button(value="Check for updates") + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + + info = gr.HTML() + extensions_table = gr.HTML(lambda: extension_table()) + + apply.click( + fn=apply_and_restart, + _js="extensions_apply", + inputs=[extensions_disabled_list, extensions_update_list], + outputs=[], + ) + + check.click( + fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), + _js="extensions_check", + inputs=[info, extensions_disabled_list], + outputs=[extensions_table, info], + ) + + with gr.TabItem("Available"): + with gr.Row(): + refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") + available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False) + extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) + install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + + with gr.Row(): + hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") + + install_result = gr.HTML() + available_extensions_table = gr.HTML() + + refresh_available_extensions_button.click( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[available_extensions_index, hide_tags, sort_column], + outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], + ) + + install_extension_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), + inputs=[extension_to_install, hide_tags, sort_column], + outputs=[available_extensions_table, extensions_table, install_result], + ) + + hide_tags.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column], + outputs=[available_extensions_table, install_result] + ) + + sort_column.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column], + outputs=[available_extensions_table, install_result] + ) + + with gr.TabItem("Install from URL"): + install_url = gr.Text(label="URL for extension's git repository") + install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") + install_button = gr.Button(value="Install", variant="primary") + install_result = gr.HTML(elem_id="extension_install_result") + + install_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), + inputs=[install_dirname, install_url], + outputs=[extensions_table, install_result], + ) + + return ui diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..71f1d81f23eb92a1ed2fdfd5a2d986ffd13e8e9f --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,251 @@ +import glob +import os.path +import urllib.parse +from pathlib import Path + +from modules import shared +import gradio as gr +import json +import html + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] +allowed_dirs = set() + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" + + extra_pages.append(page) + allowed_dirs.clear() + allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + + +def add_pages_to_demo(app): + def fetch_file(filename: str = ""): + from starlette.responses import FileResponse + + if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]): + raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + ext = os.path.splitext(filename)[1].lower() + if ext not in (".png", ".jpg"): + raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.") + + # would profit from returning 304 + return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) + + app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.name = title.lower() + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + + def refresh(self): + pass + + def link_preview(self, filename): + return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + + def search_terms_from_path(self, filename, possible_directories=None): + abspath = os.path.abspath(filename) + + for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + parentdir = os.path.abspath(parentdir) + if abspath.startswith(parentdir): + return abspath[len(parentdir):].replace('\\', '/') + + return "" + + def create_html(self, tabname): + view = shared.opts.extra_networks_default_view + items_html = '' + + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] + + is_empty = len(os.listdir(x)) == 0 + if not is_empty and not subdir.endswith("/"): + subdir = subdir + "/" + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + subdirs_html = "".join([f""" + +""" for subdir in subdirs]) + + for item in self.list_items(): + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + self_name_id = self.name.replace(" ", "_") + + res = f""" +
    +{subdirs_html} +
    +
    +{items_html} +
    +""" + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + + args = { + "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', + "prompt": item.get("prompt", None), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "card_clicked": onclick, + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "search_term": item.get("search_term", ""), + } + + return self.card_page.format(**args) + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + def toggle_visibility(is_visible): + is_visible = not is_visible + return is_visible, gr.update(visible=is_visible) + + state_visible = gr.State(value=False) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) + button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return child_path.startswith(parent_path) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..04097a7945911b8f73be49963325134da929d10f --- /dev/null +++ b/modules/ui_extra_networks_checkpoints.py @@ -0,0 +1,39 @@ +import html +import json +import os +import urllib.parse + +from modules import shared, ui_extra_networks, sd_models + + +class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Checkpoints') + + def refresh(self): + shared.refresh_checkpoints() + + def list_items(self): + checkpoint: sd_models.CheckpointInfo + for name, checkpoint in sd_models.checkpoints_list.items(): + path, ext = os.path.splitext(checkpoint.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = self.link_preview(file) + break + + yield { + "name": checkpoint.name_for_extra, + "filename": path, + "preview": preview, + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 0000000000000000000000000000000000000000..57851088734daa77c904cccf120b2c00c44c994e --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,36 @@ +import json +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = self.link_preview(file) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "search_term": self.search_terms_from_path(path), + "prompt": json.dumps(f""), + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..bb64eb81e294ebb6a94adf94d069a013e37331f1 --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,34 @@ +import json +import os + +from modules import ui_extra_networks, sd_hijack + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + preview_file = path + ".preview.png" + + preview = None + if os.path.isfile(preview_file): + preview = self.link_preview(preview_file) + + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": preview, + "search_term": self.search_terms_from_path(embedding.filename), + "prompt": json.dumps(embedding.name), + "local_preview": path + ".preview.png", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..b418d9553059d946c34b4e251062c46fe098fd75 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + + +def create_ui(): + tab_index = gr.State(value=0) + + with gr.Row().style(equal_height=False, variant='compact'): + with gr.Column(variant='compact'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + script_inputs = scripts.scripts_postproc.setup_ui() + + with gr.Column(): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py new file mode 100644 index 0000000000000000000000000000000000000000..21945235ef360823fafbc92c358078f343e874da --- /dev/null +++ b/modules/ui_tempdir.py @@ -0,0 +1,82 @@ +import os +import tempfile +from collections import namedtuple +from pathlib import Path + +import gradio as gr + +from PIL import PngImagePlugin + +from modules import shared + + +Savedfile = namedtuple("Savedfile", ["name"]) + + +def register_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 + gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} + + if hasattr(gradio, 'temp_dirs'): # gradio 3.9 + gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} + + +def check_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): + return any([filename in fileset for fileset in gradio.temp_file_sets]) + + if hasattr(gradio, 'temp_dirs'): + return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) + + return False + + +def save_pil_to_file(pil_image, dir=None): + already_saved_as = getattr(pil_image, 'already_saved_as', None) + if already_saved_as and os.path.isfile(already_saved_as): + register_tmp_file(shared.demo, already_saved_as) + + file_obj = Savedfile(already_saved_as) + return file_obj + + if shared.opts.temp_dir != "": + dir = shared.opts.temp_dir + + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + +def on_tmpdir_changed(): + if shared.opts.temp_dir == "" or shared.demo is None: + return + + os.makedirs(shared.opts.temp_dir, exist_ok=True) + + register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) + + +def cleanup_tmpdr(): + temp_dir = shared.opts.temp_dir + if temp_dir == "" or not os.path.isdir(temp_dir): + return + + for root, dirs, files in os.walk(temp_dir, topdown=False): + for name in files: + _, extension = os.path.splitext(name) + if extension != ".png": + continue + + filename = os.path.join(root, name) + os.remove(filename) diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..e2eaa7308af0091b6e8f407e889b2e446679e149 --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,145 @@ +import os +from abc import abstractmethod + +import PIL +import numpy as np +import torch +from PIL import Image + +import modules.shared +from modules import modelloader, shared + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) + + +class Upscaler: + name = None + model_path = None + model_name = None + model_url = None + enable = True + filter = None + model = None + user_path = None + scalers: [] + tile = True + + def __init__(self, create_dirs=False): + self.mod_pad_h = None + self.tile_size = modules.shared.opts.ESRGAN_tile + self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap + self.device = modules.shared.device + self.img = None + self.output = None + self.scale = 1 + self.half = not modules.shared.cmd_opts.no_half + self.pre_pad = 0 + self.mod_scale = None + + if self.model_path is None and self.name: + self.model_path = os.path.join(shared.models_path, self.name) + if self.model_path and create_dirs: + os.makedirs(self.model_path, exist_ok=True) + + try: + import cv2 + self.can_tile = True + except: + pass + + @abstractmethod + def do_upscale(self, img: PIL.Image, selected_model: str): + return img + + def upscale(self, img: PIL.Image, scale, selected_model: str = None): + self.scale = scale + dest_w = int(img.width * scale) + dest_h = int(img.height * scale) + + for i in range(3): + shape = (img.width, img.height) + + img = self.do_upscale(img, selected_model) + + if shape == (img.width, img.height): + break + + if img.width >= dest_w and img.height >= dest_h: + break + + if img.width != dest_w or img.height != dest_h: + img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) + + return img + + @abstractmethod + def load_model(self, path: str): + pass + + def find_models(self, ext_filter=None) -> list: + return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) + + def update_status(self, prompt): + print(f"\nextras: {prompt}", file=shared.progress_print_out) + + +class UpscalerData: + name = None + data_path = None + scale: int = 4 + scaler: Upscaler = None + model: None + + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + self.name = name + self.data_path = path + self.local_data_path = path + self.scaler = upscaler + self.scale = scale + self.model = model + + +class UpscalerNone(Upscaler): + name = "None" + scalers = [] + + def load_model(self, path): + pass + + def do_upscale(self, img, selected_model=None): + return img + + def __init__(self, dirname=None): + super().__init__(False) + self.scalers = [UpscalerData("None", None, self)] + + +class UpscalerLanczos(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Lanczos" + self.scalers = [UpscalerData("Lanczos", None, self)] + + +class UpscalerNearest(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Nearest" + self.scalers = [UpscalerData("Nearest", None, self)] diff --git a/modules/xlmr.py b/modules/xlmr.py new file mode 100644 index 0000000000000000000000000000000000000000..beab3fdf55e7bcffd96f3b36679e7a90c0f390dc --- /dev/null +++ b/modules/xlmr.py @@ -0,0 +1,137 @@ +from transformers import BertPreTrainedModel,BertModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 768 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + self.pooler = lambda x: x[:,0] + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # last module outputs + sequence_output = outputs[0] + + + # project every module + sequence_output_ln = self.pre_LN(sequence_output) + + # pooler + pooler_output = self.pooler(sequence_output_ln) + pooler_output = self.transformation(pooler_output) + projection_state = self.transformation(outputs.last_hidden_state) + + return { + 'pooler_output':pooler_output, + 'last_hidden_state':outputs.last_hidden_state, + 'hidden_states':outputs.hidden_states, + 'attentions':outputs.attentions, + 'projection_state':projection_state, + 'sequence_out': sequence_output + } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d53f08933a7f02e6266f9ae4fc8f23dfa119ecd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,32 @@ +blendmodes +accelerate +basicsr +fonts +font-roboto +gfpgan +gradio==3.16.2 +invisible-watermark +numpy +omegaconf +opencv-contrib-python +requests +piexif +Pillow +pytorch_lightning==1.7.7 +realesrgan +scikit-image>=0.19 +timm==0.4.12 +transformers==4.25.1 +torch +einops +jsonmerge +clean-fid +resize-right +torchdiffeq +kornia +lark +inflection +GitPython +torchsde +safetensors +psutil diff --git a/requirements_versions.txt b/requirements_versions.txt new file mode 100644 index 0000000000000000000000000000000000000000..331d0fe86513ae9e0dbe3cd8299878de78281e90 --- /dev/null +++ b/requirements_versions.txt @@ -0,0 +1,30 @@ +blendmodes==2022 +transformers==4.25.1 +accelerate==0.12.0 +basicsr==1.4.2 +gfpgan==1.3.8 +gradio==3.16.2 +numpy==1.23.3 +Pillow==9.4.0 +realesrgan==0.3.0 +torch +omegaconf==2.2.3 +pytorch_lightning==1.7.6 +scikit-image==0.19.2 +fonts +font-roboto +timm==0.6.7 +piexif==1.1.3 +einops==0.4.1 +jsonmerge==1.8.0 +clean-fid==0.1.29 +resize-right==0.0.2 +torchdiffeq==0.2.3 +kornia==0.6.7 +lark==1.1.2 +inflection==0.5.1 +GitPython==3.1.27 +torchsde==0.2.5 +safetensors==0.2.7 +httpcore<=0.15 +fastapi==0.90.1 diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a7b75245e46e3ad2f8783e4cc42d37bbe720a8 --- /dev/null +++ b/run.py @@ -0,0 +1,6 @@ +# import os + +# if not os.path.exists("extensions/deforum"): +# exec(open("deforum.sh").read()) + +exec(open("run.sh").read()) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9752b2f4cd97e0dbc0f28d7aad76f7b2e32406ad --- /dev/null +++ b/run.sh @@ -0,0 +1,5 @@ +#!usr/bin/env bash + +[ -d extensions/deforum ] || git clone https://github.com/deforum-art/deforum-for-automatic1111-webui extensions/deforum + +. webui.sh diff --git a/screenshot.png b/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..c9b754b5e18cad82ef2042579194ac5965176ec4 --- /dev/null +++ b/screenshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1e371a6d937b5504ceb69f0577086d562d94f42288e314821094ce7c79b2c09 +size 420577 diff --git a/script.js b/script.js new file mode 100644 index 0000000000000000000000000000000000000000..97e0bfcf9fa3cc6b5823d86b0949ab9c947c6418 --- /dev/null +++ b/script.js @@ -0,0 +1,102 @@ +function gradioApp() { + const elems = document.getElementsByTagName('gradio-app') + const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot + return !!gradioShadowRoot ? gradioShadowRoot : document; +} + +function get_uiCurrentTab() { + return gradioApp().querySelector('#tabs button:not(.border-transparent)') +} + +function get_uiCurrentTabContent() { + return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])') +} + +uiUpdateCallbacks = [] +uiLoadedCallbacks = [] +uiTabChangeCallbacks = [] +optionsChangedCallbacks = [] +let uiCurrentTab = null + +function onUiUpdate(callback){ + uiUpdateCallbacks.push(callback) +} +function onUiLoaded(callback){ + uiLoadedCallbacks.push(callback) +} +function onUiTabChange(callback){ + uiTabChangeCallbacks.push(callback) +} +function onOptionsChanged(callback){ + optionsChangedCallbacks.push(callback) +} + +function runCallback(x, m){ + try { + x(m) + } catch (e) { + (console.error || console.log).call(console, e.message, e); + } +} +function executeCallbacks(queue, m) { + queue.forEach(function(x){runCallback(x, m)}) +} + +var executedOnLoaded = false; + +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m){ + if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + + executeCallbacks(uiUpdateCallbacks, m); + const newTab = get_uiCurrentTab(); + if ( newTab && ( newTab !== uiCurrentTab ) ) { + uiCurrentTab = newTab; + executeCallbacks(uiTabChangeCallbacks); + } + }); + mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) +}); + +/** + * Add a ctrl+enter as a shortcut to start a generation + */ +document.addEventListener('keydown', function(e) { + var handled = false; + if (e.key !== undefined) { + if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; + } else if (e.keyCode !== undefined) { + if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; + } + if (handled) { + button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + if (button) { + button.click(); + } + e.preventDefault(); + } +}) + +/** + * checks that a UI element is not in another hidden element or tab content + */ +function uiElementIsVisible(el) { + let isVisible = !el.closest('.\\!hidden'); + if ( ! isVisible ) { + return false; + } + + while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) { + if ( ! isVisible ) { + return false; + } else if ( el.parentElement ) { + el = el.parentElement + } else { + break; + } + } + return isVisible; +} diff --git a/scripts/custom_code.py b/scripts/custom_code.py new file mode 100644 index 0000000000000000000000000000000000000000..d29113e671070d10bdb7d7a23059255681fa0ef2 --- /dev/null +++ b/scripts/custom_code.py @@ -0,0 +1,41 @@ +import modules.scripts as scripts +import gradio as gr + +from modules.processing import Processed +from modules.shared import opts, cmd_opts, state + +class Script(scripts.Script): + + def title(self): + return "Custom code" + + def show(self, is_img2img): + return cmd_opts.allow_code + + def ui(self, is_img2img): + code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code")) + + return [code] + + + def run(self, p, code): + assert cmd_opts.allow_code, '--allow-code option must be enabled' + + display_result_data = [[], -1, ""] + + def display(imgs, s=display_result_data[1], i=display_result_data[2]): + display_result_data[0] = imgs + display_result_data[1] = s + display_result_data[2] = i + + from types import ModuleType + compiled = compile(code, '', 'exec') + module = ModuleType("testmodule") + module.__dict__.update(globals()) + module.p = p + module.display = display + exec(compiled, module.__dict__) + + return Processed(p, *display_result_data) + + \ No newline at end of file diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py new file mode 100644 index 0000000000000000000000000000000000000000..2572443f97478ad984e8128e3749939946e7f59b --- /dev/null +++ b/scripts/img2imgalt.py @@ -0,0 +1,216 @@ +from collections import namedtuple + +import numpy as np +from tqdm import trange + +import modules.scripts as scripts +import gradio as gr + +from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common +from modules.processing import Processed +from modules.shared import opts, cmd_opts, state + +import torch +import k_diffusion as K + +from PIL import Image +from torch import autocast +from einops import rearrange, repeat + + +def find_noise_for_image(p, cond, uncond, cfg_scale, steps): + x = p.init_latent + + s_in = x.new_ones([x.shape[0]]) + dnw = K.external.CompVisDenoiser(shared.sd_model) + sigmas = dnw.get_sigmas(steps).flip(0) + + shared.state.sampling_steps = steps + + for i in trange(1, len(sigmas)): + shared.state.sampling_step += 1 + + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + + image_conditioning = torch.cat([p.image_conditioning] * 2) + cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} + + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] + t = dnw.sigma_to_t(sigma_in) + + eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale + + d = (x - denoised) / sigmas[i] + dt = sigmas[i] - sigmas[i - 1] + + x = x + d * dt + + sd_samplers_common.store_latent(x) + + # This shouldn't be necessary, but solved some VRAM issues + del x_in, sigma_in, cond_in, c_out, c_in, t, + del eps, denoised_uncond, denoised_cond, denoised, d, dt + + shared.state.nextjob() + + return x / x.std() + + +Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"]) + + +# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736 +def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): + x = p.init_latent + + s_in = x.new_ones([x.shape[0]]) + dnw = K.external.CompVisDenoiser(shared.sd_model) + sigmas = dnw.get_sigmas(steps).flip(0) + + shared.state.sampling_steps = steps + + for i in trange(1, len(sigmas)): + shared.state.sampling_step += 1 + + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + + image_conditioning = torch.cat([p.image_conditioning] * 2) + cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} + + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] + + if i == 1: + t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) + else: + t = dnw.sigma_to_t(sigma_in) + + eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale + + if i == 1: + d = (x - denoised) / (2 * sigmas[i]) + else: + d = (x - denoised) / sigmas[i - 1] + + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + + sd_samplers_common.store_latent(x) + + # This shouldn't be necessary, but solved some VRAM issues + del x_in, sigma_in, cond_in, c_out, c_in, t, + del eps, denoised_uncond, denoised_cond, denoised, d, dt + + shared.state.nextjob() + + return x / sigmas[-1] + + +class Script(scripts.Script): + def __init__(self): + self.cache = None + + def title(self): + return "img2img alternative test" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + info = gr.Markdown(''' + * `CFG Scale` should be 2 or lower. + ''') + + override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler")) + + override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt")) + original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt")) + original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt")) + + override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps")) + st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st")) + + override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength")) + + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg")) + randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness")) + sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) + + return [ + info, + override_sampler, + override_prompt, original_prompt, original_negative_prompt, + override_steps, st, + override_strength, + cfg, randomness, sigma_adjustment, + ] + + def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): + # Override + if override_sampler: + p.sampler_name = "Euler" + if override_prompt: + p.prompt = original_prompt + p.negative_prompt = original_negative_prompt + if override_steps: + p.steps = st + if override_strength: + p.denoising_strength = 1.0 + + def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): + lat = (p.init_latent.cpu().numpy() * 10).astype(int) + + same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \ + and self.cache.original_prompt == original_prompt \ + and self.cache.original_negative_prompt == original_negative_prompt \ + and self.cache.sigma_adjustment == sigma_adjustment + same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 + + if same_everything: + rec_noise = self.cache.noise + else: + shared.state.job_count += 1 + cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) + uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt]) + if sigma_adjustment: + rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st) + else: + rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) + self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) + + rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) + + combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) + + sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) + + sigmas = sampler.model_wrap.get_sigmas(p.steps) + + noise_dt = combined_noise - (p.init_latent / sigmas[0]) + + p.seed = p.seed + 1 + + return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) + + p.sample = sample_extra + + p.extra_generation_params["Decode prompt"] = original_prompt + p.extra_generation_params["Decode negative prompt"] = original_negative_prompt + p.extra_generation_params["Decode CFG scale"] = cfg + p.extra_generation_params["Decode steps"] = st + p.extra_generation_params["Randomness"] = randomness + p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment + + processed = processing.process_images(p) + + return processed + diff --git a/scripts/loopback.py b/scripts/loopback.py new file mode 100644 index 0000000000000000000000000000000000000000..ec1f85e5891cc97d15ff06faaf749e29d98e362d --- /dev/null +++ b/scripts/loopback.py @@ -0,0 +1,98 @@ +import numpy as np +from tqdm import trange + +import modules.scripts as scripts +import gradio as gr + +from modules import processing, shared, sd_samplers, images +from modules.processing import Processed +from modules.sd_samplers import samplers +from modules.shared import opts, cmd_opts, state +from modules import deepbooru + + +class Script(scripts.Script): + def title(self): + return "Loopback" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) + denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor")) + append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None") + + return [loops, denoising_strength_change_factor, append_interrogation] + + def run(self, p, loops, denoising_strength_change_factor, append_interrogation): + processing.fix_seed(p) + batch_count = p.n_iter + p.extra_generation_params = { + "Denoising strength change factor": denoising_strength_change_factor, + } + + p.batch_size = 1 + p.n_iter = 1 + + output_images, info = None, None + initial_seed = None + initial_info = None + + grids = [] + all_images = [] + original_init_image = p.init_images + original_prompt = p.prompt + state.job_count = loops * batch_count + + initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] + + for n in range(batch_count): + history = [] + + # Reset to original init image at the start of each batch + p.init_images = original_init_image + + for i in range(loops): + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + + if opts.img2img_color_correction: + p.color_corrections = initial_color_corrections + + if append_interrogation != "None": + p.prompt = original_prompt + ", " if original_prompt != "" else "" + if append_interrogation == "CLIP": + p.prompt += shared.interrogator.interrogate(p.init_images[0]) + elif append_interrogation == "DeepBooru": + p.prompt += deepbooru.model.tag(p.init_images[0]) + + state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" + + processed = processing.process_images(p) + + if initial_seed is None: + initial_seed = processed.seed + initial_info = processed.info + + init_img = processed.images[0] + + p.init_images = [init_img] + p.seed = processed.seed + 1 + p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) + history.append(processed.images[0]) + + grid = images.image_grid(history, rows=1) + if opts.grid_save: + images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) + + grids.append(grid) + all_images += history + + if opts.return_grid: + all_images = grids + all_images + + processed = Processed(p, all_images, initial_seed, initial_info) + + return processed diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py new file mode 100644 index 0000000000000000000000000000000000000000..0906da6ae7da2281e83a517f2278f6cf1315c518 --- /dev/null +++ b/scripts/outpainting_mk_2.py @@ -0,0 +1,283 @@ +import math + +import numpy as np +import skimage + +import modules.scripts as scripts +import gradio as gr +from PIL import Image, ImageDraw + +from modules import images, processing, devices +from modules.processing import Processed, process_images +from modules.shared import opts, cmd_opts, state + + +# this function is taken from https://github.com/parlance-zz/g-diffuser-bot +def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05): + # helper fft routines that keep ortho normalization and auto-shift before and after fft + def _fft2(data): + if data.ndim > 2: # has channels + out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho") + out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) + else: # one channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") + out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) + + return out_fft + + def _ifft2(data): + if data.ndim > 2: # has channels + out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:, :, c] + out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho") + out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) + else: # one channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") + out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) + + return out_ifft + + def _get_gaussian_window(width, height, std=3.14, mode=0): + window_scale_x = float(width / min(width, height)) + window_scale_y = float(height / min(width, height)) + + window = np.zeros((width, height)) + x = (np.arange(width) / width * 2. - 1.) * window_scale_x + for y in range(height): + fy = (y / height * 2. - 1.) * window_scale_y + if mode == 0: + window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std) + else: + window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian + + return window + + def _get_masked_window_rgb(np_mask_grey, hardness=1.): + np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) + if hardness != 1.: + hardened = np_mask_grey[:] ** hardness + else: + hardened = np_mask_grey[:] + for c in range(3): + np_mask_rgb[:, :, c] = hardened[:] + return np_mask_rgb + + width = _np_src_image.shape[0] + height = _np_src_image.shape[1] + num_channels = _np_src_image.shape[2] + + np_src_image = _np_src_image[:] * (1. - np_mask_rgb) + np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.) + img_mask = np_mask_grey > 1e-6 + ref_mask = np_mask_grey < 1e-3 + + windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey)) + windowed_image /= np.max(windowed_image) + windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color + + src_fft = _fft2(windowed_image) # get feature statistics from masked src img + src_dist = np.absolute(src_fft) + src_phase = src_fft / src_dist + + # create a generator with a static seed to make outpainting deterministic / only follow global seed + rng = np.random.default_rng(0) + + noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise + noise_rgb = rng.random((width, height, num_channels)) + noise_grey = (np.sum(noise_rgb, axis=2) / 3.) + noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter + for c in range(num_channels): + noise_rgb[:, :, c] += (1. - color_variation) * noise_grey + + noise_fft = _fft2(noise_rgb) + for c in range(num_channels): + noise_fft[:, :, c] *= noise_window + noise_rgb = np.real(_ifft2(noise_fft)) + shaped_noise_fft = _fft2(noise_rgb) + shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping + + brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now + contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. + + # scikit-image is used for histogram matching, very convenient! + shaped_noise = np.real(_ifft2(shaped_noise_fft)) + shaped_noise -= np.min(shaped_noise) + shaped_noise /= np.max(shaped_noise) + shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1) + shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb + + matched_noise = shaped_noise[:] + + return np.clip(matched_noise, 0., 1.) + + + +class Script(scripts.Script): + def title(self): + return "Outpainting mk2" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + info = gr.HTML("

    Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

    ") + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) + + return [info, pixels, mask_blur, direction, noise_q, color_variation] + + def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): + initial_seed_and_info = [None, None] + + process_width = p.width + process_height = p.height + + p.mask_blur = mask_blur*4 + p.inpaint_full_res = False + p.inpainting_fill = 1 + p.do_not_save_samples = True + p.do_not_save_grid = True + + left = pixels if "left" in direction else 0 + right = pixels if "right" in direction else 0 + up = pixels if "up" in direction else 0 + down = pixels if "down" in direction else 0 + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + left + right) / 64) * 64 + target_h = math.ceil((init_img.height + up + down) / 64) * 64 + + if left > 0: + left = left * (target_w - init_img.width) // (left + right) + + if right > 0: + right = target_w - init_img.width - left + + if up > 0: + up = up * (target_h - init_img.height) // (up + down) + + if down > 0: + down = target_h - init_img.height - up + + def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): + is_horiz = is_left or is_right + is_vert = is_top or is_bottom + pixels_horiz = expand_pixels if is_horiz else 0 + pixels_vert = expand_pixels if is_vert else 0 + + images_to_process = [] + output_images = [] + for n in range(count): + res_w = init[n].width + pixels_horiz + res_h = init[n].height + pixels_vert + process_res_w = math.ceil(res_w / 64) * 64 + process_res_h = math.ceil(res_h / 64) * 64 + + img = Image.new("RGB", (process_res_w, process_res_h)) + img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) + mask = Image.new("RGB", (process_res_w, process_res_h), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + expand_pixels + mask_blur if is_left else 0, + expand_pixels + mask_blur if is_top else 0, + mask.width - expand_pixels - mask_blur if is_right else res_w, + mask.height - expand_pixels - mask_blur if is_bottom else res_h, + ), fill="black") + + np_image = (np.asarray(img) / 255.0).astype(np.float64) + np_mask = (np.asarray(mask) / 255.0).astype(np.float64) + noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) + output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")) + + target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width + target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height + p.width = target_width if is_horiz else img.width + p.height = target_height if is_vert else img.height + + crop_region = ( + 0 if is_left else output_images[n].width - target_width, + 0 if is_top else output_images[n].height - target_height, + target_width if is_left else output_images[n].width, + target_height if is_top else output_images[n].height, + ) + mask = mask.crop(crop_region) + p.image_mask = mask + + image_to_process = output_images[n].crop(crop_region) + images_to_process.append(image_to_process) + + p.init_images = images_to_process + + latent_mask = Image.new("RGB", (p.width, p.height), "white") + draw = ImageDraw.Draw(latent_mask) + draw.rectangle(( + expand_pixels + mask_blur * 2 if is_left else 0, + expand_pixels + mask_blur * 2 if is_top else 0, + mask.width - expand_pixels - mask_blur * 2 if is_right else res_w, + mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h, + ), fill="black") + p.latent_mask = latent_mask + + proc = process_images(p) + + if initial_seed_and_info[0] is None: + initial_seed_and_info[0] = proc.seed + initial_seed_and_info[1] = proc.info + + for n in range(count): + output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height)) + output_images[n] = output_images[n].crop((0, 0, res_w, res_h)) + + return output_images + + batch_count = p.n_iter + batch_size = p.batch_size + p.n_iter = 1 + state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)) + all_processed_images = [] + + for i in range(batch_count): + imgs = [init_img] * batch_size + state.job = f"Batch {i + 1} out of {batch_count}" + + if left > 0: + imgs = expand(imgs, batch_size, left, is_left=True) + if right > 0: + imgs = expand(imgs, batch_size, right, is_right=True) + if up > 0: + imgs = expand(imgs, batch_size, up, is_top=True) + if down > 0: + imgs = expand(imgs, batch_size, down, is_bottom=True) + + all_processed_images += imgs + + all_images = all_processed_images + + combined_grid_image = images.image_grid(all_processed_images) + unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple + if opts.return_grid and not unwanted_grid_because_of_img_count: + all_images = [combined_grid_image] + all_processed_images + + res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1]) + + if opts.samples_save: + for img in all_processed_images: + images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p) + + if opts.grid_save and not unwanted_grid_because_of_img_count: + images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p) + + return res diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..d8feda00acde4e3ff644330cc62eb2dfdee060f6 --- /dev/null +++ b/scripts/poor_mans_outpainting.py @@ -0,0 +1,146 @@ +import math + +import modules.scripts as scripts +import gradio as gr +from PIL import Image, ImageDraw + +from modules import images, processing, devices +from modules.processing import Processed, process_images +from modules.shared import opts, cmd_opts, state + + +class Script(scripts.Script): + def title(self): + return "Poor man's outpainting" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + + return [pixels, mask_blur, inpainting_fill, direction] + + def run(self, p, pixels, mask_blur, inpainting_fill, direction): + initial_seed = None + initial_info = None + + p.mask_blur = mask_blur * 2 + p.inpainting_fill = inpainting_fill + p.inpaint_full_res = False + + left = pixels if "left" in direction else 0 + right = pixels if "right" in direction else 0 + up = pixels if "up" in direction else 0 + down = pixels if "down" in direction else 0 + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + left + right) / 64) * 64 + target_h = math.ceil((init_img.height + up + down) / 64) * 64 + + if left > 0: + left = left * (target_w - init_img.width) // (left + right) + if right > 0: + right = target_w - init_img.width - left + + if up > 0: + up = up * (target_h - init_img.height) // (up + down) + + if down > 0: + down = target_h - init_img.height - up + + img = Image.new("RGB", (target_w, target_h)) + img.paste(init_img, (left, up)) + + mask = Image.new("L", (img.width, img.height), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + left + (mask_blur * 2 if left > 0 else 0), + up + (mask_blur * 2 if up > 0 else 0), + mask.width - right - (mask_blur * 2 if right > 0 else 0), + mask.height - down - (mask_blur * 2 if down > 0 else 0) + ), fill="black") + + latent_mask = Image.new("L", (img.width, img.height), "white") + latent_draw = ImageDraw.Draw(latent_mask) + latent_draw.rectangle(( + left + (mask_blur//2 if left > 0 else 0), + up + (mask_blur//2 if up > 0 else 0), + mask.width - right - (mask_blur//2 if right > 0 else 0), + mask.height - down - (mask_blur//2 if down > 0 else 0) + ), fill="black") + + devices.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + work_mask = [] + work_latent_mask = [] + work_results = [] + + for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): + for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + + work.append(tiledata[2]) + work_mask.append(tiledata_mask[2]) + work_latent_mask.append(tiledata_latent_mask[2]) + + batch_count = len(work) + print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") + + state.job_count = batch_count + + for i in range(batch_count): + p.init_images = [work[i]] + p.image_mask = work_mask[i] + p.latent_mask = work_latent_mask[i] + + state.job = f"Batch {i + 1} out of {batch_count}" + processed = process_images(p) + + if initial_seed is None: + initial_seed = processed.seed + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p) + + processed = Processed(p, [combined_image], initial_seed, initial_info) + + return processed + diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d80d40e2b946fd35206f8bbc302a3cf476f081 --- /dev/null +++ b/scripts/postprocessing_codeformer.py @@ -0,0 +1,36 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, codeformer_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py new file mode 100644 index 0000000000000000000000000000000000000000..d854f3f7748dd8dec9575eb8914344db86a7f0c0 --- /dev/null +++ b/scripts/postprocessing_gfpgan.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, gfpgan_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..8842bd91c926be4382bf2e5c2877021e43f37a69 --- /dev/null +++ b/scripts/postprocessing_upscale.py @@ -0,0 +1,131 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, shared +import gradio as gr + +from modules.ui_components import FormRow + + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info[f"Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info[f"Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() + + +class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): + name = "Simple Upscale" + order = 900 + + def ui(self): + with FormRow(): + upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) + + return { + "upscale_by": upscale_by, + "upscaler_name": upscaler_name, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + if upscaler_name is None or upscaler_name == "None": + return + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) + assert upscaler1, f'could not find upscaler named {upscaler_name}' + + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.info[f"Postprocess upscaler"] = upscaler1.name diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c486d44e3c3e4df4dc663e598239a7632b0bf2 --- /dev/null +++ b/scripts/prompt_matrix.py @@ -0,0 +1,111 @@ +import math +from collections import namedtuple +from copy import copy +import random + +import modules.scripts as scripts +import gradio as gr + +from modules import images +from modules.processing import process_images, Processed +from modules.shared import opts, cmd_opts, state +import modules.sd_samplers + + +def draw_xy_grid(xs, ys, x_label, y_label, cell): + res = [] + + ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] + hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] + + first_processed = None + + state.job_count = len(xs) * len(ys) + + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed = cell(x, y) + if first_processed is None: + first_processed = processed + + res.append(processed.images[0]) + + grid = images.image_grid(res, rows=len(ys)) + grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) + + first_processed.images = [grid] + + return first_processed + + +class Script(scripts.Script): + def title(self): + return "Prompt matrix" + + def ui(self, is_img2img): + gr.HTML('
    ') + with gr.Row(): + with gr.Column(): + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) + with gr.Column(): + prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive") + variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma") + with gr.Column(): + margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + + return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size] + + def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size): + modules.processing.fix_seed(p) + # Raise error if promp type is not positive or negative + if prompt_type not in ["positive", "negative"]: + raise ValueError(f"Unknown prompt type {prompt_type}") + # Raise error if variations delimiter is not comma or space + if variations_delimiter not in ["comma", "space"]: + raise ValueError(f"Unknown variations delimiter {variations_delimiter}") + + prompt = p.prompt if prompt_type == "positive" else p.negative_prompt + original_prompt = prompt[0] if type(prompt) == list else prompt + positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt + + delimiter = ", " if variations_delimiter == "comma" else " " + + all_prompts = [] + prompt_matrix_parts = original_prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts) - 1) + for combination_num in range(combination_count): + selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] + + if put_at_start: + selected_prompts = selected_prompts + [prompt_matrix_parts[0]] + else: + selected_prompts = [prompt_matrix_parts[0]] + selected_prompts + + all_prompts.append(delimiter.join(selected_prompts)) + + p.n_iter = math.ceil(len(all_prompts) / p.batch_size) + p.do_not_save_grid = True + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") + + if prompt_type == "positive": + p.prompt = all_prompts + else: + p.negative_prompt = all_prompts + p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))] + p.prompt_for_display = positive_prompt + processed = process_images(p) + + grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) + grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size) + processed.images.insert(0, grid) + processed.index_of_first_image = 1 + processed.infotexts.insert(0, processed.infotexts[0]) + + if opts.grid_save: + images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p) + + return processed diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py new file mode 100644 index 0000000000000000000000000000000000000000..76dc5778b2754c3eda49d1b66220ebec42265623 --- /dev/null +++ b/scripts/prompts_from_file.py @@ -0,0 +1,177 @@ +import copy +import math +import os +import random +import sys +import traceback +import shlex + +import modules.scripts as scripts +import gradio as gr + +from modules import sd_samplers +from modules.processing import Processed, process_images +from PIL import Image +from modules.shared import opts, cmd_opts, state + + +def process_string_tag(tag): + return tag + + +def process_int_tag(tag): + return int(tag) + + +def process_float_tag(tag): + return float(tag) + + +def process_boolean_tag(tag): + return True if (tag == "true") else False + + +prompt_tags = { + "sd_model": None, + "outpath_samples": process_string_tag, + "outpath_grids": process_string_tag, + "prompt_for_display": process_string_tag, + "prompt": process_string_tag, + "negative_prompt": process_string_tag, + "styles": process_string_tag, + "seed": process_int_tag, + "subseed_strength": process_float_tag, + "subseed": process_int_tag, + "seed_resize_from_h": process_int_tag, + "seed_resize_from_w": process_int_tag, + "sampler_index": process_int_tag, + "sampler_name": process_string_tag, + "batch_size": process_int_tag, + "n_iter": process_int_tag, + "steps": process_int_tag, + "cfg_scale": process_float_tag, + "width": process_int_tag, + "height": process_int_tag, + "restore_faces": process_boolean_tag, + "tiling": process_boolean_tag, + "do_not_save_samples": process_boolean_tag, + "do_not_save_grid": process_boolean_tag +} + + +def cmdargs(line): + args = shlex.split(line) + pos = 0 + res = {} + + while pos < len(args): + arg = args[pos] + + assert arg.startswith("--"), f'must start with "--": {arg}' + assert pos+1 < len(args), f'missing argument for command line option {arg}' + + tag = arg[2:] + + if tag == "prompt" or tag == "negative_prompt": + pos += 1 + prompt = args[pos] + pos += 1 + while pos < len(args) and not args[pos].startswith("--"): + prompt += " " + prompt += args[pos] + pos += 1 + res[tag] = prompt + continue + + + func = prompt_tags.get(tag, None) + assert func, f'unknown commandline option: {arg}' + + val = args[pos+1] + if tag == "sampler_name": + val = sd_samplers.samplers_map.get(val.lower(), None) + + res[tag] = func(val) + + pos += 2 + + return res + + +def load_prompt_file(file): + if file is None: + lines = [] + else: + lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] + + return None, "\n".join(lines), gr.update(lines=7) + + +class Script(scripts.Script): + def title(self): + return "Prompts from file or textbox" + + def ui(self, is_img2img): + checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) + checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + + prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) + file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) + + file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) + + # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n. + # We don't shrink back to 1, because that causes the control to ignore [enter], and it may + # be unclear to the user that shift-enter is needed. + prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) + return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] + + def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): + lines = [x.strip() for x in prompt_txt.splitlines()] + lines = [x for x in lines if len(x) > 0] + + p.do_not_save_grid = True + + job_count = 0 + jobs = [] + + for line in lines: + if "--" in line: + try: + args = cmdargs(line) + except Exception: + print(f"Error parsing line {line} as commandline:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + args = {"prompt": line} + else: + args = {"prompt": line} + + job_count += args.get("n_iter", p.n_iter) + + jobs.append(args) + + print(f"Will process {len(lines)} lines in {job_count} jobs.") + if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1: + p.seed = int(random.randrange(4294967294)) + + state.job_count = job_count + + images = [] + all_prompts = [] + infotexts = [] + for n, args in enumerate(jobs): + state.job = f"{state.job_no + 1} out of {state.job_count}" + + copy_p = copy.copy(p) + for k, v in args.items(): + setattr(copy_p, k, v) + + proc = process_images(copy_p) + images += proc.images + + if checkbox_iterate: + p.seed = p.seed + (p.batch_size * p.n_iter) + all_prompts += proc.all_prompts + infotexts += proc.infotexts + + return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..332d76d918e53e5c2ed8d8cc1371413f1f29d7ee --- /dev/null +++ b/scripts/sd_upscale.py @@ -0,0 +1,101 @@ +import math + +import modules.scripts as scripts +import gradio as gr +from PIL import Image + +from modules import processing, shared, sd_samplers, images, devices +from modules.processing import Processed +from modules.shared import opts, cmd_opts, state + + +class Script(scripts.Script): + def title(self): + return "SD upscale" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + info = gr.HTML("

    Will upscale the image by the selected scale factor; use width and height sliders to set tile size

    ") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) + + return [info, overlap, upscaler_index, scale_factor] + + def run(self, p, _, overlap, upscaler_index, scale_factor): + if isinstance(upscaler_index, str): + upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) + processing.fix_seed(p) + upscaler = shared.sd_upscalers[upscaler_index] + + p.extra_generation_params["SD upscale overlap"] = overlap + p.extra_generation_params["SD upscale upscaler"] = upscaler.name + + initial_info = None + seed = p.seed + + init_img = p.init_images[0] + init_img = images.flatten(init_img, opts.img2img_background_color) + + if upscaler.name != "None": + img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) + else: + img = init_img + + devices.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) + + batch_size = p.batch_size + upscale_count = p.n_iter + p.n_iter = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + + for y, h, row in grid.tiles: + for tiledata in row: + work.append(tiledata[2]) + + batch_count = math.ceil(len(work) / batch_size) + state.job_count = batch_count * upscale_count + + print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") + + result_images = [] + for n in range(upscale_count): + start_seed = seed + n + p.seed = start_seed + + work_results = [] + for i in range(batch_count): + p.batch_size = batch_size + p.init_images = work[i * batch_size:(i + 1) * batch_size] + + state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" + processed = processing.process_images(p) + + if initial_info is None: + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + result_images.append(combined_image) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) + + processed = Processed(p, result_images, seed, initial_info) + + return processed diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..53511b121a2c401f6980d9a3c41cabe9b1c099fe --- /dev/null +++ b/scripts/xyz_grid.py @@ -0,0 +1,620 @@ +from collections import namedtuple +from copy import copy +from itertools import permutations, chain +import random +import csv +from io import StringIO +from PIL import Image +import numpy as np + +import modules.scripts as scripts +import gradio as gr + +from modules import images, paths, sd_samplers, processing, sd_models, sd_vae +from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img +from modules.shared import opts, cmd_opts, state +import modules.shared as shared +import modules.sd_samplers +import modules.sd_models +import modules.sd_vae +import glob +import os +import re + +from modules.ui_components import ToolButton + +fill_values_symbol = "\U0001f4d2" # 📒 + +AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) + + +def apply_field(field): + def fun(p, x, xs): + setattr(p, field, x) + + return fun + + +def apply_prompt(p, x, xs): + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") + + p.prompt = p.prompt.replace(xs[0], x) + p.negative_prompt = p.negative_prompt.replace(xs[0], x) + + +def apply_order(p, x, xs): + token_order = [] + + # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen + for token in x: + token_order.append((p.prompt.find(token), token)) + + token_order.sort(key=lambda t: t[0]) + + prompt_parts = [] + + # Split the prompt up, taking out the tokens + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + + # Rebuild the prompt with the tokens in the order we want + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + + +def apply_sampler(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + raise RuntimeError(f"Unknown sampler: {x}") + + p.sampler_name = sampler_name + + +def confirm_samplers(p, xs): + for x in xs: + if x.lower() not in sd_samplers.samplers_map: + raise RuntimeError(f"Unknown sampler: {x}") + + +def apply_checkpoint(p, x, xs): + info = modules.sd_models.get_closet_checkpoint_match(x) + if info is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + modules.sd_models.reload_model_weights(shared.sd_model, info) + + +def confirm_checkpoints(p, xs): + for x in xs: + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + + +def apply_clip_skip(p, x, xs): + opts.data["CLIP_stop_at_last_layers"] = x + + +def apply_upscale_latent_space(p, x, xs): + if x.lower().strip() != '0': + opts.data["use_scale_latent_for_hires_fix"] = True + else: + opts.data["use_scale_latent_for_hires_fix"] = False + + +def find_vae(name: str): + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None + else: + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified + else: + return modules.sd_vae.vae_dict[choices[0]] + + +def apply_vae(p, x, xs): + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) + + +def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): + p.styles.extend(x.split(',')) + + +def format_value_add_label(p, opt, x): + if type(x) == float: + x = round(x, 8) + + return f"{opt.label}: {x}" + + +def format_value(p, opt, x): + if type(x) == float: + x = round(x, 8) + return x + + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + +def do_nothing(p, x, xs): + pass + + +def format_nothing(p, opt, x): + return "" + + +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + +class AxisOption: + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = type + self.apply = apply + self.format_value = format_value + self.confirm = confirm + self.cost = cost + self.choices = choices + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False + + +axis_options = [ + AxisOption("Nothing", str, do_nothing, format_value=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), + AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), +] + + +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size): + hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] + title_texts = [[images.GridAnnotation(z)] for z in z_labels] + + # Temporary list of all the images that are generated to be populated into the grid. + # Will be filled with empty images for any individual step that fails to process properly + image_cache = [None] * (len(xs) * len(ys) * len(zs)) + + processed_result = None + cell_mode = "P" + cell_size = (1, 1) + + state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter + + def process_cell(x, y, z, ix, iy, iz): + nonlocal image_cache, processed_result, cell_mode, cell_size + + def index(ix, iy, iz): + return ix + iy * len(xs) + iz * len(xs) * len(ys) + + state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" + + processed: Processed = cell(x, y, z) + + try: + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + processed_result.all_prompts = [processed.prompt] + processed_result.all_seeds = [processed.seed] + processed_result.infotexts = [processed.infotexts[0]] + + image_cache[index(ix, iy, iz)] = processed_image + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) + except: + image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) + + if first_axes_processed == 'x': + for ix, x in enumerate(xs): + if second_axes_processed == 'y': + for iy, y in enumerate(ys): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'y': + for iy, y in enumerate(ys): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'z': + for iz, z in enumerate(zs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + + if not processed_result: + print("Unexpected error: draw_xyz_grid failed to return even a single processed image") + return Processed(p, []) + + sub_grids = [None] * len(zs) + for i in range(len(zs)): + start_index = i * len(xs) * len(ys) + end_index = start_index + len(xs) * len(ys) + grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size) + sub_grids[i] = grid + if include_sub_grids and len(zs) > 1: + processed_result.images.insert(i+1, grid) + + sub_grid_size = sub_grids[0].size + z_grid = images.image_grid(sub_grids, rows=1) + if draw_legend: + z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) + processed_result.images[0] = z_grid + + return processed_result, sub_grids + + +class SharedSettingsStackHelper(object): + def __enter__(self): + self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers + self.vae = opts.sd_vae + + def __exit__(self, exc_type, exc_value, tb): + opts.data["sd_vae"] = self.vae + modules.sd_models.reload_model_weights() + modules.sd_vae.reload_vae_weights() + + opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + + +re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") +re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") + +re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") +re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + + +class Script(scripts.Script): + def title(self): + return "X/Y/Z plot" + + def ui(self, is_img2img): + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + + with gr.Row(): + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) + + with gr.Row(): + z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) + z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) + fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) + + with gr.Row(variant="compact", elem_id="axis_options"): + with gr.Column(): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + with gr.Column(): + include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Column(): + margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + + with gr.Row(variant="compact", elem_id="swap_axes"): + swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") + swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") + swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") + + def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values): + return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values + + xy_swap_args = [x_type, x_values, y_type, y_values] + swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) + yz_swap_args = [y_type, y_values, z_type, z_values] + swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) + xz_swap_args = [x_type, x_values, z_type, z_values] + swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) + + def fill(x_type): + axis = self.current_axis_options[x_type] + return ", ".join(axis.choices()) if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values]) + + def select_axis(x_type): + return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) + + x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) + y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) + + self.infotext_fields = ( + (x_type, "X Type"), + (x_values, "X Values"), + (y_type, "Y Type"), + (y_values, "Y Values"), + (z_type, "Z Type"), + (z_values, "Z Values"), + ) + + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + + def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): + if not no_fixed_seeds: + modules.processing.fix_seed(p) + + if not opts.return_grid: + p.batch_size = 1 + + def process_axis(opt, vals): + if opt.label == 'Nothing': + return [0] + + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] + + if opt.type == int: + valslist_ext = [] + + for val in valslist: + m = re_range.fullmatch(val) + mc = re_range_count.fullmatch(val) + if m is not None: + start = int(m.group(1)) + end = int(m.group(2))+1 + step = int(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += list(range(start, end, step)) + elif mc is not None: + start = int(mc.group(1)) + end = int(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + + valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] + else: + valslist_ext.append(val) + + valslist = valslist_ext + elif opt.type == float: + valslist_ext = [] + + for val in valslist: + m = re_range_float.fullmatch(val) + mc = re_range_count_float.fullmatch(val) + if m is not None: + start = float(m.group(1)) + end = float(m.group(2)) + step = float(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += np.arange(start, end + step, step).tolist() + elif mc is not None: + start = float(mc.group(1)) + end = float(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + + valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() + else: + valslist_ext.append(val) + + valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) + + valslist = [opt.type(x) for x in valslist] + + # Confirm options are valid before starting + if opt.confirm: + opt.confirm(p, valslist) + + return valslist + + x_opt = self.current_axis_options[x_type] + xs = process_axis(x_opt, x_values) + + y_opt = self.current_axis_options[y_type] + ys = process_axis(y_opt, y_values) + + z_opt = self.current_axis_options[z_type] + zs = process_axis(z_opt, z_values) + + def fix_axis_seeds(axis_opt, axis_list): + if axis_opt.label in ['Seed', 'Var. seed']: + return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] + else: + return axis_list + + if not no_fixed_seeds: + xs = fix_axis_seeds(x_opt, xs) + ys = fix_axis_seeds(y_opt, ys) + zs = fix_axis_seeds(z_opt, zs) + + if x_opt.label == 'Steps': + total_steps = sum(xs) * len(ys) * len(zs) + elif y_opt.label == 'Steps': + total_steps = sum(ys) * len(xs) * len(zs) + elif z_opt.label == 'Steps': + total_steps = sum(zs) * len(xs) * len(ys) + else: + total_steps = p.steps * len(xs) * len(ys) * len(zs) + + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) * len(zs) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) * len(zs) + elif z_opt.label == "Hires steps": + total_steps += sum(zs) * len(xs) * len(ys) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) + else: + total_steps *= 2 + + total_steps *= p.n_iter + + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + plural_s = 's' if len(zs) > 1 else '' + print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) + + grid_infotext = [None] + + state.xyz_plot_x = AxisInfo(x_opt, xs) + state.xyz_plot_y = AxisInfo(y_opt, ys) + state.xyz_plot_z = AxisInfo(z_opt, zs) + + # If one of the axes is very slow to change between (like SD model + # checkpoint), then make sure it is in the outer iteration of the nested + # `for` loop. + first_axes_processed = 'x' + second_axes_processed = 'y' + if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: + first_axes_processed = 'x' + if y_opt.cost > z_opt.cost: + second_axes_processed = 'y' + else: + second_axes_processed = 'z' + elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: + first_axes_processed = 'y' + if x_opt.cost > z_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'z' + elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: + first_axes_processed = 'z' + if x_opt.cost > y_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'y' + + def cell(x, y, z): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + + pc = copy(p) + pc.styles = pc.styles[:] + x_opt.apply(pc, x, xs) + y_opt.apply(pc, y, ys) + z_opt.apply(pc, z, zs) + + res = process_images(pc) + + if grid_infotext[0] is None: + pc.extra_generation_params = copy(pc.extra_generation_params) + pc.extra_generation_params['Script'] = self.title() + + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + + if z_opt.label != 'Nothing': + pc.extra_generation_params["Z Type"] = z_opt.label + pc.extra_generation_params["Z Values"] = z_values + if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) + + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + + return res + + with SharedSettingsStackHelper(): + processed, sub_grids = draw_xyz_grid( + p, + xs=xs, + ys=ys, + zs=zs, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images, + include_sub_grids=include_sub_grids, + first_axes_processed=first_axes_processed, + second_axes_processed=second_axes_processed, + margin_size=margin_size + ) + + if opts.grid_save and len(sub_grids) > 1: + for sub_grid in sub_grids: + images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + + if opts.grid_save: + images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + + return processed diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..05572f662d1aa817d4000e16789c70899c51c648 --- /dev/null +++ b/style.css @@ -0,0 +1,961 @@ +.container { + max-width: 100%; +} + +.token-counter{ + position: absolute; + display: inline-block; + right: 2em; + min-width: 0 !important; + width: auto; + z-index: 100; +} + +.token-counter.error span{ + box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); + border: 2px solid rgba(255,0,0,0.4) !important; +} + +.token-counter div{ + display: inline; +} + +.token-counter span{ + padding: 0.1em 0.75em; +} + +#sh{ + min-width: 2em; + min-height: 2em; + max-width: 2em; + max-height: 2em; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; + margin: 0.1em 0; + opacity: 0%; + cursor: default; +} + +.output-html p {margin: 0 0.5em;} + +.row > *, +.row > .gr-form > * { + min-width: min(120px, 100%); + flex: 1 1 0%; +} + +.performance { + font-size: 0.85em; + color: #444; +} + +.performance p{ + display: inline-block; +} + +.performance .time { + margin-right: 0; +} + +.performance .vram { +} + +#txt2img_generate, #img2img_generate { + min-height: 4.5em; +} + +@media screen and (min-width: 2500px) { + #txt2img_gallery, #img2img_gallery { + min-height: 768px; + } +} + +#txt2img_gallery img, #img2img_gallery img{ + object-fit: scale-down; +} +#txt2img_actions_column, #img2img_actions_column { + margin: 0.35rem 0.75rem 0.35rem 0; +} +#script_list { + padding: .625rem .75rem 0 .625rem; +} +.justify-center.overflow-x-scroll { + justify-content: left; +} + +.justify-center.overflow-x-scroll button:first-of-type { + margin-left: auto; +} + +.justify-center.overflow-x-scroll button:last-of-type { + margin-right: auto; +} + +[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{ + min-width: 2.3em; + height: 2.5em; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; +} + +#hidden_element{ + display: none; +} + +[id$=_seed_row], [id$=_subseed_row]{ + gap: 0.5rem; + padding: 0.6em; +} + +[id$=_subseed_show_box]{ + min-width: auto; + flex-grow: 0; +} + +[id$=_subseed_show_box] > div{ + border: 0; + height: 100%; +} + +[id$=_subseed_show]{ + min-width: auto; + flex-grow: 0; + padding: 0; +} + +[id$=_subseed_show] label{ + height: 100%; +} + +#txt2img_actions_column, #img2img_actions_column{ + gap: 0; + margin-right: .75rem; +} + +#txt2img_tools, #img2img_tools{ + gap: 0.4em; +} + +#interrogate_col{ + min-width: 0 !important; + max-width: 8em !important; + margin-right: 1em; + gap: 0; +} +#interrogate, #deepbooru{ + margin: 0em 0.25em 0.5em 0.25em; + min-width: 8em; + max-width: 8em; +} + +#style_pos_col, #style_neg_col{ + min-width: 8em !important; +} + +#txt2img_styles_row, #img2img_styles_row{ + gap: 0.25em; + margin-top: 0.3em; +} + +#txt2img_styles_row > button, #img2img_styles_row > button{ + margin: 0; +} + +#txt2img_styles, #img2img_styles{ + padding: 0; +} + +#txt2img_styles > label > div, #img2img_styles > label > div{ + min-height: 3.2em; +} + +ul.list-none{ + max-height: 35em; + z-index: 2000; +} + +.gr-form{ + background: transparent; +} + +.my-4{ + margin-top: 0; + margin-bottom: 0; +} + +#resize_mode{ + flex: 1.5; +} + +button{ + align-self: stretch !important; +} + +.overflow-hidden, .gr-panel{ + overflow: visible !important; +} + +#x_type, #y_type{ + max-width: 10em; +} + +#txt2img_preview, #img2img_preview, #ti_preview{ + position: absolute; + width: 320px; + left: 0; + right: 0; + margin-left: auto; + margin-right: auto; + margin-top: 34px; + z-index: 100; + border: none; + border-top-left-radius: 0; + border-top-right-radius: 0; +} + +@media screen and (min-width: 768px) { + #txt2img_preview, #img2img_preview, #ti_preview { + position: absolute; + } +} + +@media screen and (max-width: 767px) { + #txt2img_preview, #img2img_preview, #ti_preview { + position: relative; + } +} + +#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{ + display: none; +} + +fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ + position: absolute; + top: -0.7em; + line-height: 1.2em; + padding: 0; + margin: 0 0.5em; + + background-color: white; + box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white; + + z-index: 300; +} + +.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ + background-color: rgb(31, 41, 55); + box-shadow: none; + border: 1px solid rgba(128, 128, 128, 0.1); + border-radius: 6px; + padding: 0.1em 0.5em; +} + +#txt2img_column_batch, #img2img_column_batch{ + min-width: min(13.5em, 100%) !important; +} + +#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ + position: relative; + border: none; + margin-right: 8em; +} + +#settings .gr-panel div.flex-col div.justify-between div{ + position: relative; + z-index: 200; +} + +#settings{ + display: block; +} + +#settings > div{ + border: none; + margin-left: 10em; +} + +#settings > div.flex-wrap{ + float: left; + display: block; + margin-left: 0; + width: 10em; +} + +#settings > div.flex-wrap button{ + display: block; + border: none; + text-align: left; +} + +#settings_result{ + height: 1.4em; + margin: 0 1.2em; +} + +input[type="range"]{ + margin: 0.5em 0 -0.3em 0; +} + +#mask_bug_info { + text-align: center; + display: block; + margin-top: -0.75em; + margin-bottom: -0.75em; +} + +#txt2img_negative_prompt, #img2img_negative_prompt{ +} + +/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */ +.transition.opacity-20 { + opacity: 1 !important; +} + +/* more gradio's garbage cleanup */ +.min-h-\[4rem\] { min-height: unset !important; } +.min-h-\[6rem\] { min-height: unset !important; } + +.progressDiv{ + position: relative; + height: 20px; + background: #b4c0cc; + border-radius: 3px !important; + margin-bottom: -3px; +} + +.dark .progressDiv{ + background: #424c5b; +} + +.progressDiv .progress{ + width: 0%; + height: 20px; + background: #0060df; + color: white; + font-weight: bold; + line-height: 20px; + padding: 0 8px 0 0; + text-align: right; + border-radius: 3px; + overflow: visible; + white-space: nowrap; + padding: 0 0.5em; +} + +.livePreview{ + position: absolute; + z-index: 300; + background-color: white; + margin: -4px; +} + +.dark .livePreview{ + background-color: rgb(17 24 39 / var(--tw-bg-opacity)); +} + +.livePreview img{ + position: absolute; + object-fit: contain; + width: 100%; + height: 100%; +} + +#lightboxModal{ + display: none; + position: fixed; + z-index: 1001; + padding-top: 100px; + left: 0; + top: 0; + width: 100%; + height: 100%; + overflow: auto; + background-color: rgba(20, 20, 20, 0.95); + user-select: none; + -webkit-user-select: none; +} + +.modalControls { + display: grid; + grid-template-columns: 32px 32px 32px 1fr 32px; + grid-template-areas: "zoom tile save space close"; + position: absolute; + top: 0; + left: 0; + right: 0; + padding: 16px; + gap: 16px; + background-color: rgba(0,0,0,0.2); +} + +.modalClose { + grid-area: close; +} + +.modalZoom { + grid-area: zoom; +} + +.modalSave { + grid-area: save; +} + +.modalTileImage { + grid-area: tile; +} + +.modalClose, +.modalZoom, +.modalTileImage { + color: white; + font-size: 35px; + font-weight: bold; + cursor: pointer; +} + +.modalSave { + color: white; + font-size: 28px; + margin-top: 8px; + font-weight: bold; + cursor: pointer; +} + +.modalClose:hover, +.modalClose:focus, +.modalSave:hover, +.modalSave:focus, +.modalZoom:hover, +.modalZoom:focus { + color: #999; + text-decoration: none; + cursor: pointer; +} + +#modalImage { + display: block; + margin-left: auto; + margin-right: auto; + margin-top: auto; + width: auto; +} + +.modalImageFullscreen { + object-fit: contain; + height: 90%; +} + +.modalPrev, +.modalNext { + cursor: pointer; + position: absolute; + top: 50%; + width: auto; + padding: 16px; + margin-top: -50px; + color: white; + font-weight: bold; + font-size: 20px; + transition: 0.6s ease; + border-radius: 0 3px 3px 0; + user-select: none; + -webkit-user-select: none; +} + +.modalNext { + right: 0; + border-radius: 3px 0 0 3px; +} + +.modalPrev:hover, +.modalNext:hover { + background-color: rgba(0, 0, 0, 0.8); +} + +#imageARPreview{ + position:absolute; + top:0px; + left:0px; + border:2px solid red; + background:rgba(255, 0, 0, 0.3); + z-index: 900; + pointer-events:none; + display:none +} + +#txt2img_generate_box, #img2img_generate_box{ + position: relative; +} + +#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + height: 100%; + background: #b4c0cc; + display: none; +} + +#txt2img_interrupt, #img2img_interrupt{ + left: 0; + border-radius: 0.5rem 0 0 0.5rem; +} +#txt2img_skip, #img2img_skip{ + right: 0; + border-radius: 0 0.5rem 0.5rem 0; +} + +.red { + color: red; +} + +.gallery-item { + --tw-bg-opacity: 0 !important; +} + +#context-menu{ + z-index:9999; + position:absolute; + display:block; + padding:0px 0; + border:2px solid #a55000; + border-radius:8px; + box-shadow:1px 1px 2px #CE6400; + width: 200px; +} + +.context-menu-items{ + list-style: none; + margin: 0; + padding: 0; +} + +.context-menu-items a{ + display:block; + padding:5px; + cursor:pointer; +} + +.context-menu-items a:hover{ + background: #a55000; +} + +#quicksettings { + width: fit-content; +} + +#quicksettings > div, #quicksettings > fieldset{ + max-width: 24em; + min-width: 24em; + padding: 0; + border: none; + box-shadow: none; + background: none; + margin-right: 10px; +} + +#quicksettings > div > div > div > label > span { + position: relative; + margin-right: 9em; + margin-bottom: -1em; +} + +canvas[key="mask"] { + z-index: 12 !important; + filter: invert(); + mix-blend-mode: multiply; + pointer-events: none; +} + + +/* gradio 3.4.1 stuff for editable scrollbar values */ +.gr-box > div > div > input.gr-text-input{ + position: absolute; + right: 0.5em; + top: -0.6em; + z-index: 400; + width: 6em; +} +#quicksettings .gr-box > div > div > input.gr-text-input { + top: -1.12em; +} + +.row.gr-compact{ + overflow: visible; +} + +#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, +#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img, +#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img, +#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img +{ + height: 480px !important; + max-height: 480px !important; + min-height: 480px !important; +} + +/* Extensions */ + +#tab_extensions table{ + border-collapse: collapse; +} + +#tab_extensions table td, #tab_extensions table th{ + border: 1px solid #ccc; + padding: 0.25em 0.5em; +} + +#tab_extensions table input[type="checkbox"]{ + margin-right: 0.5em; +} + +#tab_extensions button{ + max-width: 16em; +} + +#tab_extensions input[disabled="disabled"]{ + opacity: 0.5; +} + +.extension-tag{ + font-weight: bold; + font-size: 95%; +} + +#available_extensions .info{ + margin: 0; +} + +#available_extensions .date_added{ + opacity: 0.85; + font-size: 90%; +} + +#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{ + min-width: auto; + padding-left: 0.5em; + padding-right: 0.5em; +} + +.gr-form{ + background-color: white; +} + +.dark .gr-form{ + background-color: rgb(31 41 55 / var(--tw-bg-opacity)); +} + +.gr-button-tool, .gr-button-tool-top{ + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.4em; +} + +.gr-button-tool{ + margin: 0.6em 0em 0.55em 0; +} + +.gr-button-tool-top, #settings .gr-button-tool{ + margin: 1.6em 0.7em 0.55em 0; +} + + +#modelmerger_results_container{ + margin-top: 1em; + overflow: visible; +} + +#modelmerger_models{ + gap: 0; +} + + +#quicksettings .gr-button-tool{ + margin: 0; + border-color: unset; + background-color: unset; +} + +#modelmerger_interp_description>p { + margin: 0!important; + text-align: center; +} +#modelmerger_interp_description { + margin: 0.35rem 0.75rem 1.23rem; +} +#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { + padding-top: 0.9em; + padding-bottom: 0.9em; +} +#txt2img_settings { + padding-top: 1.16em; + padding-bottom: 0.9em; +} +#img2img_settings { + padding-bottom: 0.9em; +} + +#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{ + border: none; + padding-bottom: 0.5em; +} + +footer { + display: none !important; +} + +#footer{ + text-align: center; +} + +#footer div{ + display: inline-block; +} + +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + +#txtimg_hr_finalres{ + min-height: 0 !important; + padding: .625rem .75rem; + margin-left: -0.75em + +} + +#txtimg_hr_finalres .resolution{ + font-weight: bold; +} + +#txt2img_checkboxes, #img2img_checkboxes{ + margin-bottom: 0.5em; + margin-left: 0em; +} +#txt2img_checkboxes > div, #img2img_checkboxes > div{ + flex: 0; + white-space: nowrap; + min-width: auto; +} + +#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ + margin-left: 0em; +} + +#axis_options { + margin-left: 0em; +} + +.inactive{ + opacity: 0.5; +} + +[id*='_prompt_container']{ + gap: 0; +} + +[id*='_prompt_container'] > div{ + margin: -0.4em 0 0 0; +} + +.gr-compact { + border: none; +} + +.dark .gr-compact{ + background-color: rgb(31 41 55 / var(--tw-bg-opacity)); + margin-left: 0; +} + +.gr-compact{ + overflow: visible; +} + +.gr-compact > *{ +} + +.gr-compact .gr-block, .gr-compact .gr-form{ + border: none; + box-shadow: none; +} + +.gr-compact .gr-box{ + border-radius: .5rem !important; + border-width: 1px !important; +} + +#mode_img2img > div > div{ + gap: 0 !important; +} + +[id*='img2img_copy_to_'] { + border: none; +} + +[id*='img2img_copy_to_'] > button { +} + +[id*='img2img_label_copy_to_'] { + font-size: 1.0em; + font-weight: bold; + text-align: center; + line-height: 2.4em; +} + +.extra-networks > div > [id *= '_extra_']{ + margin: 0.3em; +} + +.extra-network-subdirs{ + padding: 0.2em 0.35em; +} + +.extra-network-subdirs button{ + margin: 0 0.15em; +} + +#txt2img_extra_networks .search, #img2img_extra_networks .search{ + display: inline-block; + max-width: 16em; + margin: 0.3em; + align-self: center; +} + +#txt2img_extra_view, #img2img_extra_view { + width: auto; +} + +.extra-network-cards .nocards, .extra-network-thumbs .nocards{ + margin: 1.25em 0.5em 0.5em 0.5em; +} + +.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{ + font-size: 1.5em; + margin-bottom: 1em; +} + +.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{ + margin-left: 0.5em; +} + +.extra-network-thumbs { + display: flex; + flex-flow: row wrap; + gap: 10px; +} + +.extra-network-thumbs .card { + height: 6em; + width: 6em; + cursor: pointer; + background-image: url('./file=html/card-no-preview.png'); + background-size: cover; + background-position: center center; + position: relative; +} + +.extra-network-thumbs .card:hover .additional a { + display: block; +} + +.extra-network-thumbs .actions .additional a { + background-image: url('./file=html/image-update.svg'); + background-repeat: no-repeat; + background-size: cover; + background-position: center center; + position: absolute; + top: 0; + left: 0; + width: 24px; + height: 24px; + display: none; + font-size: 0; + text-align: -9999; +} + +.extra-network-thumbs .actions .name { + position: absolute; + bottom: 0; + font-size: 10px; + padding: 3px; + width: 100%; + overflow: hidden; + white-space: nowrap; + text-overflow: ellipsis; + background: rgba(0,0,0,.5); + color: white; +} + +.extra-network-thumbs .card:hover .actions .name { + white-space: normal; + word-break: break-all; +} + +.extra-network-cards .card{ + display: inline-block; + margin: 0.5em; + width: 16em; + height: 24em; + box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); + border-radius: 0.2em; + position: relative; + + background-size: auto 100%; + background-position: center; + overflow: hidden; + cursor: pointer; + + background-image: url('./file=html/card-no-preview.png') +} + +.extra-network-cards .card:hover{ + box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); +} + +.extra-network-cards .card .actions .additional{ + display: none; +} + +.extra-network-cards .card .actions{ + position: absolute; + bottom: 0; + left: 0; + right: 0; + padding: 0.5em; + color: white; + background: rgba(0,0,0,0.5); + box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5); + text-shadow: 0 0 0.2em black; +} + +.extra-network-cards .card .actions:hover{ + box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; +} + +.extra-network-cards .card .actions .name{ + font-size: 1.7em; + font-weight: bold; + line-break: anywhere; +} + +.extra-network-cards .card .actions:hover .additional{ + display: block; +} + +.extra-network-cards .card ul{ + margin: 0.25em 0 0.75em 0.25em; + cursor: unset; +} + +.extra-network-cards .card ul a{ + cursor: pointer; +} + +.extra-network-cards .card ul a:hover{ + color: red; +} + +[id*='_prompt_container'] > div { + margin: 0!important; +} diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/basic_features/__init__.py b/test/basic_features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0170c511fe54cc6bcf49ec7f75ca7c747de41db5 --- /dev/null +++ b/test/basic_features/extras_test.py @@ -0,0 +1,54 @@ +import unittest +import requests +from gradio.processing_utils import encode_pil_to_base64 +from PIL import Image + +class TestExtrasWorking(unittest.TestCase): + def setUp(self): + self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image" + self.extras_single = { + "resize_mode": 0, + "show_extras_results": True, + "gfpgan_visibility": 0, + "codeformer_visibility": 0, + "codeformer_weight": 0, + "upscaling_resize": 2, + "upscaling_resize_w": 128, + "upscaling_resize_h": 128, + "upscaling_crop": True, + "upscaler_1": "None", + "upscaler_2": "None", + "extras_upscaler_2_visibility": 0, + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + } + + def test_simple_upscaling_performed(self): + self.extras_single["upscaler_1"] = "Lanczos" + self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200) + + +class TestPngInfoWorking(unittest.TestCase): + def setUp(self): + self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" + self.png_info = { + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + } + + def test_png_info_performed(self): + self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200) + + +class TestInterrogateWorking(unittest.TestCase): + def setUp(self): + self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" + self.interrogate = { + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), + "model": "clip" + } + + def test_interrogate_performed(self): + self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py new file mode 100644 index 0000000000000000000000000000000000000000..08c5c903e8382ef4b969b01da87bc69fb06ff2b4 --- /dev/null +++ b/test/basic_features/img2img_test.py @@ -0,0 +1,66 @@ +import unittest +import requests +from gradio.processing_utils import encode_pil_to_base64 +from PIL import Image + + +class TestImg2ImgWorking(unittest.TestCase): + def setUp(self): + self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" + self.simple_img2img = { + "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], + "resize_mode": 0, + "denoising_strength": 0.75, + "mask": None, + "mask_blur": 4, + "inpainting_fill": 0, + "inpaint_full_res": False, + "inpaint_full_res_padding": 0, + "inpainting_mask_invert": False, + "prompt": "example prompt", + "styles": [], + "seed": -1, + "subseed": -1, + "subseed_strength": 0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "batch_size": 1, + "n_iter": 1, + "steps": 3, + "cfg_scale": 7, + "width": 64, + "height": 64, + "restore_faces": False, + "tiling": False, + "negative_prompt": "", + "eta": 0, + "s_churn": 0, + "s_tmax": 0, + "s_tmin": 0, + "s_noise": 1, + "override_settings": {}, + "sampler_index": "Euler a", + "include_init_images": False + } + + def test_img2img_simple_performed(self): + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + + def test_inpainting_masked_performed(self): + self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + + def test_inpainting_with_inverted_masked_performed(self): + self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) + self.simple_img2img["inpainting_mask_invert"] = True + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + + def test_img2img_sd_upscale_performed(self): + self.simple_img2img["script_name"] = "sd upscale" + self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] + + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa43a44a3e3818d7220a98acf5a6a504cdca3e3 --- /dev/null +++ b/test/basic_features/txt2img_test.py @@ -0,0 +1,80 @@ +import unittest +import requests + + +class TestTxt2ImgWorking(unittest.TestCase): + def setUp(self): + self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" + self.simple_txt2img = { + "enable_hr": False, + "denoising_strength": 0, + "firstphase_width": 0, + "firstphase_height": 0, + "prompt": "example prompt", + "styles": [], + "seed": -1, + "subseed": -1, + "subseed_strength": 0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "batch_size": 1, + "n_iter": 1, + "steps": 3, + "cfg_scale": 7, + "width": 64, + "height": 64, + "restore_faces": False, + "tiling": False, + "negative_prompt": "", + "eta": 0, + "s_churn": 0, + "s_tmax": 0, + "s_tmin": 0, + "s_noise": 1, + "sampler_index": "Euler a" + } + + def test_txt2img_simple_performed(self): + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_with_negative_prompt_performed(self): + self.simple_txt2img["negative_prompt"] = "example negative prompt" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_with_complex_prompt_performed(self): + self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_not_square_image_performed(self): + self.simple_txt2img["height"] = 128 + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_with_hrfix_performed(self): + self.simple_txt2img["enable_hr"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_with_tiling_performed(self): + self.simple_txt2img["tiling"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_with_restore_faces_performed(self): + self.simple_txt2img["restore_faces"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_with_vanilla_sampler_performed(self): + self.simple_txt2img["sampler_index"] = "PLMS" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + self.simple_txt2img["sampler_index"] = "DDIM" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_multiple_batches_performed(self): + self.simple_txt2img["n_iter"] = 2 + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + def test_txt2img_batch_performed(self): + self.simple_txt2img["batch_size"] = 2 + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfc28a0d30c070c292ff8154e9b93a74abecb85 --- /dev/null +++ b/test/basic_features/utils_test.py @@ -0,0 +1,62 @@ +import unittest +import requests + +class UtilsTests(unittest.TestCase): + def setUp(self): + self.url_options = "http://localhost:7860/sdapi/v1/options" + self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" + self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" + self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" + self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" + self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" + self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" + self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" + self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" + self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" + + def test_options_get(self): + self.assertEqual(requests.get(self.url_options).status_code, 200) + + def test_options_write(self): + response = requests.get(self.url_options) + self.assertEqual(response.status_code, 200) + + pre_value = response.json()["send_seed"] + + self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) + + response = requests.get(self.url_options) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["send_seed"], not pre_value) + + requests.post(self.url_options, json={"send_seed": pre_value}) + + def test_cmd_flags(self): + self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) + + def test_samplers(self): + self.assertEqual(requests.get(self.url_samplers).status_code, 200) + + def test_upscalers(self): + self.assertEqual(requests.get(self.url_upscalers).status_code, 200) + + def test_sd_models(self): + self.assertEqual(requests.get(self.url_sd_models).status_code, 200) + + def test_hypernetworks(self): + self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200) + + def test_face_restorers(self): + self.assertEqual(requests.get(self.url_face_restorers).status_code, 200) + + def test_realesrgan_models(self): + self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200) + + def test_prompt_styles(self): + self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200) + + def test_embeddings(self): + self.assertEqual(requests.get(self.url_embeddings).status_code, 200) + +if __name__ == "__main__": + unittest.main() diff --git a/test/server_poll.py b/test/server_poll.py new file mode 100644 index 0000000000000000000000000000000000000000..42d56a4caacfc40d686dc99668d72238392448cd --- /dev/null +++ b/test/server_poll.py @@ -0,0 +1,24 @@ +import unittest +import requests +import time + + +def run_tests(proc, test_dir): + timeout_threshold = 240 + start_time = time.time() + while time.time()-start_time < timeout_threshold: + try: + requests.head("http://localhost:7860/") + break + except requests.exceptions.ConnectionError: + if proc.poll() is not None: + break + if proc.poll() is None: + if test_dir is None: + test_dir = "test" + suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") + result = unittest.TextTestRunner(verbosity=2).run(suite) + return len(result.failures) + len(result.errors) + else: + print("Launch unsuccessful") + return 1 diff --git a/test/test_files/empty.pt b/test/test_files/empty.pt new file mode 100644 index 0000000000000000000000000000000000000000..72f13b63cc3353c1c7cdc27ee23a2ad242cfdfba --- /dev/null +++ b/test/test_files/empty.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d030ad8db708280fcae77d87e973102039acd23a11bdecc3db8eb6c0ac940ee1 +size 431 diff --git a/test/test_files/img2img_basic.png b/test/test_files/img2img_basic.png new file mode 100644 index 0000000000000000000000000000000000000000..b5c1c9ad35162dbc6ebed5aa46682e1b065b2739 --- /dev/null +++ b/test/test_files/img2img_basic.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:401e359065bac7d28ec38bdb8e1c63667879941d0a46b67a3aaa07b0bd88fb7a +size 9932 diff --git a/test/test_files/mask_basic.png b/test/test_files/mask_basic.png new file mode 100644 index 0000000000000000000000000000000000000000..0500048c96868010a723fffe6f50673737991a8a --- /dev/null +++ b/test/test_files/mask_basic.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f6f760130a1712dbf0312d9315ff9a07efd00f2816f1e9b69d3f0e85b68315d +size 362 diff --git a/textual_inversion_templates/hypernetwork.txt b/textual_inversion_templates/hypernetwork.txt new file mode 100644 index 0000000000000000000000000000000000000000..91e06890571c7e4974d5a76c30fab62e8587c7d2 --- /dev/null +++ b/textual_inversion_templates/hypernetwork.txt @@ -0,0 +1,27 @@ +a photo of a [filewords] +a rendering of a [filewords] +a cropped photo of the [filewords] +the photo of a [filewords] +a photo of a clean [filewords] +a photo of a dirty [filewords] +a dark photo of the [filewords] +a photo of my [filewords] +a photo of the cool [filewords] +a close-up photo of a [filewords] +a bright photo of the [filewords] +a cropped photo of a [filewords] +a photo of the [filewords] +a good photo of the [filewords] +a photo of one [filewords] +a close-up photo of the [filewords] +a rendition of the [filewords] +a photo of the clean [filewords] +a rendition of a [filewords] +a photo of a nice [filewords] +a good photo of a [filewords] +a photo of the nice [filewords] +a photo of the small [filewords] +a photo of the weird [filewords] +a photo of the large [filewords] +a photo of a cool [filewords] +a photo of a small [filewords] diff --git a/textual_inversion_templates/none.txt b/textual_inversion_templates/none.txt new file mode 100644 index 0000000000000000000000000000000000000000..f77af4612b289a56b718c3bee62c66a6151f75be --- /dev/null +++ b/textual_inversion_templates/none.txt @@ -0,0 +1 @@ +picture diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt new file mode 100644 index 0000000000000000000000000000000000000000..15af2d6b85f259d0bf41fbe0c8ca7a3340e1b259 --- /dev/null +++ b/textual_inversion_templates/style.txt @@ -0,0 +1,19 @@ +a painting, art by [name] +a rendering, art by [name] +a cropped painting, art by [name] +the painting, art by [name] +a clean painting, art by [name] +a dirty painting, art by [name] +a dark painting, art by [name] +a picture, art by [name] +a cool painting, art by [name] +a close-up painting, art by [name] +a bright painting, art by [name] +a cropped painting, art by [name] +a good painting, art by [name] +a close-up painting, art by [name] +a rendition, art by [name] +a nice painting, art by [name] +a small painting, art by [name] +a weird painting, art by [name] +a large painting, art by [name] diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt new file mode 100644 index 0000000000000000000000000000000000000000..b3a8159a869a7890bdd42470664fadf015e0658d --- /dev/null +++ b/textual_inversion_templates/style_filewords.txt @@ -0,0 +1,19 @@ +a painting of [filewords], art by [name] +a rendering of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +the painting of [filewords], art by [name] +a clean painting of [filewords], art by [name] +a dirty painting of [filewords], art by [name] +a dark painting of [filewords], art by [name] +a picture of [filewords], art by [name] +a cool painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a bright painting of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +a good painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a rendition of [filewords], art by [name] +a nice painting of [filewords], art by [name] +a small painting of [filewords], art by [name] +a weird painting of [filewords], art by [name] +a large painting of [filewords], art by [name] diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt new file mode 100644 index 0000000000000000000000000000000000000000..79f36aa0543fc2151b7f7e28725309c0c9a4912a --- /dev/null +++ b/textual_inversion_templates/subject.txt @@ -0,0 +1,27 @@ +a photo of a [name] +a rendering of a [name] +a cropped photo of the [name] +the photo of a [name] +a photo of a clean [name] +a photo of a dirty [name] +a dark photo of the [name] +a photo of my [name] +a photo of the cool [name] +a close-up photo of a [name] +a bright photo of the [name] +a cropped photo of a [name] +a photo of the [name] +a good photo of the [name] +a photo of one [name] +a close-up photo of the [name] +a rendition of the [name] +a photo of the clean [name] +a rendition of a [name] +a photo of a nice [name] +a good photo of a [name] +a photo of the nice [name] +a photo of the small [name] +a photo of the weird [name] +a photo of the large [name] +a photo of a cool [name] +a photo of a small [name] diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt new file mode 100644 index 0000000000000000000000000000000000000000..008652a6bf4277f12a1759f5f3c815ae754dcfcf --- /dev/null +++ b/textual_inversion_templates/subject_filewords.txt @@ -0,0 +1,27 @@ +a photo of a [name], [filewords] +a rendering of a [name], [filewords] +a cropped photo of the [name], [filewords] +the photo of a [name], [filewords] +a photo of a clean [name], [filewords] +a photo of a dirty [name], [filewords] +a dark photo of the [name], [filewords] +a photo of my [name], [filewords] +a photo of the cool [name], [filewords] +a close-up photo of a [name], [filewords] +a bright photo of the [name], [filewords] +a cropped photo of a [name], [filewords] +a photo of the [name], [filewords] +a good photo of the [name], [filewords] +a photo of one [name], [filewords] +a close-up photo of the [name], [filewords] +a rendition of the [name], [filewords] +a photo of the clean [name], [filewords] +a rendition of a [name], [filewords] +a photo of a nice [name], [filewords] +a good photo of a [name], [filewords] +a photo of the nice [name], [filewords] +a photo of the small [name], [filewords] +a photo of the weird [name], [filewords] +a photo of the large [name], [filewords] +a photo of a cool [name], [filewords] +a photo of a small [name], [filewords] diff --git a/webui-macos-env.sh b/webui-macos-env.sh new file mode 100644 index 0000000000000000000000000000000000000000..37cac4fb02c5ea53fa270336729b90ac2ef3ac74 --- /dev/null +++ b/webui-macos-env.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#################################################################### +# macOS defaults # +# Please modify webui-user.sh to change these instead of this file # +#################################################################### + +if [[ -x "$(command -v python3.10)" ]] +then + python_cmd="python3.10" +fi + +export install_dir="$HOME" +export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" +export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" +export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" +export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" +export PYTORCH_ENABLE_MPS_FALLBACK=1 + +#################################################################### diff --git a/webui-user.bat b/webui-user.bat new file mode 100644 index 0000000000000000000000000000000000000000..e5a257bef06f5bfcaff1c8b33c64a767eb8b3fe5 --- /dev/null +++ b/webui-user.bat @@ -0,0 +1,8 @@ +@echo off + +set PYTHON= +set GIT= +set VENV_DIR= +set COMMANDLINE_ARGS= + +call webui.bat diff --git a/webui-user.sh b/webui-user.sh new file mode 100644 index 0000000000000000000000000000000000000000..bfa53cb7c67083ec0a01bfa420269af4d85c6c94 --- /dev/null +++ b/webui-user.sh @@ -0,0 +1,46 @@ +#!/bin/bash +######################################################### +# Uncomment and change the variables below to your need:# +######################################################### + +# Install directory without trailing slash +#install_dir="/home/$(whoami)" + +# Name of the subdirectory +#clone_dir="stable-diffusion-webui" + +# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" +#export COMMANDLINE_ARGS="" + +# python3 executable +#python_cmd="python3" + +# git executable +#export GIT="git" + +# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) +#venv_dir="venv" + +# script to launch to start the app +#export LAUNCH_SCRIPT="launch.py" + +# install command for torch +#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" + +# Requirements file to use for stable-diffusion-webui +#export REQS_FILE="requirements_versions.txt" + +# Fixed git repos +#export K_DIFFUSION_PACKAGE="" +#export GFPGAN_PACKAGE="" + +# Fixed git commits +#export STABLE_DIFFUSION_COMMIT_HASH="" +#export TAMING_TRANSFORMERS_COMMIT_HASH="" +#export CODEFORMER_COMMIT_HASH="" +#export BLIP_COMMIT_HASH="" + +# Uncomment to enable accelerated launch +#export ACCELERATE="True" + +########################################### diff --git a/webui.bat b/webui.bat new file mode 100644 index 0000000000000000000000000000000000000000..209d972bd755293bc16b4c8b6fe0e5108b3244b3 --- /dev/null +++ b/webui.bat @@ -0,0 +1,85 @@ +@echo off + +if not defined PYTHON (set PYTHON=python) +if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") + + +set ERROR_REPORTING=FALSE + +mkdir tmp 2>NUL + +%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :check_pip +echo Couldn't launch python +goto :show_stdout_stderr + +:check_pip +%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :start_venv +if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr +%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :start_venv +echo Couldn't install pip +goto :show_stdout_stderr + +:start_venv +if ["%VENV_DIR%"] == ["-"] goto :skip_venv +if ["%SKIP_VENV%"] == ["1"] goto :skip_venv + +dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :activate_venv + +for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" +echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% +%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :activate_venv +echo Unable to create venv in directory "%VENV_DIR%" +goto :show_stdout_stderr + +:activate_venv +set PYTHON="%VENV_DIR%\Scripts\Python.exe" +echo venv %PYTHON% + +:skip_venv +if [%ACCELERATE%] == ["True"] goto :accelerate +goto :launch + +:accelerate +echo Checking for accelerate +set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" +if EXIST %ACCELERATE% goto :accelerate_launch + +:launch +%PYTHON% launch.py %* +pause +exit /b + +:accelerate_launch +echo Accelerating +%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py +pause +exit /b + +:show_stdout_stderr + +echo. +echo exit code: %errorlevel% + +for /f %%i in ("tmp\stdout.txt") do set size=%%~zi +if %size% equ 0 goto :show_stderr +echo. +echo stdout: +type tmp\stdout.txt + +:show_stderr +for /f %%i in ("tmp\stderr.txt") do set size=%%~zi +if %size% equ 0 goto :show_stderr +echo. +echo stderr: +type tmp\stderr.txt + +:endofscript + +echo. +echo Launch unsuccessful. Exiting. +pause diff --git a/webui.py b/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8b486aa730b9c297d009f0cc4a58f8103a1d0e --- /dev/null +++ b/webui.py @@ -0,0 +1,286 @@ +import os +import sys +import time +import importlib +import signal +import re +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from packaging import version + +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call + +import torch + +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__long_version__ = torch.__version__ + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks +import modules.codeformer_model as codeformer +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img + +import modules.lowvram +import modules.paths +import modules.scripts +import modules.sd_hijack +import modules.sd_models +import modules.sd_vae +import modules.txt2img +import modules.script_callbacks +import modules.textual_inversion.textual_inversion +import modules.progress + +import modules.ui +from modules import modelloader +from modules.shared import cmd_opts +import modules.hypernetworks.hypernetwork + + +if cmd_opts.server_name: + server_name = cmd_opts.server_name +else: + server_name = "0.0.0.0" if cmd_opts.listen else None + + +def check_versions(): + if shared.cmd_opts.skip_version_check: + return + + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + +def initialize(): + check_versions() + + extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) + + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + modules.scripts.load_scripts() + return + + modelloader.cleanup_models() + modules.sd_models.setup_model() + codeformer.setup_model(cmd_opts.codeformer_models_path) + gfpgan.setup_model(cmd_opts.gfpgan_models_path) + + modelloader.list_builtin_upscalers() + modules.scripts.load_scripts() + modelloader.load_upscalers() + + modules.sd_vae.refresh_vae_list() + + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) + + +def setup_cors(app): + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + + +def create_api(app): + from modules.api.api import Api + api = Api(app, queue_lock) + return api + + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if shared.state.need_restart: + shared.state.need_restart = False + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + +def api_only(): + initialize() + + app = FastAPI() + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + modules.script_callbacks.app_started_callback(None, app) + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + + +def webui(): + launch_api = cmd_opts.api + initialize() + + while 1: + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() + + modules.script_callbacks.before_ui_callback() + + shared.demo = modules.ui.create_ui() + + if cmd_opts.gradio_queue: + shared.demo.queue(64) + + gradio_auth_creds = [] + if cmd_opts.gradio_auth: + gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') + if cmd_opts.gradio_auth_path: + with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(',')] + + app, local_url, share_url = shared.demo.launch( + share=cmd_opts.share, + server_name=server_name, + server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + # after initial launch, disable --autolaunch for subsequent restarts + cmd_opts.autolaunch = False + + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for + # an attacker to trick the user into opening a malicious HTML page, which makes a request to the + # running web ui and do whatever the attacker wants, including installing an extension and + # running its code. We disable this here. Suggested by RyotaK. + app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + + setup_cors(app) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + modules.progress.setup_progress_api(app) + + if launch_api: + create_api(app) + + ui_extra_networks.add_pages_to_demo(app) + + modules.script_callbacks.app_started_callback(shared.demo, app) + + wait_on_server(shared.demo) + print('Restarting UI...') + + sd_samplers.set_samplers() + + modules.script_callbacks.script_unloaded_callback() + extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + + modelloader.forbid_loaded_nonbuiltin_upscalers() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + + modules.sd_models.list_models() + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + +if __name__ == "__main__": + if cmd_opts.nowebui: + api_only() + else: + webui() diff --git a/webui.sh b/webui.sh new file mode 100755 index 0000000000000000000000000000000000000000..8cdad22d310fed20f229b09d7a3160aeb1731a85 --- /dev/null +++ b/webui.sh @@ -0,0 +1,186 @@ +#!/usr/bin/env bash +################################################# +# Please do not make any changes to this file, # +# change the variables in webui-user.sh instead # +################################################# + +# If run from macOS, load defaults from webui-macos-env.sh +if [[ "$OSTYPE" == "darwin"* ]]; then + if [[ -f webui-macos-env.sh ]] + then + source ./webui-macos-env.sh + fi +fi + +# Read variables from webui-user.sh +# shellcheck source=/dev/null +if [[ -f webui-user.sh ]] +then + source ./webui-user.sh +fi + +# Set defaults +# Install directory without trailing slash +if [[ -z "${install_dir}" ]] +then + install_dir="/home/$(whoami)" +fi + +# Name of the subdirectory (defaults to stable-diffusion-webui) +if [[ -z "${clone_dir}" ]] +then + clone_dir="stable-diffusion-webui" +fi + +# python3 executable +if [[ -z "${python_cmd}" ]] +then + python_cmd="python3" +fi + +# git executable +if [[ -z "${GIT}" ]] +then + export GIT="git" +fi + +# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) +if [[ -z "${venv_dir}" ]] +then + venv_dir="venv" +fi + +if [[ -z "${LAUNCH_SCRIPT}" ]] +then + LAUNCH_SCRIPT="launch.py" +fi + +# this script cannot be run as root by default +can_run_as_root=0 + +# read any command line flags to the webui.sh script +while getopts "f" flag > /dev/null 2>&1 +do + case ${flag} in + f) can_run_as_root=1;; + *) break;; + esac +done + +# Disable sentry logging +export ERROR_REPORTING=FALSE + +# Do not reinstall existing pip packages on Debian/Ubuntu +export PIP_IGNORE_INSTALLED=0 + +# Pretty print +delimiter="################################################################" + +printf "\n%s\n" "${delimiter}" +printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" +printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" +printf "\n%s\n" "${delimiter}" + +# Do not run as root +if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]] +then + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 +else + printf "\n%s\n" "${delimiter}" + printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)" + printf "\n%s\n" "${delimiter}" +fi + +if [[ -d .git ]] +then + printf "\n%s\n" "${delimiter}" + printf "Repo already cloned, using it as install directory" + printf "\n%s\n" "${delimiter}" + install_dir="${PWD}/../" + clone_dir="${PWD##*/}" +fi + +# Check prerequisites +gpu_info=$(lspci 2>/dev/null | grep VGA) +case "$gpu_info" in + *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 + ;; + *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 + printf "\n%s\n" "${delimiter}" + printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" + printf "\n%s\n" "${delimiter}" + ;; + *) + ;; +esac +if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" +fi + +for preq in "${GIT}" "${python_cmd}" +do + if ! hash "${preq}" &>/dev/null + then + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}" + printf "\n%s\n" "${delimiter}" + exit 1 + fi +done + +if ! "${python_cmd}" -c "import venv" &>/dev/null +then + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 +fi + +cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; } +if [[ -d "${clone_dir}" ]] +then + cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } +else + printf "\n%s\n" "${delimiter}" + printf "Clone stable-diffusion-webui" + printf "\n%s\n" "${delimiter}" + "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}" + cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } +fi + +printf "\n%s\n" "${delimiter}" +printf "Create and activate python venv" +printf "\n%s\n" "${delimiter}" +cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } +if [[ ! -d "${venv_dir}" ]] +then + "${python_cmd}" -m venv "${venv_dir}" + first_launch=1 +fi +# shellcheck source=/dev/null +if [[ -f "${venv_dir}"/bin/activate ]] +then + source "${venv_dir}"/bin/activate +else + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 +fi + +if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] +then + printf "\n%s\n" "${delimiter}" + printf "Accelerating launch.py..." + printf "\n%s\n" "${delimiter}" + exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" +else + printf "\n%s\n" "${delimiter}" + printf "Launching launch.py..." + printf "\n%s\n" "${delimiter}" + exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" +fi