diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..fbcb3f4364cf73b552c0f06abea77439cd3a7b7e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +# This is a sample Dockefile that builds a runtime container and runs the sample Gradio app. +# Note, you must pass in the pretrained models when you run the container. + +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 + +WORKDIR /workspace + +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python-is-python3 \ + python3-pip \ + python3.10-venv \ + libgl1 \ + libgl1-mesa-glx \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +ADD requirements.txt . + +RUN pip install -r requirements.txt + +ADD . . + +RUN chmod a+x docker-entrypoint.sh + +ENV DEMO_PORT=12345 +ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/app/app_pixart_dmd.py b/app/app_pixart_dmd.py new file mode 100644 index 0000000000000000000000000000000000000000..745665431d587a9937025b4fc1ffde0ffd84046f --- /dev/null +++ b/app/app_pixart_dmd.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python +from __future__ import annotations +import argparse +import os +import sys +from pathlib import Path + +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) +import random +import gradio as gr +import numpy as np +import uuid +from diffusers import ConsistencyDecoderVAE, PixArtAlphaPipeline, Transformer2DModel, DDPMScheduler +import torch +from typing import Tuple +from datetime import datetime +from scripts.diffusers_patches import pipeline_pixart_alpha_call + +DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png) + # PixArt-Alpha One Step 512px + #### [PixArt-Alpha-DMD 512px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-Alpha-DMD-XL-2-512x512](https://huggingface.co/PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512) checkpoint. + #### English prompts ONLY; 提示词仅限英文 + ### We only use 8 V100 GPUs for PixArt-DMD training. There's still plenty of room for improvement. + """ +if not torch.cuda.is_available(): + DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" + +MAX_SEED = np.iinfo(np.int32).max +CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1" +MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000")) +USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" +ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" +PORT = int(os.getenv("DEMO_PORT", "15432")) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +style_list = [ + { + "name": "(No style)", + "prompt": "{prompt}", + "negative_prompt": "", + }, + { + "name": "Cinematic", + "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", + "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", + }, + { + "name": "Photographic", + "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", + "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", + }, + { + "name": "Anime", + "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", + "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast", + }, + { + "name": "Manga", + "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", + "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style", + }, + { + "name": "Digital Art", + "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", + "negative_prompt": "photo, photorealistic, realism, ugly", + }, + { + "name": "Pixel art", + "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", + "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic", + }, + { + "name": "Fantasy art", + "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", + "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white", + }, + { + "name": "Neonpunk", + "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", + "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", + }, + { + "name": "3D Model", + "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", + "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting", + }, +] + +styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "(No style)" +SCHEDULE_NAME = ["PixArt-DMD"] +DEFAULT_SCHEDULE_NAME = "PixArt-DMD" +NUM_IMAGES_PER_PROMPT = 2 + + +def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: + p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + if not negative: + negative = "" + return p.replace("{prompt}", positive), n + negative + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', default="PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", type=str) + parser.add_argument( + '--pipeline_load_from', default="PixArt-alpha/PixArt-XL-2-1024-MS", type=str, + help="Download for loading text_encoder, " + "tokenizer and vae from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS") + parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5') + return parser.parse_args() + + +args = get_args() + +if torch.cuda.is_available(): + weight_dtype = torch.float16 + T5_token_max_length = args.T5_token_max_length + model_path = args.model_path + if 'Sigma' in args.model_path: + T5_token_max_length = 300 + + pipe = PixArtAlphaPipeline.from_pretrained( + args.pipeline_load_from, + transformer=None, + torch_dtype=weight_dtype, + ) + pipe.transformer = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=weight_dtype) + pipe.scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + + print("Changing __call__ method of PixArtAlphaPipeline using scripts.diffusers_patches.pipeline_pixart_alpha_call") + setattr(PixArtAlphaPipeline, '__call__', pipeline_pixart_alpha_call) + + if os.getenv('CONSISTENCY_DECODER', False): + print("Using DALL-E 3 Consistency Decoder") + pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) + + if ENABLE_CPU_OFFLOAD: + pipe.enable_model_cpu_offload() + else: + pipe.to(device) + print("Loaded on Device!") + + # speed-up T5 + pipe.text_encoder.to_bettertransformer() + + if USE_TORCH_COMPILE: + pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True) + print("Model Compiled!") + + +def save_image(img, seed=''): + unique_name = f"{str(uuid.uuid4())}_{seed}.png" + save_path = os.path.join(f'output/online_demo_img/{datetime.now().date()}') + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(save_path, exist_ok=True) + unique_name = os.path.join(save_path, unique_name) + img.save(unique_name) + return unique_name + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +@torch.no_grad() +@torch.inference_mode() +def generate( + prompt: str, + negative_prompt: str = "", + style: str = DEFAULT_STYLE_NAME, + use_negative_prompt: bool = False, + num_imgs: int = 1, + seed: int = 0, + width: int = 1024, + height: int = 1024, + randomize_seed: bool = False, + use_resolution_binning: bool = True, + progress=gr.Progress(track_tqdm=True), +): + seed = int(randomize_seed_fn(seed, randomize_seed)) + generator = torch.Generator().manual_seed(seed) + print(f"{PORT}: {model_path}") + print(prompt) + + if not use_negative_prompt: + negative_prompt = None # type: ignore + prompt, negative_prompt = apply_style(style, prompt, negative_prompt) + + images = pipe( + prompt=prompt, + timesteps=[400], + width=width, + height=height, + guidance_scale=1, + num_inference_steps=1, + generator=generator, + num_images_per_prompt=num_imgs, + use_resolution_binning=use_resolution_binning, + output_type="pil", + max_sequence_length=T5_token_max_length, + ).images + + image_paths = [save_image(img, seed) for img in images] + print(image_paths) + return image_paths, seed + + +examples = [ + "A small cactus with a happy face in the Sahara desert.", + "an astronaut sitting in a diner, eating fries, cinematic, analog film", + "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.", + "stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.", + "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.", + "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background", + "Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism", + "anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur", + "The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8", +] + +with gr.Blocks(css="scripts/style.css") as demo: + gr.Markdown(DESCRIPTION) + gr.DuplicateButton( + value="Duplicate Space for private use", + elem_id="duplicate-button", + visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", + ) + with gr.Row(equal_height=False): + with gr.Group(): + with gr.Row(): + prompt = gr.Text( + label="Prompt", + show_label=False, + max_lines=1, + placeholder="Enter your prompt", + container=False, + ) + run_button = gr.Button("Run", scale=0) + result = gr.Gallery(label="Result", show_label=False) + # with gr.Accordion("Advanced options", open=False): + with gr.Group(): + with gr.Row(): + use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True) + with gr.Row(visible=True): + schedule = gr.Radio( + show_label=True, + container=True, + interactive=True, + choices=SCHEDULE_NAME, + value=DEFAULT_SCHEDULE_NAME, + label="Sampler Schedule", + visible=True, + ) + num_imgs = gr.Slider( + label="Num Images", + minimum=1, + maximum=8, + step=1, + value=NUM_IMAGES_PER_PROMPT, + ) + style_selection = gr.Radio( + show_label=True, + container=True, + interactive=True, + choices=STYLE_NAMES, + value=DEFAULT_STYLE_NAME, + label="Image Style", + ) + negative_prompt = gr.Text( + label="Negative prompt", + max_lines=1, + placeholder="Enter a negative prompt", + visible=True, + ) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + with gr.Row(visible=True): + width = gr.Slider( + label="Width", + minimum=256, + maximum=MAX_IMAGE_SIZE, + step=32, + value=512, + ) + height = gr.Slider( + label="Height", + minimum=256, + maximum=MAX_IMAGE_SIZE, + step=32, + value=512, + ) + + gr.Examples( + examples=examples, + inputs=prompt, + outputs=[result, seed], + fn=generate, + cache_examples=CACHE_EXAMPLES, + ) + + use_negative_prompt.change( + fn=lambda x: gr.update(visible=x), + inputs=use_negative_prompt, + outputs=negative_prompt, + api_name=False, + ) + + gr.on( + triggers=[ + prompt.submit, + negative_prompt.submit, + run_button.click, + ], + fn=generate, + inputs=[ + prompt, + negative_prompt, + style_selection, + use_negative_prompt, + num_imgs, + seed, + width, + height, + schedule, + randomize_seed, + ], + outputs=[result, seed], + api_name="run", + ) + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=PORT, debug=True) diff --git a/app/app_pixart_sigma.py b/app/app_pixart_sigma.py new file mode 100644 index 0000000000000000000000000000000000000000..54f1d8c68455e220976ef1d02766f2f5bc5252a3 --- /dev/null +++ b/app/app_pixart_sigma.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python +from __future__ import annotations +import argparse +import os +import sys +from pathlib import Path +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) +import random +import gradio as gr +import numpy as np +import uuid +from diffusers import ConsistencyDecoderVAE, DPMSolverMultistepScheduler, Transformer2DModel, AutoencoderKL +import torch +from typing import Tuple +from datetime import datetime +from diffusion.sa_solver_diffusers import SASolverScheduler +from peft import PeftModel +from scripts.diffusers_patches import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline + + +DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png) + # PixArt-Sigma 1024px + #### [PixArt-Sigma 1024px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint. + #### English prompts ONLY; 提示词仅限英文 + ### You may change the DPM-Solver inference steps from 14 to 20, or DPM-Solver Guidance scale from 4.5 to 3.5 if you didn't get satisfied results. + """ +if not torch.cuda.is_available(): + DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" + +MAX_SEED = np.iinfo(np.int32).max +CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1" +MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000")) +USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" +ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" +PORT = int(os.getenv("DEMO_PORT", "15432")) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +style_list = [ + { + "name": "(No style)", + "prompt": "{prompt}", + "negative_prompt": "", + }, + { + "name": "Cinematic", + "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", + "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", + }, + { + "name": "Photographic", + "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", + "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", + }, + { + "name": "Anime", + "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", + "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast", + }, + { + "name": "Manga", + "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", + "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style", + }, + { + "name": "Digital Art", + "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", + "negative_prompt": "photo, photorealistic, realism, ugly", + }, + { + "name": "Pixel art", + "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", + "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic", + }, + { + "name": "Fantasy art", + "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", + "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white", + }, + { + "name": "Neonpunk", + "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", + "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", + }, + { + "name": "3D Model", + "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", + "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting", + }, +] + + +styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "(No style)" +SCHEDULE_NAME = ["DPM-Solver", "SA-Solver"] +DEFAULT_SCHEDULE_NAME = "DPM-Solver" +NUM_IMAGES_PER_PROMPT = 1 + +def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: + p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + if not negative: + negative = "" + return p.replace("{prompt}", positive), n + negative + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--is_lora', action='store_true', help='enable lora ckpt loading') + parser.add_argument('--repo_id', default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", type=str) + parser.add_argument('--lora_repo_id', default=None, type=str) + parser.add_argument('--model_path', default=None, type=str) + parser.add_argument( + '--pipeline_load_from', default="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", type=str, + help="Download for loading text_encoder, tokenizer and vae " + "from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS") + parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5') + return parser.parse_args() + + +args = get_args() + +if torch.cuda.is_available(): + weight_dtype = torch.float16 + T5_token_max_length = args.T5_token_max_length + model_path = args.model_path + if 'Sigma' in args.model_path: + T5_token_max_length = 300 + + # tmp patches for diffusers PixArtSigmaPipeline Implementation + print( + "Changing _init_patched_inputs method of diffusers.models.Transformer2DModel " + "using scripts.diffusers_patches.pixart_sigma_init_patched_inputs") + setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs) + + if not args.is_lora: + transformer = Transformer2DModel.from_pretrained( + model_path, + subfolder='transformer', + torch_dtype=weight_dtype, + ) + pipe = PixArtSigmaPipeline.from_pretrained( + args.pipeline_load_from, + transformer=transformer, + torch_dtype=weight_dtype, + use_safetensors=True, + ) + else: + assert args.lora_repo_id is not None + transformer = Transformer2DModel.from_pretrained(args.repo_id, subfolder="transformer", torch_dtype=torch.float16) + transformer = PeftModel.from_pretrained(transformer, args.lora_repo_id) + pipe = PixArtSigmaPipeline.from_pretrained( + args.repo_id, + transformer=transformer, + torch_dtype=torch.float16, + use_safetensors=True, + ) + del transformer + + + if os.getenv('CONSISTENCY_DECODER', False): + print("Using DALL-E 3 Consistency Decoder") + pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) + + if ENABLE_CPU_OFFLOAD: + pipe.enable_model_cpu_offload() + else: + pipe.to(device) + print("Loaded on Device!") + + # speed-up T5 + pipe.text_encoder.to_bettertransformer() + + if USE_TORCH_COMPILE: + pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True) + print("Model Compiled!") + + +def save_image(img, seed=''): + unique_name = f"{str(uuid.uuid4())}_{seed}.png" + save_path = os.path.join(f'output/online_demo_img/{datetime.now().date()}') + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(save_path, exist_ok=True) + unique_name = os.path.join(save_path, unique_name) + img.save(unique_name) + return unique_name + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +@torch.no_grad() +@torch.inference_mode() +def generate( + prompt: str, + negative_prompt: str = "", + style: str = DEFAULT_STYLE_NAME, + use_negative_prompt: bool = False, + num_imgs: int = 1, + seed: int = 0, + width: int = 1024, + height: int = 1024, + schedule: str = 'DPM-Solver', + dpms_guidance_scale: float = 4.5, + sas_guidance_scale: float = 3, + dpms_inference_steps: int = 20, + sas_inference_steps: int = 25, + randomize_seed: bool = False, + use_resolution_binning: bool = True, + progress=gr.Progress(track_tqdm=True), +): + seed = int(randomize_seed_fn(seed, randomize_seed)) + generator = torch.Generator().manual_seed(seed) + print(f"{PORT}: {model_path}") + print(prompt) + + if schedule == 'DPM-Solver': + if not isinstance(pipe.scheduler, DPMSolverMultistepScheduler): + pipe.scheduler = DPMSolverMultistepScheduler() + num_inference_steps = dpms_inference_steps + guidance_scale = dpms_guidance_scale + elif schedule == "SA-Solver": + if not isinstance(pipe.scheduler, SASolverScheduler): + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config, algorithm_type='data_prediction', tau_func=lambda t: 1 if 200 <= t <= 800 else 0, predictor_order=2, corrector_order=2) + num_inference_steps = sas_inference_steps + guidance_scale = sas_guidance_scale + else: + raise ValueError(f"Unknown schedule: {schedule}") + + if not use_negative_prompt: + negative_prompt = None # type: ignore + prompt, negative_prompt = apply_style(style, prompt, negative_prompt) + + images = pipe( + prompt=prompt, + width=width, + height=height, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + num_images_per_prompt=num_imgs, + use_resolution_binning=use_resolution_binning, + output_type="pil", + max_sequence_length=args.T5_token_max_length, + ).images + + image_paths = [save_image(img, seed) for img in images] + print(image_paths) + return image_paths, seed + + +examples = [ + "A small cactus with a happy face in the Sahara desert.", + "an astronaut sitting in a diner, eating fries, cinematic, analog film", + "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.", + "stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.", + "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.", + "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background", + "Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism", + "anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur", + "The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8", +] + +with gr.Blocks(css="scripts/style.css") as demo: + gr.Markdown(DESCRIPTION) + gr.DuplicateButton( + value="Duplicate Space for private use", + elem_id="duplicate-button", + visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", + ) + with gr.Row(equal_height=False): + with gr.Group(): + with gr.Row(): + prompt = gr.Text( + label="Prompt", + show_label=False, + max_lines=1, + placeholder="Enter your prompt", + container=False, + ) + run_button = gr.Button("Run", scale=0) + result = gr.Gallery(label="Result", show_label=False) + # with gr.Accordion("Advanced options", open=False): + with gr.Group(): + with gr.Row(): + use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True) + with gr.Row(visible=True): + schedule = gr.Radio( + show_label=True, + container=True, + interactive=True, + choices=SCHEDULE_NAME, + value=DEFAULT_SCHEDULE_NAME, + label="Sampler Schedule", + visible=True, + ) + num_imgs = gr.Slider( + label="Num Images", + minimum=1, + maximum=8, + step=1, + value=1, + ) + style_selection = gr.Radio( + show_label=True, + container=True, + interactive=True, + choices=STYLE_NAMES, + value=DEFAULT_STYLE_NAME, + label="Image Style", + ) + negative_prompt = gr.Text( + label="Negative prompt", + max_lines=1, + placeholder="Enter a negative prompt", + visible=True, + ) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + with gr.Row(visible=True): + width = gr.Slider( + label="Width", + minimum=256, + maximum=MAX_IMAGE_SIZE, + step=32, + value=1024, + ) + height = gr.Slider( + label="Height", + minimum=256, + maximum=MAX_IMAGE_SIZE, + step=32, + value=1024, + ) + with gr.Row(): + dpms_guidance_scale = gr.Slider( + label="DPM-Solver Guidance scale", + minimum=1, + maximum=10, + step=0.1, + value=4.5, + ) + dpms_inference_steps = gr.Slider( + label="DPM-Solver inference steps", + minimum=5, + maximum=40, + step=1, + value=14, + ) + with gr.Row(): + sas_guidance_scale = gr.Slider( + label="SA-Solver Guidance scale", + minimum=1, + maximum=10, + step=0.1, + value=3, + ) + sas_inference_steps = gr.Slider( + label="SA-Solver inference steps", + minimum=10, + maximum=40, + step=1, + value=25, + ) + + gr.Examples( + examples=examples, + inputs=prompt, + outputs=[result, seed], + fn=generate, + cache_examples=CACHE_EXAMPLES, + ) + + use_negative_prompt.change( + fn=lambda x: gr.update(visible=x), + inputs=use_negative_prompt, + outputs=negative_prompt, + api_name=False, + ) + + gr.on( + triggers=[ + prompt.submit, + negative_prompt.submit, + run_button.click, + ], + fn=generate, + inputs=[ + prompt, + negative_prompt, + style_selection, + use_negative_prompt, + num_imgs, + seed, + width, + height, + schedule, + dpms_guidance_scale, + sas_guidance_scale, + dpms_inference_steps, + sas_inference_steps, + randomize_seed, + ], + outputs=[result, seed], + api_name="run", + ) + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=PORT, debug=True) diff --git a/asset/PixArt.svg b/asset/PixArt.svg new file mode 100644 index 0000000000000000000000000000000000000000..8f31930d5d88cf7ab79dc3edebdd7604fa6c26f1 --- /dev/null +++ b/asset/PixArt.svg @@ -0,0 +1,96 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/asset/docs/pixart.md b/asset/docs/pixart.md new file mode 100644 index 0000000000000000000000000000000000000000..06feb0456f0bfd0a0a64e38b6aa097cfadef3727 --- /dev/null +++ b/asset/docs/pixart.md @@ -0,0 +1,112 @@ + + +[//]: # ((reference from [hugging Face](https://github.com/huggingface/diffusers/blob/docs/8bit-inference-pixart/docs/source/en/api/pipelines/pixart.md))) + +## Running the `PixArtAlphaPipeline` in under 8GB GPU VRAM + +It is possible to run the [`PixArtAlphaPipeline`] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example. + +First, install the `bitsandbytes` library: + +```bash +pip install -U bitsandbytes +``` + +Then load the text encoder in 8-bit: + +```python +from transformers import T5EncoderModel +from diffusers import PixArtAlphaPipeline + +text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="text_encoder", + load_in_8bit=True, + device_map="auto", + +) +pipe = PixArtAlphaPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + text_encoder=text_encoder, + transformer=None, + device_map="auto" +) +``` + +Now, use the `pipe` to encode a prompt: + +```python +with torch.no_grad(): + prompt = "cute cat" + prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt) + +del text_encoder +del pipe +flush() +``` + +`flush()` is just a utility function to clear the GPU VRAM and is implemented like so: + +```python +import gc + +def flush(): + gc.collect() + torch.cuda.empty_cache() +``` + +Then compute the latents providing the prompt embeddings as inputs: + +```python +pipe = PixArtAlphaPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + text_encoder=None, + torch_dtype=torch.float16, +).to("cuda") + +latents = pipe( + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_images_per_prompt=1, + output_type="latent", +).images + +del pipe.transformer +flush() +``` + +Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded. + +Once the latents are computed, pass it off the VAE to decode into a real image: + +```python +with torch.no_grad(): + image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] +image = pipe.image_processor.postprocess(image, output_type="pil") +image.save("cat.png") +``` + +All of this, put together, should allow you to run [`PixArtAlphaPipeline`] under 8GB GPU VRAM. + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png) + +Find the script [here](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e) that can be run end-to-end to report the memory being used. + + + +Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit. + + \ No newline at end of file diff --git a/asset/examples.py b/asset/examples.py new file mode 100644 index 0000000000000000000000000000000000000000..48acf40d78d6af6626427d81577752a7cb38fcc7 --- /dev/null +++ b/asset/examples.py @@ -0,0 +1,36 @@ + +examples = [ + [ + "A small cactus with a happy face in the Sahara desert.", + "dpm-solver", 20, 4.5, + ], + [ + "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history" + "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits " + "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret " + "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile " + "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and " + "the Parisian streets and city in the background, depth of field, cinematic 35mm film.", + "dpm-solver", 20, 4.5, + ], + [ + "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. " + "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. " + "The quote 'Find the universe within you' is etched in bold letters across the horizon." + "blue and pink, brilliantly illuminated in the background.", + "dpm-solver", 20, 4.5, + ], + [ + "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.", + "dpm-solver", 20, 4.5, + ], + [ + "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.", + "dpm-solver", 20, 4.5, + ], + [ + "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, " + "national geographic photo, 8k resolution, crayon art, interactive artwork", + "dpm-solver", 20, 4.5, + ] +] diff --git a/asset/logo-sigma.png b/asset/logo-sigma.png new file mode 100644 index 0000000000000000000000000000000000000000..dc2fa7c8defb50b7deae3023186e78c812e96671 Binary files /dev/null and b/asset/logo-sigma.png differ diff --git a/asset/logo.png b/asset/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..fe8f6fc0da5b1b2aa322eeb1437546d21e55a6e2 Binary files /dev/null and b/asset/logo.png differ diff --git a/asset/samples.txt b/asset/samples.txt new file mode 100644 index 0000000000000000000000000000000000000000..be921ef7af8bd237cebc92bd5f70d3d3ec57ad07 --- /dev/null +++ b/asset/samples.txt @@ -0,0 +1,120 @@ +A small cactus with a happy face in the Sahara desert. +Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail. +beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background +stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background. +nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph. +Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism +anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur +The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8 +Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens. +8k uhd A man looks up at the starry sky, lonely and ethereal, Minimalism, Chaotic composition Op Art +A middle-aged woman of Asian descent, her dark hair streaked with silver, appears fractured and splintered, intricately embedded within a sea of broken porcelain. The porcelain glistens with splatter paint patterns in a harmonious blend of glossy and matte blues, greens, oranges, and reds, capturing her dance in a surreal juxtaposition of movement and stillness. Her skin tone, a light hue like the porcelain, adds an almost mystical quality to her form. +A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden. +A alpaca made of colorful building blocks, cyberpunk +A baby painter trying to draw very simple picture, white background +A boy and a girl fall in love +A dog that has been meditating all the time +A man is sitting in a chair with his chin resting on his hand. The chair, along with the man's feet, are submerged in the sea. Strikingly, the man's back is on fire. +A painter study hard to learn how to draw with many concepts in the air, white background +A painter with low quality, white background, pixel art +A person standing on the desert, desert waves, gossip illustration, half red, half blue, abstract image of sand, clear style, trendy illustration, outdoor, top view, clear style, precision art, ultra high definition image +A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster. +A sureal parallel world where mankind avoid extinction by preserving nature, epic trees, water streams, various flowers, intricate details, rich colors, rich vegetation, cinematic, symmetrical, beautiful lighting, V-Ray render, sun rays, magical lights, photography +A woman is shopping for fresh produce at the farmer's market. +A worker that looks like a mixture of cow and horse is working hard to type code +A young man dressed in ancient Chinese clothing, Asian people, White robe, Handsome, Hand gestures forming a spell, Martial arts and fairy-like vibe, Carrying a legendary-level giant sword on the back, Game character, Surrounded by runes, Cyberpunk style, neon lights, best quality, masterpiece, cg, hdr, high-definition, extremely detailed, photorealistic, epic, character design, detailed face, superhero, hero, detailed UHD, real-time, vfx, 3D rendering, 8k +An alien octopus floats through a protal reading a newspaper +An epressive oil painting of a basketbal player dunking, depicted as an explosion of a nebula +art collection style and fashion shoot, in the style of made of glass, dark blue and light pink, paul rand, solarpunk, camille vivier, beth didonato hair, barbiecore, hyper-realistic +artistic +beautiful secen +Crocodile in a sweater +Design a letter A, 3D stereoscopic Ice material Interior light blue Conceptual product design Futuristic Blind box toy Handcrafted Exquisite 3D effect Full body display Ultra-high precision Ultra-detailed Perfect lighting OC Renderer Blender 8k Ultra-sharp Ultra-noise reduction +Floating,colossal,futuristic statue in the sky, awe-inspiring and serenein the style of Stuart Lippincott:2with detailed composition and subtle geometric elements.This sanctuary-ike atmosphere features crisp clarity and soft amber tones.In contrasttiny human figures surround the statueThe pieceincorporates flowing draperiesreminiscent of Shwedoff and Philip McKay's stylesemphasizing thejuxtaposition between the powerful presence of the statue and thevulnerability of the minuscule human figuresshwedoff +knolling of a drawing tools for painter +Leonardo da Vinci's Last Supper content, Van Goph's Starry Night Style +Luffy from ONEPIECE, handsome face, fantasy +photography shot through an outdoor window of a coffee shop with neon sign lighting, window glares and reflections, depth of field, {little girl with red hair sitting at a table, portrait, kodak portra 800,105 mm f1.8 +poster of a mechanical cat, techical Schematics viewed from front and side view on light white blueprint paper, illustartion drafting style, illustation, typography, conceptual art, dark fantasy steampunk, cinematic, dark fantasy +The girl in the car is filled with goldfish and flowers, goldfish can fly, Kawaguchi Renko's art, natural posture, holiday dadcore, youthful energy and pressure, body stretching, goldfish simulation movies in the sky, super details, and dreamy high photography. Colorful. Covered by water and goldfish, indoor scene, close-up shot in XT4 movie +The image features a woman wearing a red shirt with an icon. She appears to be posing for the camera, and her outfit includes a pair of jeans. The woman seems to be in a good mood, as she is smiling. The background of the image is blurry, focusing more on the woman and her attire. +The towel was on top of the hard counter. +A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds. +I want to supplement vitamin c, please help me paint related food. +A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the window. +A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape. +A blue jay standing on a large basket of rainbow macarons. +A bucket bag made of blue suede. The bag is decorated with intricate golden paisley patterns. The handle of the bag is made of rubies and pearls. +An alien octopus floats through a portal reading a newspaper. +bird's eye view of a city. +beautiful scene +A 2D animation of a folk music band composed of anthropomorphic autumn leaves, each playing traditional bluegrass instruments, amidst a rustic forest setting dappled with the soft light of a harvest moon. +In front of a deep black backdrop, a figure of middle years, her Tongan skin rich and glowing, is captured mid-twirl, her curly hair flowing like a storm behind her. Her attire resembles a whirlwind of marble and porcelain fragments. Illuminated by the gleam of scattered porcelain shards, creating a dreamlike atmosphere, the dancer manages to appear fragmented, yet maintains a harmonious and fluid form. +Digital illustration of a beach scene crafted from yarn. The sandy beach is depicted with beige yarn, waves are made of blue and white yarn crashing onto the shore. A yarn sun sets on the horizon, casting a warm glow. Yarn palm trees sway gently, and little yarn seashells dot the shoreline. +Illustration of a chic chair with a design reminiscent of a pumpkin’s form, with deep orange cushioning, in a stylish loft setting. +A detailed oil painting of an old sea captain, steering his ship through a storm. Saltwater is splashing against his weathered face, determination in his eyes. Twirling malevolent clouds are seen above and stern waves threaten to submerge the ship while seagulls dive and twirl through the chaotic landscape. Thunder and lights embark in the distance, illuminating the scene with an eerie green glow. +An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. The quote 'Find the universe within you' is etched in bold letters across the horizon. +A modern architectural building with large glass windows, situated on a cliff overlooking a serene ocean at sunset +photo of an ancient shipwreck nestled on the ocean floor. Marine plants have claimed the wooden structure, and fish swim in and out of its hollow spaces. Sunken treasures and old cannons are scattered around, providing a glimpse into the past +A 3D render of a coffee mug placed on a window sill during a stormy day. The storm outside the window is reflected in the coffee, with miniature lightning bolts and turbulent waves seen inside the mug. The room is dimly lit, adding to the dramatic atmosphere.A minimap diorama of a cafe adorned with indoor plants. Wooden beams crisscross above, and a cold brew station stands out with tiny bottles and glasses. +An antique botanical illustration drawn with fine lines and a touch of watercolour whimsy, depicting a strange lily crossed with a Venus flytrap, its petals poised as if ready to snap shut on any unsuspecting insects.An illustration inspired by old-world botanical sketches blends a cactus with lilac blooms into a Möbius strip, using detailed lines and subtle watercolor touches to capture nature's diverse beauty and mathematical intrigue. +An ink sketch style illustration of a small hedgehog holding a piece of watermelon with its tiny paws, taking little bites with its eyes closed in delight.Photo of a lychee-inspired spherical chair, with a bumpy white exterior and plush interior, set against a tropical wallpaper. +3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background +professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest. +an astronaut sitting in a diner, eating fries, cinematic, analog film +Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering. +Ethereal fantasy concept art of thunder god with hammer. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy. +A Japanese girl walking along a path, surrounding by blooming oriental cherry, pink petal slowly falling down to the ground +A Ukiyoe style painting, an astronaut riding a unicorn, In the background there is an ancient Japanese architecture +Steampunk makeup, in the style of vray tracing, colorful impasto, uhd image, indonesian art, fine feather details with bright red and yellow and green and pink and orange colours, intricate patterns and details, dark cyan and amber makeup. Rich colourful plumes. Victorian style. +A cute teddy bear in front of a plain white wall, warm and brown fur, soft and fluffy +The beautiful scenery of Seattle, painting by Al Capp. +Photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang. +An astronaut riding a horse on the moon, oil painting by Van Gogh. +A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky +Realistic oil painting of a stunning model merged in multicolor splash made of finely torn paper, eye contact, walking with class in a street. +a chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic,futuristic style, gray and green light, movie lighting, 32K HD +a handsome 24 years old boy in the middle with sky color background wearing eye glasses, it's super detailed with anime style, it's a portrait with delicated eyes and nice looking face +a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, national geographic photo, 8k resolution, crayon art, interactive artwork +3D rendering miniature scene design, Many tall buildings, A winding urban road runs through the middle,a lot of cars on the road, transparent material pipeline transports Materials, ,there are many people around, in thestyle of light orange and yellow, graphic design- inspired illustrations, classic still-life, beeple, josan gon-zalez, manga-influenced, miniature dioramas, in thestyle of playful and whimsical designs, graphic de-sign-inspired illustrations, minimalism, hyperrealismlomo lca, e-commerce C4D style, e-commerce posterUl, UX, octane render, blender +Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works +A cute orange kitten sliding down an aqua slide. happy excited. 16mm lens in front. we see his excitement and scared in the eye. vibrant colors. water splashing on the lens +Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. +A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures. +An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film. +A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in. +A New Zealand female business owner stands and is happy that his business is growing by having good VoIP and broadband supplied by Voyager Internet. This business owner is dressed semi casual and is standing with a funky office space in the background. The image is light and bright and is well lit. This image needs to be shot like a professional photo shoot using a Canon R6 with high quality 25mm lens. This image has a shallow depth of field +The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8 +Editorial photoshoot of a old woman, high fashion 2000s fashion +Mural Painted of Prince in Purple Rain on side of 5 story brick building next to zen garden vacant lot in the urban center district, rgb +Cozy Scandinavian living room, there is a cat sleeping on the couch, depth of field +Street style centered straight shot photo shot on Afga Vista 400, lense 50mm, of a two women,skin to skin touch face, emotion, hughing, natural blond hair, natural features, ultra detailed, skin texture, Rembrandt light, soft shadows +Frog, in forest, colorful, no watermark, no signature, in forest, 8k +selfie of a woman and her lion cub on the plains +A fisherman fixing his net sitting on a beautiful tropical beach at sunset with bending palm trees fishing gear and a small boat on shore +Coast, decorative painting, horizon, modern, fashionable, full of abstract feeling, full of imagination, the picture reveals the sense of time passing, there is a feeling of the end of the world +A close up of a branch of a tree and a golden bug on the top a leaf, shutterstock contest winner,ecological art, depth of field, shallow depth of field, macro photography +Outdoor style fashion photo, full – body shot of a man with short brown hair, happy and smiling, he is standing on his hipster bicycle wearing a light blue long sleeved blouse with closed buttons and dark blue jeans trousers, in the background the exterior of an Aldi store, fully lit background, natural afternoon lighting +beautiful woman sniper, wearing soviet army uniform, one eye on sniper lens, in snow ground +A very attractive and natural woman, sitting on a yoka mat, breathing, eye closed, no make up, intense satisfaction, she looks like she is intensely relaxed, yoga class, sunrise, 35mm +a close up of a helmet on a person, digital art, inspired by Han Gan, cloisonnism, female, victorian armor, ultramarine, best of behance, anton fadeev 8 k, fined detail, sci-fi character, elegant armor, fantasy art behance +a melting apple +yellow FIAT 500 Cinquecento 1957 driving through liechtenstein castle with a lot of banknotes scattered behind ,filled with wads of cash , car color yellow, license plate R-33 +tented resort in the desert, rocky and sandy terrain, 5 star hotel, beautiful landscape, landscape photography, depth of view, Fujifilm GFX 100 –uplight +Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm. +Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots. +Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light. +Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park. +Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers. +Game-Art - An island with different geographical properties and multiple small cities floating in space +Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. +A car made out of vegetables. +A serene lakeside during autumn with trees displaying a palette of fiery colors. +A realistic landscape shot of the Northern Lights dancing over a snowy mountain range in Iceland. +A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky. +Drone view of waves crashing against the rugged cliffs along Big Sur’s Garay Point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. +A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed. +Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture. +Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works. +A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. +A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair, and wears a pair of earrings. Behind are blurred city buildings and streets. \ No newline at end of file diff --git a/configs/PixArt_xl2_internal.py b/configs/PixArt_xl2_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..8c44f7d658035dbfdefa322a9522884850db1b54 --- /dev/null +++ b/configs/PixArt_xl2_internal.py @@ -0,0 +1,79 @@ +data_root = '/data/data' +data = dict(type='InternalData', root='images', image_list_json=['data_info.json'], transform='default_train', load_vae_feat=True, load_t5_feat=True) +image_size = 256 # the generated image resolution +train_batch_size = 32 +eval_batch_size = 16 +use_fsdp=False # if use FSDP mode +valid_num=0 # take as valid aspect-ratio when sample number >= valid_num +fp32_attention = True +# model setting +model = 'PixArt_XL_2' +aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] +multi_scale = False # if use multiscale dataset model training +pe_interpolation = 1.0 # positional embedding interpolation +# qk norm +qk_norm = False +# kv token compression +kv_compress = False +kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], +} + +# training setting +num_workers=4 +train_sampling_steps = 1000 +visualize=False +eval_sampling_steps = 250 +model_max_length = 120 +lora_rank = 4 +num_epochs = 80 +gradient_accumulation_steps = 1 +grad_checkpointing = False +gradient_clip = 1.0 +gc_step = 1 +auto_lr = dict(rule='sqrt') + +# we use different weight decay with the official implementation since it results better result +optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10) +lr_schedule = 'constant' +lr_schedule_args = dict(num_warmup_steps=500) + +save_image_epochs = 1 +save_model_epochs = 1 +save_model_steps=1000000 + +sample_posterior = True +mixed_precision = 'fp16' +scale_factor = 0.18215 # ldm vae: 0.18215; sdxl vae: 0.13025 +ema_rate = 0.9999 +tensorboard_mox_interval = 50 +log_interval = 50 +cfg_scale = 4 +mask_type='null' +num_group_tokens=0 +mask_loss_coef=0. +load_mask_index=False # load prepared mask_type index +# load model settings +vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema" +load_from = None +resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True) +snr_loss=False +real_prompt_ratio = 1.0 +# classifier free guidance +class_dropout_prob = 0.1 +# work dir settings +work_dir = '/cache/exps/' +s3_work_dir = None +micro_condition = False +seed = 43 +skip_step=0 + +# LCM +loss_type = 'huber' +huber_c = 0.001 +num_ddim_timesteps=50 +w_max = 15.0 +w_min = 3.0 +ema_decay = 0.95 diff --git a/configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py b/configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py new file mode 100644 index 0000000000000000000000000000000000000000..886784ddfe787028ccdf1fb90f8d159d7bf2a0c9 --- /dev/null +++ b/configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py @@ -0,0 +1,30 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data/dreambooth/dataset' + +data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True) +image_size = 1024 + +# model setting +model = 'PixArtMS_XL_2' # model for multi-scale training +fp32_attention = True +load_from = 'Path/to/PixArt-XL-2-1024-MS.pth' +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 2.0 + +# training setting +num_workers=1 +train_batch_size = 1 +num_epochs = 200 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=0) +auto_lr = None + +log_interval = 1 +save_model_epochs=10000 +save_model_steps=100 +work_dir = 'output/debug' diff --git a/configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py b/configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e874813859eff844a16bdfdf71dee48ee62bf0 --- /dev/null +++ b/configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py @@ -0,0 +1,29 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json',] + +data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) +image_size = 1024 + +# model setting +model = 'PixArt_XL_2' +fp32_attention = True +load_from = None +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +pe_interpolation = 2.0 + +# training setting +num_workers=10 +train_batch_size = 2 # 32 +num_epochs = 200 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=1000) + +eval_sampling_steps = 200 +log_interval = 20 +save_model_epochs=1 +save_model_steps=2000 +work_dir = 'output/debug' diff --git a/configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py b/configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py new file mode 100644 index 0000000000000000000000000000000000000000..952a2b1e5797b1fabb3407b13579d1cc587d2ead --- /dev/null +++ b/configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py @@ -0,0 +1,32 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json',] + +data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) +image_size = 1024 + +# model setting +model = 'PixArtMS_XL_2' # model for multi-scale training +fp32_attention = True +load_from = None +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 2.0 + +# training setting +num_workers=10 +train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint +num_epochs = 10 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=1000) +save_model_epochs=1 +save_model_steps=2000 + +log_interval = 20 +eval_sampling_steps = 200 +work_dir = 'output/debug' +micro_condition = True diff --git a/configs/pixart_alpha_config/PixArt_xl2_img256_internal.py b/configs/pixart_alpha_config/PixArt_xl2_img256_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..fb4a21eee530851261170576bb360180f9d58a15 --- /dev/null +++ b/configs/pixart_alpha_config/PixArt_xl2_img256_internal.py @@ -0,0 +1,27 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json',] + +data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) +image_size = 256 + +# model setting +model = 'PixArt_XL_2' +fp32_attention = True +load_from = None +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +# training setting +eval_sampling_steps = 200 + +num_workers=10 +train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint +num_epochs = 200 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=1000) + +log_interval = 20 +save_model_epochs=5 +work_dir = 'output/debug' diff --git a/configs/pixart_alpha_config/PixArt_xl2_img512_internal.py b/configs/pixart_alpha_config/PixArt_xl2_img512_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..c7d06e958c3008260c92a113352d2e4106e35c01 --- /dev/null +++ b/configs/pixart_alpha_config/PixArt_xl2_img512_internal.py @@ -0,0 +1,29 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json',] + +data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) +image_size = 512 + +# model setting +model = 'PixArt_XL_2' +fp32_attention = True +load_from = None +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +pe_interpolation = 1.0 + +# training setting +use_fsdp=False # if use FSDP mode +num_workers=10 +train_batch_size = 38 # 32 +num_epochs = 200 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=1000) + +eval_sampling_steps = 200 +log_interval = 20 +save_model_epochs=1 +work_dir = 'output/debug' diff --git a/configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py b/configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py new file mode 100644 index 0000000000000000000000000000000000000000..b0383fbd0963cd3b7ff0458dfb1865ba7b7b1ebe --- /dev/null +++ b/configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py @@ -0,0 +1,31 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json',] + +data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) +image_size = 512 + +# model setting +model = 'PixArtMS_XL_2' # model for multi-scale training +fp32_attention = True +load_from = None +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 1.0 + +# training setting +num_workers=10 +train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint +num_epochs = 20 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=1000) +save_model_epochs=1 +save_model_steps=2000 + +log_interval = 20 +eval_sampling_steps = 200 +work_dir = 'output/debug' diff --git a/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py b/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py new file mode 100644 index 0000000000000000000000000000000000000000..bfeddc245c0f0f9e45ebd39b7bbd5b6851bef4bc --- /dev/null +++ b/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py @@ -0,0 +1,46 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'pixart-sigma-toy-dataset' +image_list_json = ['data_info.json'] + +data = dict( + type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', + load_vae_feat=False, load_t5_feat=False +) +image_size = 1024 + +# model setting +model = 'PixArtMS_XL_2' +mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] +fp32_attention = True +load_from = None +resume_from = None +vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae +aspect_ratio_type = 'ASPECT_RATIO_1024' +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 2.0 + +# training setting +num_workers = 10 +train_batch_size = 2 # 3 for w.o feature extraction; 12 for feature extraction +num_epochs = 2 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) +lr_schedule_args = dict(num_warmup_steps=1000) + +eval_sampling_steps = 500 +visualize = True +log_interval = 20 +save_model_epochs = 1 +save_model_steps = 1000 +work_dir = 'output/debug' + +# pixart-sigma +scale_factor = 0.13025 +real_prompt_ratio = 0.5 +model_max_length = 300 +class_dropout_prob = 0.1 + +qk_norm = False +skip_step = 0 # skip steps during data loading diff --git a/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py b/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a7b6de3974f1d789ad7906df240a5e57041b62 --- /dev/null +++ b/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py @@ -0,0 +1,51 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json'] + +data = dict( + type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', + load_vae_feat=False, load_t5_feat=False +) +image_size = 1024 + +# model setting +model = 'PixArtMS_XL_2' +mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] +fp32_attention = True +load_from = None +resume_from = None +vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae +aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 2.0 + +# training setting +num_workers = 10 +train_batch_size = 4 # 16 +num_epochs = 2 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) +lr_schedule_args = dict(num_warmup_steps=500) + +eval_sampling_steps = 250 +visualize = True +log_interval = 10 +save_model_epochs = 1 +save_model_steps = 1000 +work_dir = 'output/debug' + +# pixart-sigma +scale_factor = 0.13025 +real_prompt_ratio = 0.5 +model_max_length = 300 +class_dropout_prob = 0.1 +kv_compress = True +kv_compress_config = { + 'sampling': 'conv', # ['conv', 'uniform', 'ave'] + 'scale_factor': 2, + 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], +} +qk_norm = False +skip_step = 0 # skip steps during data loading diff --git a/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py b/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py new file mode 100644 index 0000000000000000000000000000000000000000..e3623722f222270c57828f9a83dcaaae4026fd01 --- /dev/null +++ b/configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py @@ -0,0 +1,52 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'pixart-sigma-toy-dataset' +image_list_json = ['data_info.json'] + +data = dict( + type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', + load_vae_feat=True, load_t5_feat=True, +) +image_size = 1024 + +# model setting +model = 'PixArtMS_XL_2' # model for multi-scale training +fp32_attention = False +load_from = None +resume_from = None +vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae +aspect_ratio_type = 'ASPECT_RATIO_1024' +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 2.0 + +# training setting +num_workers = 4 +train_batch_size = 12 # max 12 for PixArt-xL/2 when grad_checkpoint +num_epochs = 10 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='CAMEWrapper', lr=1e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) +lr_schedule_args = dict(num_warmup_steps=100) +save_model_epochs = 10 +save_model_steps = 1000 +valid_num = 0 # take as valid aspect-ratio when sample number >= valid_num + +log_interval = 10 +eval_sampling_steps = 5 +visualize = True +work_dir = 'output/debug' + +# pixart-sigma +scale_factor = 0.13025 +real_prompt_ratio = 0.5 +model_max_length = 300 +class_dropout_prob = 0.1 + +# LCM +loss_type = 'huber' +huber_c = 0.001 +num_ddim_timesteps = 50 +w_max = 15.0 +w_min = 3.0 +ema_decay = 0.95 +cfg_scale = 4.5 diff --git a/configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py b/configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..9a41144f3784c67df3f4cad11985417e93fe56bf --- /dev/null +++ b/configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py @@ -0,0 +1,41 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'pixart-sigma-toy-dataset' +image_list_json = ['data_info.json'] + +data = dict( + type='InternalDataSigma', root='InternData', image_list_json=image_list_json, transform='default_train', + load_vae_feat=False, load_t5_feat=False, +) +image_size = 256 + +# model setting +model = 'PixArt_XL_2' +mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] +fp32_attention = True +load_from = "output/pretrained_models/PixArt-Sigma-XL-2-256x256.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma +resume_from = None +vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae +multi_scale = False # if use multiscale dataset model training +pe_interpolation = 0.5 + +# training setting +num_workers = 10 +train_batch_size = 64 # 64 as default +num_epochs = 200 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) +lr_schedule_args = dict(num_warmup_steps=1000) + +eval_sampling_steps = 500 +log_interval = 20 +save_model_epochs = 5 +save_model_steps = 2500 +work_dir = 'output/debug' + +# pixart-sigma +scale_factor = 0.13025 +real_prompt_ratio = 0.5 +model_max_length = 300 +class_dropout_prob = 0.1 diff --git a/configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py b/configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py new file mode 100644 index 0000000000000000000000000000000000000000..0b719dbdb1c636039d3a0ef8559d34e5b42a140e --- /dev/null +++ b/configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py @@ -0,0 +1,49 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'data' +image_list_json = ['data_info.json'] + +data = dict( + type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', + load_vae_feat=False, load_t5_feat=False +) +image_size = 2048 + +# model setting +model = 'PixArtMS_XL_2' +mixed_precision = 'fp16' +fp32_attention = True +load_from = None +resume_from = None +vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae +aspect_ratio_type = 'ASPECT_RATIO_2048' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256] +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 4.0 + +# training setting +num_workers = 10 +train_batch_size = 4 # 48 +num_epochs = 10 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) +lr_schedule_args = dict(num_warmup_steps=100) + +eval_sampling_steps = 100 +visualize = True +log_interval = 10 +save_model_epochs = 10 +save_model_steps = 100 +work_dir = 'output/debug' + +# pixart-sigma +scale_factor = 0.13025 +real_prompt_ratio = 0.5 +model_max_length = 300 +class_dropout_prob = 0.1 +kv_compress = False +kv_compress_config = { + 'sampling': 'conv', + 'scale_factor': 2, + 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], +} diff --git a/configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py b/configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py new file mode 100644 index 0000000000000000000000000000000000000000..93f2e5a27fdaf1a84803b5b7ceca1c24e549cbce --- /dev/null +++ b/configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py @@ -0,0 +1,43 @@ +_base_ = ['../PixArt_xl2_internal.py'] +data_root = 'pixart-sigma-toy-dataset' +image_list_json = ['data_info.json'] + +data = dict( + type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train', + load_vae_feat=False, load_t5_feat=False, +) +image_size = 512 + +# model setting +model = 'PixArtMS_XL_2' +mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16'] +fp32_attention = True +load_from = "output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma +resume_from = None +vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae +aspect_ratio_type = 'ASPECT_RATIO_512' +multi_scale = True # if use multiscale dataset model training +pe_interpolation = 1.0 + +# training setting +num_workers = 10 +train_batch_size = 2 # 48 as default +num_epochs = 10 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16)) +lr_schedule_args = dict(num_warmup_steps=1000) + +eval_sampling_steps = 500 +visualize = True +log_interval = 20 +save_model_epochs = 5 +save_model_steps = 2500 +work_dir = 'output/debug' + +# pixart-sigma +scale_factor = 0.13025 +real_prompt_ratio = 0.5 +model_max_length = 300 +class_dropout_prob = 0.1 diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54508d56cbff959d571efbc6c1a8041e0ace22a1 --- /dev/null +++ b/diffusion/__init__.py @@ -0,0 +1,8 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from .iddpm import IDDPM +from .dpm_solver import DPMS +from .sa_sampler import SASolverSampler diff --git a/diffusion/data/__init__.py b/diffusion/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bed3472a43d5e1b81d9400560c298e98b971a38 --- /dev/null +++ b/diffusion/data/__init__.py @@ -0,0 +1,2 @@ +from .datasets import * +from .transforms import get_transform diff --git a/diffusion/data/builder.py b/diffusion/data/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..2059a6645c16118e4fd39dd0c521ef351b5f76f8 --- /dev/null +++ b/diffusion/data/builder.py @@ -0,0 +1,50 @@ +import os +import time + +from mmcv import Registry, build_from_cfg +from torch.utils.data import DataLoader + +from diffusion.data.transforms import get_transform +from diffusion.utils.logger import get_root_logger + +DATASETS = Registry('datasets') + +DATA_ROOT = '/cache/data' + + +def set_data_root(data_root): + global DATA_ROOT + DATA_ROOT = data_root + + +def get_data_path(data_dir): + if os.path.isabs(data_dir): + return data_dir + global DATA_ROOT + return os.path.join(DATA_ROOT, data_dir) + + +def build_dataset(cfg, resolution=224, **kwargs): + logger = get_root_logger() + + dataset_type = cfg.get('type') + logger.info(f"Constructing dataset {dataset_type}...") + t = time.time() + transform = cfg.pop('transform', 'default_train') + transform = get_transform(transform, resolution) + dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs)) + logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}") + return dataset + + +def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs): + if 'batch_sampler' in kwargs: + dataloader = DataLoader(dataset, batch_sampler=kwargs['batch_sampler'], num_workers=num_workers, pin_memory=True) + else: + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, + **kwargs) + return dataloader diff --git a/diffusion/data/datasets/InternalData.py b/diffusion/data/datasets/InternalData.py new file mode 100644 index 0000000000000000000000000000000000000000..3e287cccdfc190de819be94990366dd1e0242679 --- /dev/null +++ b/diffusion/data/datasets/InternalData.py @@ -0,0 +1,312 @@ +import os +import random +from PIL import Image +import numpy as np +import torch +from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS +from torch.utils.data import Dataset +from diffusers.utils.torch_utils import randn_tensor +from torchvision import transforms as T +from diffusion.data.builder import get_data_path, DATASETS +from diffusion.utils.logger import get_root_logger + +import json + +@DATASETS.register_module() +class InternalData(Dataset): + def __init__(self, + root, + image_list_json='data_info.json', + transform=None, + resolution=256, + sample_subset=None, + load_vae_feat=False, + input_size=32, + patch_size=2, + mask_ratio=0.0, + load_mask_index=False, + max_length=120, + config=None, + **kwargs): + self.root = get_data_path(root) + self.transform = transform + self.load_vae_feat = load_vae_feat + self.ori_imgs_nums = 0 + self.resolution = resolution + self.N = int(resolution // (input_size // patch_size)) + self.mask_ratio = mask_ratio + self.load_mask_index = load_mask_index + self.max_lenth = max_length + self.meta_data_clean = [] + self.img_samples = [] + self.txt_feat_samples = [] + self.vae_feat_samples = [] + self.mask_index_samples = [] + self.prompt_samples = [] + + image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] + for json_file in image_list_json: + meta_data = self.load_json(os.path.join(self.root, 'partition', json_file)) + self.ori_imgs_nums += len(meta_data) + meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] + self.meta_data_clean.extend(meta_data_clean) + self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) + self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) + self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}resolution/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) + self.prompt_samples.extend([item['prompt'] for item in meta_data_clean]) + + # Set loader and extensions + if load_vae_feat: + self.transform = None + self.loader = self.vae_feat_loader + else: + self.loader = default_loader + + if sample_subset is not None: + self.sample_subset(sample_subset) # sample dataset for local debug + logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + logger.info(f"T5 max token length: {self.max_lenth}") + + def getdata(self, index): + img_path = self.img_samples[index] + npz_path = self.txt_feat_samples[index] + npy_path = self.vae_feat_samples[index] + prompt = self.prompt_samples[index] + data_info = { + 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), + 'aspect_ratio': torch.tensor(1.) + } + + img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) + txt_info = np.load(npz_path) + txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 + attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT + if 'attention_mask' in txt_info.keys(): + attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] + if txt_fea.shape[1] != self.max_lenth: + txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) + attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) + + if self.transform: + img = self.transform(img) + + data_info['prompt'] = prompt + return img, txt_fea, attention_mask, data_info + + def __getitem__(self, idx): + for _ in range(20): + try: + return self.getdata(idx) + except Exception as e: + print(f"Error details: {str(e)}") + idx = np.random.randint(len(self)) + raise RuntimeError('Too many bad data.') + + def get_data_info(self, idx): + data_info = self.meta_data_clean[idx] + return {'height': data_info['height'], 'width': data_info['width']} + + @staticmethod + def vae_feat_loader(path): + # [mean, std] + mean, std = torch.from_numpy(np.load(path)).chunk(2) + sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) + return mean + std * sample + + def load_ori_img(self, img_path): + # 加载图像并转换为Tensor + transform = T.Compose([ + T.Resize(256), # Image.BICUBIC + T.CenterCrop(256), + T.ToTensor(), + ]) + return transform(Image.open(img_path)) + + def load_json(self, file_path): + with open(file_path, 'r') as f: + meta_data = json.load(f) + + return meta_data + + def sample_subset(self, ratio): + sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) + self.img_samples = [self.img_samples[i] for i in sampled_idx] + + def __len__(self): + return len(self.img_samples) + + def __getattr__(self, name): + if name == "set_epoch": + return lambda epoch: None + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + +@DATASETS.register_module() +class InternalDataSigma(Dataset): + def __init__(self, + root, + image_list_json='data_info.json', + transform=None, + resolution=256, + sample_subset=None, + load_vae_feat=False, + load_t5_feat=False, + input_size=32, + patch_size=2, + mask_ratio=0.0, + mask_type='null', + load_mask_index=False, + real_prompt_ratio=1.0, + max_length=300, + config=None, + **kwargs): + self.root = get_data_path(root) + self.transform = transform + self.load_vae_feat = load_vae_feat + self.load_t5_feat = load_t5_feat + self.ori_imgs_nums = 0 + self.resolution = resolution + self.N = int(resolution // (input_size // patch_size)) + self.mask_ratio = mask_ratio + self.load_mask_index = load_mask_index + self.mask_type = mask_type + self.real_prompt_ratio = real_prompt_ratio + self.max_lenth = max_length + self.meta_data_clean = [] + self.img_samples = [] + self.txt_samples = [] + self.sharegpt4v_txt_samples = [] + self.txt_feat_samples = [] + self.vae_feat_samples = [] + self.mask_index_samples = [] + self.gpt4v_txt_feat_samples = [] + self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32 + logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + logger.info(f"T5 max token length: {self.max_lenth}") + logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}") + + image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] + for json_file in image_list_json: + meta_data = self.load_json(os.path.join(self.root, json_file)) + logger.info(f"{json_file} data volume: {len(meta_data)}") + self.ori_imgs_nums += len(meta_data) + meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5] + self.meta_data_clean.extend(meta_data_clean) + self.img_samples.extend([ + os.path.join(self.root.replace('InternData', 'InternImgs'), item['path']) for item in meta_data_clean + ]) + self.txt_samples.extend([item['prompt'] for item in meta_data_clean]) + self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean]) + self.txt_feat_samples.extend([ + os.path.join( + self.root, + 'caption_features_new', + item['path'].rsplit('/', 1)[-1].replace('.png', '.npz') + ) for item in meta_data_clean + ]) + self.gpt4v_txt_feat_samples.extend([ + os.path.join( + self.root, + 'sharegpt4v_caption_features_new', + item['path'].rsplit('/', 1)[-1].replace('.png', '.npz') + ) for item in meta_data_clean + ]) + self.vae_feat_samples.extend( + [ + os.path.join( + self.root, + f'img_sdxl_vae_features_{resolution}resolution_new', + item['path'].rsplit('/', 1)[-1].replace('.png', '.npy') + ) for item in meta_data_clean + ]) + + # Set loader and extensions + if load_vae_feat: + self.transform = None + self.loader = self.vae_feat_loader + else: + self.loader = default_loader + + if sample_subset is not None: + self.sample_subset(sample_subset) # sample dataset for local debug + + def getdata(self, index): + img_path = self.img_samples[index] + real_prompt = random.random() < self.real_prompt_ratio + npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index] + txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index] + npy_path = self.vae_feat_samples[index] + data_info = {'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), + 'aspect_ratio': torch.tensor(1.)} + + if self.load_vae_feat: + img = self.loader(npy_path) + else: + img = self.loader(img_path) + + attention_mask = torch.ones(1, 1, self.max_lenth) # 1x1xT + if self.load_t5_feat: + txt_info = np.load(npz_path) + txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 + if 'attention_mask' in txt_info.keys(): + attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] + if txt_fea.shape[1] != self.max_lenth: + txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) + attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) + else: + txt_fea = txt + + if self.transform: + img = self.transform(img) + + data_info["mask_type"] = self.mask_type + return img, txt_fea, attention_mask.to(torch.int16), data_info + + def __getitem__(self, idx): + for _ in range(20): + try: + data = self.getdata(idx) + return data + except Exception as e: + print(f"Error details {self.img_samples[idx]}: {str(e)}") + idx = np.random.randint(len(self)) + raise RuntimeError('Too many bad data.') + + def get_data_info(self, idx): + data_info = self.meta_data_clean[idx] + return {'height': data_info['height'], 'width': data_info['width']} + + @staticmethod + def vae_feat_loader(path): + # [mean, std] + mean, std = torch.from_numpy(np.load(path)).chunk(2) + sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype) + return mean + std * sample + + def load_ori_img(self, img_path): + # 加载图像并转换为Tensor + transform = T.Compose([ + T.Resize(256), # Image.BICUBIC + T.CenterCrop(256), + T.ToTensor(), + ]) + img = transform(Image.open(img_path)) + return img + + def load_json(self, file_path): + with open(file_path, 'r') as f: + meta_data = json.load(f) + + return meta_data + + def sample_subset(self, ratio): + sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio)) + self.img_samples = [self.img_samples[i] for i in sampled_idx] + + def __len__(self): + return len(self.img_samples) + + def __getattr__(self, name): + if name == "set_epoch": + return lambda epoch: None + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + diff --git a/diffusion/data/datasets/InternalData_ms.py b/diffusion/data/datasets/InternalData_ms.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0f6357591745861b228250c9f3ae5fc206f155 --- /dev/null +++ b/diffusion/data/datasets/InternalData_ms.py @@ -0,0 +1,336 @@ +import os +import numpy as np +import torch +import random +from torchvision.datasets.folder import default_loader +from diffusion.data.datasets.InternalData import InternalData, InternalDataSigma +from diffusion.data.builder import get_data_path, DATASETS +from diffusion.utils.logger import get_root_logger +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode +from diffusion.data.datasets.utils import * + +def get_closest_ratio(height: float, width: float, ratios: dict): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return ratios[closest_ratio], float(closest_ratio) + + +@DATASETS.register_module() +class InternalDataMS(InternalData): + def __init__(self, + root, + image_list_json='data_info.json', + transform=None, + resolution=256, + sample_subset=None, + load_vae_feat=False, + input_size=32, + patch_size=2, + mask_ratio=0.0, + mask_type='null', + load_mask_index=False, + real_prompt_ratio=1.0, + max_length=120, + config=None, + **kwargs): + self.root = get_data_path(root) + self.transform = transform + self.load_vae_feat = load_vae_feat + self.ori_imgs_nums = 0 + self.resolution = resolution + self.N = int(resolution // (input_size // patch_size)) + self.mask_ratio = mask_ratio + self.load_mask_index = load_mask_index + self.mask_type = mask_type + self.real_prompt_ratio = real_prompt_ratio + self.max_lenth = max_length + self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) + self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio + self.meta_data_clean = [] + self.img_samples = [] + self.txt_feat_samples = [] + self.vae_feat_samples = [] + self.mask_index_samples = [] + self.ratio_index = {} + self.ratio_nums = {} + # self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32 + for k, v in self.aspect_ratio.items(): + self.ratio_index[float(k)] = [] # used for self.getitem + self.ratio_nums[float(k)] = 0 # used for batch-sampler + + image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] + for json_file in image_list_json: + meta_data = self.load_json(os.path.join(self.root, json_file)) + self.ori_imgs_nums += len(meta_data) + meta_data_clean = [item for item in meta_data if item['ratio'] <= 4] + self.meta_data_clean.extend(meta_data_clean) + self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean]) + self.txt_feat_samples.extend([os.path.join(self.root, 'caption_features', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean]) + self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean]) + + # Set loader and extensions + if load_vae_feat: + self.transform = None + self.loader = self.vae_feat_loader + else: + self.loader = default_loader + + if sample_subset is not None: + self.sample_subset(sample_subset) # sample dataset for local debug + + # scan the dataset for ratio static + for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): + ori_h, ori_w = info['height'], info['width'] + closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) + self.ratio_nums[closest_ratio] += 1 + if len(self.ratio_index[closest_ratio]) == 0: + self.ratio_index[closest_ratio].append(i) + # print(self.ratio_nums) + logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + logger.info(f"T5 max token length: {self.max_lenth}") + + def getdata(self, index): + img_path = self.img_samples[index] + npz_path = self.txt_feat_samples[index] + npy_path = self.vae_feat_samples[index] + ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width'] + + # Calculate the closest aspect ratio and resize & crop image[w, h] + closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) + closest_size = list(map(lambda x: int(x), closest_size)) + self.closest_ratio = closest_ratio + + if self.load_vae_feat: + try: + img = self.loader(npy_path) + if index not in self.ratio_index[closest_ratio]: + self.ratio_index[closest_ratio].append(index) + except Exception: + index = random.choice(self.ratio_index[closest_ratio]) + return self.getdata(index) + h, w = (img.shape[1], img.shape[2]) + assert h, w == (ori_h//8, ori_w//8) + else: + img = self.loader(img_path) + h, w = (img.size[1], img.size[0]) + assert h, w == (ori_h, ori_w) + + data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)} + data_info['aspect_ratio'] = closest_ratio + data_info["mask_type"] = self.mask_type + + txt_info = np.load(npz_path) + txt_fea = torch.from_numpy(txt_info['caption_feature']) + attention_mask = torch.ones(1, 1, txt_fea.shape[1]) + if 'attention_mask' in txt_info.keys(): + attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] + + if not self.load_vae_feat: + if closest_size[0] / ori_h > closest_size[1] / ori_w: + resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) + else: + resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] + self.transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC + T.CenterCrop(closest_size), + T.ToTensor(), + T.Normalize([.5], [.5]), + ]) + + if self.transform: + img = self.transform(img) + + return img, txt_fea, attention_mask, data_info + + def __getitem__(self, idx): + for _ in range(20): + try: + return self.getdata(idx) + except Exception as e: + print(f"Error details: {str(e)}") + idx = random.choice(self.ratio_index[self.closest_ratio]) + raise RuntimeError('Too many bad data.') + + +@DATASETS.register_module() +class InternalDataMSSigma(InternalDataSigma): + def __init__(self, + root, + image_list_json='data_info.json', + transform=None, + resolution=256, + sample_subset=None, + load_vae_feat=False, + load_t5_feat=False, + input_size=32, + patch_size=2, + mask_ratio=0.0, + mask_type='null', + load_mask_index=False, + real_prompt_ratio=1.0, + max_length=300, + config=None, + **kwargs): + self.root = get_data_path(root) + self.transform = transform + self.load_vae_feat = load_vae_feat + self.load_t5_feat = load_t5_feat + self.ori_imgs_nums = 0 + self.resolution = resolution + self.N = int(resolution // (input_size // patch_size)) + self.mask_ratio = mask_ratio + self.load_mask_index = load_mask_index + self.mask_type = mask_type + self.real_prompt_ratio = real_prompt_ratio + self.max_lenth = max_length + self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1]) + self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio + self.meta_data_clean = [] + self.img_samples = [] + self.txt_samples = [] + self.sharegpt4v_txt_samples = [] + self.txt_feat_samples = [] + self.vae_feat_samples = [] + self.mask_index_samples = [] + self.ratio_index = {} + self.ratio_nums = {} + self.gpt4v_txt_feat_samples = [] + self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32 + self.interpolate_model = InterpolationMode.BICUBIC + if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]: + self.interpolate_model = InterpolationMode.LANCZOS + suffix = '' + for k, v in self.aspect_ratio.items(): + self.ratio_index[float(k)] = [] # used for self.getitem + self.ratio_nums[float(k)] = 0 # used for batch-sampler + logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + logger.info(f"T5 max token length: {self.max_lenth}") + logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}") + + image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] + for json_file in image_list_json: + meta_data = self.load_json(os.path.join(self.root, json_file)) + logger.info(f"{json_file} data volume: {len(meta_data)}") + self.ori_imgs_nums += len(meta_data) + meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5] + self.meta_data_clean.extend(meta_data_clean) + self.img_samples.extend([ + os.path.join(self.root.replace('InternData'+suffix, 'InternImgs'), item['path']) for item in meta_data_clean + ]) + self.txt_samples.extend([item['prompt'] for item in meta_data_clean]) + self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean]) + self.txt_feat_samples.extend([ + os.path.join( + self.root, + 'caption_features_new', + '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz') + ) for item in meta_data_clean + ]) + self.gpt4v_txt_feat_samples.extend([ + os.path.join( + self.root, + 'sharegpt4v_caption_features_new', + '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz') + ) for item in meta_data_clean + ]) + self.vae_feat_samples.extend( + [ + os.path.join( + self.root + suffix, + f'img_sdxl_vae_features_{resolution}resolution_ms_new', + '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy') + ) for item in meta_data_clean + ]) + + if self.real_prompt_ratio < 1: + assert len(self.sharegpt4v_txt_samples[0]) != 0 + + # Set loader and extensions + if load_vae_feat: + self.transform = None + self.loader = self.vae_feat_loader + else: + self.loader = default_loader + + if sample_subset is not None: + self.sample_subset(sample_subset) # sample dataset for local debug + + # scan the dataset for ratio static + for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): + ori_h, ori_w = info['height'], info['width'] + closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) + self.ratio_nums[closest_ratio] += 1 + if len(self.ratio_index[closest_ratio]) == 0: + self.ratio_index[closest_ratio].append(i) + + def getdata(self, index): + img_path = self.img_samples[index] + real_prompt = random.random() < self.real_prompt_ratio + npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index] + txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index] + npy_path = self.vae_feat_samples[index] + data_info = {} + ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width'] + + # Calculate the closest aspect ratio and resize & crop image[w, h] + closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) + closest_size = list(map(lambda x: int(x), closest_size)) + self.closest_ratio = closest_ratio + + if self.load_vae_feat: + img = self.loader(npy_path) + if index not in self.ratio_index[closest_ratio]: + self.ratio_index[closest_ratio].append(index) + h, w = (img.shape[1], img.shape[2]) + assert h, w == (ori_h//8, ori_w//8) + else: + img = self.loader(img_path) + h, w = (img.size[1], img.size[0]) + assert h, w == (ori_h, ori_w) + + data_info['img_hw'] = torch.tensor([ori_h, ori_w], dtype=torch.float32) + data_info['aspect_ratio'] = closest_ratio + data_info["mask_type"] = self.mask_type + + attention_mask = torch.ones(1, 1, self.max_lenth) + if self.load_t5_feat: + txt_info = np.load(npz_path) + txt_fea = torch.from_numpy(txt_info['caption_feature']) + if 'attention_mask' in txt_info.keys(): + attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] + if txt_fea.shape[1] != self.max_lenth: + txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1).to(self.weight_dtype) + attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) + else: + txt_fea = txt + + if not self.load_vae_feat: + if closest_size[0] / ori_h > closest_size[1] / ori_w: + resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) + else: + resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] + self.transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC + T.CenterCrop(closest_size), + T.ToTensor(), + T.Normalize([.5], [.5]), + ]) + + if self.transform: + img = self.transform(img) + + return img, txt_fea, attention_mask.to(torch.int16), data_info + + def __getitem__(self, idx): + for _ in range(20): + try: + data = self.getdata(idx) + return data + except Exception as e: + print(f"Error details: {str(e)}") + idx = random.choice(self.ratio_index[self.closest_ratio]) + raise RuntimeError('Too many bad data.') + diff --git a/diffusion/data/datasets/__init__.py b/diffusion/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25573d1e9f8ba08717f6068d5d402d0429357fc7 --- /dev/null +++ b/diffusion/data/datasets/__init__.py @@ -0,0 +1,3 @@ +from .InternalData import InternalData, InternalDataSigma +from .InternalData_ms import InternalDataMS, InternalDataSigma +from .utils import * diff --git a/diffusion/data/datasets/utils.py b/diffusion/data/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..27fae41e9e759e446b3f5f1fad13abd0520aa396 --- /dev/null +++ b/diffusion/data/datasets/utils.py @@ -0,0 +1,134 @@ + +ASPECT_RATIO_2880 = { + '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0], + '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0], + '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0], + '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0], + '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0], + '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0], + '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0], + '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0], + '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0], + '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0] +} + +ASPECT_RATIO_2048 = { + '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0], + '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], + '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], + '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], + '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], + '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], + '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], + '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], + '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0], + '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0] +} + +ASPECT_RATIO_1024 = { + '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], + '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], + '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], + '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], + '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], + '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], + '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], + '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], + '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], + '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], +} + +ASPECT_RATIO_512 = { + '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], + '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] + } + +ASPECT_RATIO_256 = { + '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], + '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], + '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], + '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], + '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], + '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], + '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], + '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], + '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], + '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] +} + +ASPECT_RATIO_256_TEST = { + '0.25': [128.0, 512.0], '0.28': [128.0, 464.0], + '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], + '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], + '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], + '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], + '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], + '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], + '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], + '2.5': [400.0, 160.0], '3.0': [432.0, 144.0], + '4.0': [512.0, 128.0] +} + +ASPECT_RATIO_512_TEST = { + '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '3.0': [864.0, 288.0], + '4.0': [1024.0, 256.0] + } + +ASPECT_RATIO_1024_TEST = { + '0.25': [512., 2048.], '0.28': [512., 1856.], + '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], + '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], + '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], + '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], + '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], + '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], + '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], + '2.5': [1600., 640.], '3.0': [1728., 576.], + '4.0': [2048., 512.], +} + +ASPECT_RATIO_2048_TEST = { + '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], + '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], + '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], + '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], + '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], + '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], + '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], + '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], + '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0], + '4.0': [4096.0, 1024.0] +} + +ASPECT_RATIO_2880_TEST = { + '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0], + '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0], + '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0], + '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0], + '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0], + '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0], + '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0], + '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0], + '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0], + '4.0': [8192.0, 2048.0], +} + +def get_chunks(lst, n): + for i in range(0, len(lst), n): + yield lst[i:i + n] diff --git a/diffusion/data/transforms.py b/diffusion/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..fe35a149436b030c71c6f14830ae08e008e53177 --- /dev/null +++ b/diffusion/data/transforms.py @@ -0,0 +1,30 @@ +import torchvision.transforms as T + +TRANSFORMS = dict() + + +def register_transform(transform): + name = transform.__name__ + if name in TRANSFORMS: + raise RuntimeError(f'Transform {name} has already registered.') + TRANSFORMS.update({name: transform}) + + +def get_transform(type, resolution): + transform = TRANSFORMS[type](resolution) + transform = T.Compose(transform) + transform.image_size = resolution + return transform + + +@register_transform +def default_train(n_px): + transform = [ + T.Lambda(lambda img: img.convert('RGB')), + T.Resize(n_px), # Image.BICUBIC + T.CenterCrop(n_px), + # T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize([.5], [.5]), + ] + return transform diff --git a/diffusion/dpm_solver.py b/diffusion/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..232449c276973b75c7c9c93f8904bf017a42ac39 --- /dev/null +++ b/diffusion/dpm_solver.py @@ -0,0 +1,36 @@ +import torch +from .model import gaussian_diffusion as gd +from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP + + +def DPMS( + model, + condition, + uncondition, + cfg_scale, + model_type='noise', # or "x_start" or "v" or "score" + noise_schedule="linear", + guidance_type='classifier-free', + model_kwargs={}, + diffusion_steps=1000 +): + betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) + + ## 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas) + + ## 2. Convert your discrete-time `model` to the continuous-time + ## noise prediction model. Here is an example for a diffusion model + ## `model` with the noise prediction type ("noise") . + model_fn = model_wrapper( + model, + noise_schedule, + model_type=model_type, + model_kwargs=model_kwargs, + guidance_type=guidance_type, + condition=condition, + unconditional_condition=uncondition, + guidance_scale=cfg_scale, + ) + ## 3. Define dpm-solver and sample by multistep DPM-Solver. + return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") \ No newline at end of file diff --git a/diffusion/iddpm.py b/diffusion/iddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..c9459f4c807d2318a699392d51a86bc10bbe318f --- /dev/null +++ b/diffusion/iddpm.py @@ -0,0 +1,53 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +from diffusion.model.respace import SpacedDiffusion, space_timesteps +from .model import gaussian_diffusion as gd + + +def IDDPM( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + pred_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + snr=False, + return_startx=False, +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + (( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ) + if pred_sigma + else None + ), + loss_type=loss_type, + snr=snr, + return_startx=return_startx, + # rescale_timesteps=rescale_timesteps, + ) \ No newline at end of file diff --git a/diffusion/lcm_scheduler.py b/diffusion/lcm_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..9d69dcedcc33ec494b3b152b8922c2cb3f976bc9 --- /dev/null +++ b/diffusion/lcm_scheduler.py @@ -0,0 +1,459 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers import ConfigMixin, SchedulerMixin +from diffusers.configuration_utils import register_to_config +from diffusers.utils import BaseOutput + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class LCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + denoised: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class LCMScheduler(SchedulerMixin, ConfigMixin): + """ + `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + # _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # LCM Timesteps Setting: # Linear Spacing + c = self.config.num_train_timesteps // lcm_origin_steps + lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device) + + def get_scalings_for_boundary_condition_discrete(self, t): + self.sigma_data = 0.5 # Default: 0.5 + + # By dividing 0.1: This is almost a delta function at t=0. + c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2) + c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5) + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timeindex: int, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # 1. get previous step value + prev_timeindex = timeindex + 1 + if prev_timeindex < len(self.timesteps): + prev_timestep = self.timesteps[prev_timeindex] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Different Parameterization: + parameterization = self.config.prediction_type + + if parameterization == "epsilon": # noise-prediction + pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + + elif parameterization == "sample": # x-prediction + pred_x0 = model_output + + elif parameterization == "v_prediction": # v-prediction + pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + + # 4. Denoise model output using boundary conditions + denoised = c_out * pred_x0 + c_skip * sample + + # 5. Sample z ~ N(0, I), For MultiStep Inference + # Noise is not used for one-step sampling. + if len(self.timesteps) > 1: + noise = torch.randn(model_output.shape).to(model_output.device) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + diff --git a/diffusion/model/__init__.py b/diffusion/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0d2755ad3a52d9d304763f96bf8b8c13dbbe76 --- /dev/null +++ b/diffusion/model/__init__.py @@ -0,0 +1 @@ +from .nets import * diff --git a/diffusion/model/builder.py b/diffusion/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..22821d03ef3410325885ad89a289c8ba1c032bac --- /dev/null +++ b/diffusion/model/builder.py @@ -0,0 +1,14 @@ +from mmcv import Registry + +from diffusion.model.utils import set_grad_checkpoint + +MODELS = Registry('models') + + +def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs): + if isinstance(cfg, str): + cfg = dict(type=cfg) + model = MODELS.build(cfg, default_args=kwargs) + if use_grad_checkpoint: + set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step) + return model diff --git a/diffusion/model/diffusion_utils.py b/diffusion/model/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cedd4fa2433f32c34df1157839b423ecb444e403 --- /dev/null +++ b/diffusion/model/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffusion/model/dpm_solver.py b/diffusion/model/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..6381eda3e747c7acbc040601d3db027c9208031a --- /dev/null +++ b/diffusion/model/dpm_solver.py @@ -0,0 +1,1337 @@ +import torch +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1. + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1. + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', + 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', + 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, + solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1)): + t = timesteps[step] + # We only use lower order for steps < 10 + # if lower_order_final and steps < 10: + if lower_order_final: # recommended by Shuchen Xue + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, + solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, + device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/diffusion/model/edm_sample.py b/diffusion/model/edm_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f930bd3be0e533f3829d83813479763229941881 --- /dev/null +++ b/diffusion/model/edm_sample.py @@ -0,0 +1,171 @@ +import random +import numpy as np +from tqdm import tqdm + +from diffusion.model.utils import * + + +# ---------------------------------------------------------------------------- +# Proposed EDM sampler (Algorithm 2). + +def edm_sampler( + net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( + sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next + + +# ---------------------------------------------------------------------------- +# Generalized ablation sampler, representing the superset of all sampling +# methods discussed in the paper. + +def ablation_sampler( + net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=None, sigma_max=None, rho=7, + solver='heun', discretization='edm', schedule='linear', scaling='none', + epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, +): + assert solver in ['euler', 'heun'] + assert discretization in ['vp', 've', 'iddpm', 'edm'] + assert schedule in ['vp', 've', 'linear'] + assert scaling in ['vp', 'none'] + + # Helper functions for VP & VE noise level schedules. + vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 + vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) + vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * ( + sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma ** 2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) + sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) + sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) + vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == 'vp': + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == 've': + orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == 'iddpm': + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] + else: + assert discretization == 'edm' + sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( + sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + + # Define noise level schedule. + if schedule == 'vp': + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == 've': + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + assert schedule == 'linear' + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == 'vp': + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + assert scaling == 'none' + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s( + t_hat) * S_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to( + torch.float64) + d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s( + t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == 'euler' or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + assert solver == 'heun' + denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to( + torch.float64) + d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv( + t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) + + return x_next diff --git a/diffusion/model/gaussian_diffusion.py b/diffusion/model/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6226d26c90d5d611e9e3834ee9d2fa672735ec --- /dev/null +++ b/diffusion/model/gaussian_diffusion.py @@ -0,0 +1,1041 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import enum +import math + +import numpy as np +import torch as th +import torch.nn.functional as F + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + snr=False, + return_startx=False, + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.snr = snr + self.return_startx = return_startx + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + + if False: + target_resolution = 128 # 1024:128; 512:64; 256:32; + reference_resolution = 64 # Reference resolution (e.g., 64x64) + scaling_factor = (target_resolution / reference_resolution) ** 2 + print('scaling_factor', scaling_factor) + + # Adjust alphas and betas according to the scaling factor + alpha_cumprod_snr_shift = self.alphas_cumprod / (scaling_factor * (1 - self.alphas_cumprod) + self.alphas_cumprod) + alpha_cuspord_rmove1 = np.concatenate([np.ones([1]), alpha_cumprod_snr_shift[:999]]) + alpha_snr_shift = alpha_cumprod_snr_shift / alpha_cuspord_rmove1 + + betas_snr_shift = 1 - alpha_snr_shift + + # Update the class attributes with adjusted values + snr_ref = (self.alphas_cumprod / (1 - self.alphas_cumprod)) + snr_cur = (alpha_cumprod_snr_shift / (1 - alpha_cumprod_snr_shift)) + + self.betas = betas_snr_shift + self.alphas_cumprod = np.cumprod(alpha_snr_shift, axis=0) + + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + else: + model_variance = th.zeros_like(model_output) + model_log_variance = th.zeros_like(model_output) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + t = timestep + if model_kwargs is None: + model_kwargs = {} + if skip_noise: + x_t = x_start + else: + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + if isinstance(model_output, dict) and model_output.get('x', None) is not None: + output = model_output['x'] + else: + output = model_output + + if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output = th.split(output, C, dim=1)[0] + return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output, model_var_values = th.split(output, C, dim=1) + # Learn the variance using the variational bound, but don't let it affect our mean prediction. + frozen_out = th.cat([output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out, **kwargs: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert output.shape == target.shape == x_start.shape + if self.snr: + if self.model_mean_type == ModelMeanType.START_X: + pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output) + pred_startx = output + elif self.model_mean_type == ModelMeanType.EPSILON: + pred_noise = output + pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output) + # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2) + # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2) + + t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32] + # best + target = th.where(t > 249, noise, x_start) + output = th.where(t > 249, pred_noise, pred_startx) + loss = (target - output) ** 2 + if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0: + assert 'mask' in model_output + loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1) + mask = model_output['mask'] + unmask = 1 - mask + terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1) + if model_kwargs['mask_loss_coef'] > 0: + terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1) + else: + terms["mse"] = mean_flat(loss) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + if "mae" in terms: + terms["loss"] = terms["loss"] + terms["mae"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + t = timestep + if model_kwargs is None: + model_kwargs = {} + if skip_noise: + x_t = x_start + else: + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0] + + if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output = th.split(output, C, dim=1)[0] + return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output, model_var_values = th.split(output, C, dim=1) + # Learn the variance using the variational bound, but don't let it affect our mean prediction. + frozen_out = th.cat([output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out, **kwargs: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert output.shape == target.shape == x_start.shape + if self.snr: + if self.model_mean_type == ModelMeanType.START_X: + pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output) + pred_startx = output + elif self.model_mean_type == ModelMeanType.EPSILON: + pred_noise = output + pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output) + # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2) + # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2) + + t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32] + # best + target = th.where(t > 249, noise, x_start) + output = th.where(t > 249, pred_noise, pred_startx) + loss = (target - output) ** 2 + terms["mse"] = mean_flat(loss) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + if "mae" in terms: + terms["loss"] = terms["loss"] + terms["mae"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/diffusion/model/llava/__init__.py b/diffusion/model/llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17f8019ecb9ba08eac7a70690c3450a950449a3c --- /dev/null +++ b/diffusion/model/llava/__init__.py @@ -0,0 +1 @@ +from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig \ No newline at end of file diff --git a/diffusion/model/llava/llava_mpt.py b/diffusion/model/llava/llava_mpt.py new file mode 100644 index 0000000000000000000000000000000000000000..c343585a0e77e2517f85df3ec0c3a9ec906deb21 --- /dev/null +++ b/diffusion/model/llava/llava_mpt.py @@ -0,0 +1,280 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +import math + +from transformers import AutoConfig, AutoModelForCausalLM, CLIPVisionModel, CLIPImageProcessor + +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from diffusion.model.llava.mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel + + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + + +class LlavaMPTConfig(MPTConfig): + model_type = "llava_mpt" + + +class LlavaMPTModel(MPTModel): + config_class = LlavaMPTConfig + + def __init__(self, config: MPTConfig, mm_vision_tower=None, mm_hidden_size=None): + super(LlavaMPTModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + # HACK: for FSDP + self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)] + # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) + + if hasattr(config, "use_mm_proj"): + self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model) + + def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False): + self.config.mm_vision_tower = vision_tower + + image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + if not hasattr(self, 'vision_tower'): + vision_tower = CLIPVisionModel.from_pretrained(vision_tower) + else: + vision_tower = self.vision_tower[0] + vision_tower.requires_grad_(False) + vision_tower = vision_tower.to(torch.float16) + self.vision_tower = [vision_tower] + + vision_config = vision_tower.config + num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 + + self.config.use_mm_proj = True + self.config.mm_hidden_size = vision_config.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + + if not hasattr(self, 'mm_projector'): + self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.d_model) + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items() if 'mm_projector' in k}) + + return dict( + image_processor=image_processor, + image_token_len=num_patches, + vision_config=vision_config + ) + + def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): + + # HACK: replace back original embeddings for LLaVA pretraining + orig_embeds_params = getattr(self, 'orig_embeds_params', None) + # if orig_embeds_params is not None: + # orig_embeds_params = orig_embeds_params[0] + # with torch.no_grad(): + # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data + + inputs_embeds = self.wte(input_ids) + + vision_tower = getattr(self, 'vision_tower', None) + if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: + # TODO: this is a modified multimodal LLM -- Haotian Liu + vision_tower = vision_tower[0] # HACK: for FSDP + with torch.no_grad(): + if type(images) is list: + # variable length images + image_features = [] + for image in images: + image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True) + select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) + select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] + image_feature = select_hidden_state[:, 1:] + image_features.append(image_feature) + else: + image_forward_outs = vision_tower(images, output_hidden_states=True) + select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) + select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] + image_features = select_hidden_state[:, 1:] + if type(images) is list: + image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features] + else: + image_features = self.mm_projector(image_features) + dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = self.mm_projector(dummy_image_features) + + new_input_embeds = [] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() + new_input_embeds.append(cur_input_embeds) + continue + if vision_tower.config.use_im_start_end: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum(): + raise ValueError("The number of image start tokens and image end tokens should be the same.") + image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0] + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device) + num_patches = cur_image_features.shape[0] + if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token: + raise ValueError("The image end token should follow the image start token.") + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) + else: + cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) + cur_image_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches: + raise ValueError("The number of image patch tokens should be the same as the number of image patches.") + masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0] + mask_index_start = masked_indices[0] + if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): + raise ValueError("The image patch tokens should be consecutive.") + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0) + else: + cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) + new_input_embeds.append(cur_new_input_embeds) + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(LlavaMPTModel, self).forward(input_ids=None, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, tok_emb=inputs_embeds) + + +class LlavaMPTForCausalLM(MPTForCausalLM): + config_class = LlavaMPTConfig + supports_gradient_checkpointing = True + + def __init__(self, config): + super(MPTForCausalLM, self).__init__(config) + + if not config.tie_word_embeddings: + raise ValueError('MPTForCausalLM only supports tied word embeddings') + self.transformer = LlavaMPTModel(config) + self.logit_scale = None + if config.logit_scale is not None: + logit_scale = config.logit_scale + if isinstance(logit_scale, str): + if logit_scale == 'inv_sqrt_d_model': + logit_scale = 1 / math.sqrt(config.d_model) + else: + raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + self.logit_scale = logit_scale + + def get_model(self): + return self.transformer + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlavaMPTModel): + module.gradient_checkpointing = value + + def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): + return_dict = return_dict if return_dict is not None else self.config.return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, images=images) + logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight) + if self.logit_scale is not None: + if self.logit_scale == 0: + warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') + logits *= self.logit_scale + loss = None + if labels is not None: + labels = torch.roll(labels, shifts=-1) + labels[:, -1] = -100 + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) + return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None: + raise NotImplementedError('inputs_embeds is not implemented for MPT yet') + attention_mask = kwargs['attention_mask'].bool() + if attention_mask[:, -1].sum() != attention_mask.shape[0]: + raise NotImplementedError('MPT does not support generation with right padding.') + if self.transformer.attn_uses_sequence_id and self.training: + sequence_id = torch.zeros_like(input_ids[:1]) + else: + sequence_id = None + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if self.transformer.prefix_lm: + prefix_mask = torch.ones_like(attention_mask) + if kwargs.get('use_cache') == False: + raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') + else: + prefix_mask = None + return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} + + def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, + tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None): + vision_config = self.get_model().vision_tower[0].config + vision_config.use_im_start_end = mm_use_im_start_end + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if tune_mm_mlp_adapter: + self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['transformer.wte.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + +AutoConfig.register("llava_mpt", LlavaMPTConfig) +AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) diff --git a/diffusion/model/llava/mpt/attention.py b/diffusion/model/llava/mpt/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca1069cd14ca055d918fa623d7da5efb4c5fd89 --- /dev/null +++ b/diffusion/model/llava/mpt/attention.py @@ -0,0 +1,276 @@ +"""Attention layers.""" +import math +import warnings +from typing import Optional +import torch +import torch.nn as nn +from einops import rearrange +from torch import nn +from .norm import LPLayerNorm + +def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): + if original_is_causal and num_query_tokens != num_key_tokens: + if num_query_tokens != 1: + raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') + else: + return False + return original_is_causal + +def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads) + v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads) + min_val = torch.finfo(q.dtype).min + (b, _, s_q, d) = q.shape + s_k = k.size(-1) + if softmax_scale is None: + softmax_scale = 1 / math.sqrt(d) + attn_weight = q.matmul(k) * softmax_scale + if attn_bias is not None: + if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): + raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.') + attn_weight = attn_weight + attn_bias + if key_padding_mask is not None: + if attn_bias is not None: + warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) + if is_causal: + s = max(s_q, s_k) + causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) + causal_mask = causal_mask.tril() + causal_mask = causal_mask.to(torch.bool) + causal_mask = ~causal_mask + causal_mask = causal_mask[-s_q:, -s_k:] + attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p: + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True) + out = attn_weight.matmul(v) + out = rearrange(out, 'b h s d -> b s (h d)') + if needs_weights: + return (out, attn_weight) + return (out, None) + +def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): + for tensor in tensors: + if tensor.dtype not in valid_dtypes: + raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') + if not tensor.is_cuda: + raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') + +def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): + try: + from flash_attn import bert_padding, flash_attn_interface + except: + raise RuntimeError('Please install flash-attn==1.0.3.post0') + check_valid_inputs(query, key, value) + if attn_bias is not None: + raise NotImplementedError(f'attn_bias not implemented for flash attn.') + (batch_size, seqlen) = query.shape[:2] + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) + query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + if multiquery: + key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) + value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) + dropout_p = dropout_p if training else 0.0 + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) + output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) + return (output, None) + +def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): + try: + from flash_attn import flash_attn_triton + except: + raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202') + check_valid_inputs(query, key, value) + if dropout_p: + raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.') + if needs_weights: + raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') + if key_padding_mask is not None: + warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + (b_size, s_k) = key_padding_mask.shape[:2] + if attn_bias is None: + attn_bias = query.new_zeros(b_size, 1, 1, s_k) + attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) + query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) + key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) + value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) + if multiquery: + key = key.expand(*key.shape[:2], n_heads, key.size(-1)) + value = value.expand(*value.shape[:2], n_heads, value.size(-1)) + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) + output = attn_output.view(*attn_output.shape[:2], -1) + return (output, None) + +class MultiheadAttention(nn.Module): + """Multi-head self attention. + + Using torch or triton attention implemetation enables user to also use + additive bias. + """ + + def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None): + super().__init__() + self.attn_impl = attn_impl + self.clip_qkv = clip_qkv + self.qk_ln = qk_ln + self.d_model = d_model + self.n_heads = n_heads + self.softmax_scale = softmax_scale + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) + self.attn_dropout_p = attn_pdrop + self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device) + fuse_splits = (d_model, 2 * d_model) + self.Wqkv._fused = (0, fuse_splits) + if self.qk_ln: + layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm + self.q_ln = layernorm_class(self.d_model, device=device) + self.k_ln = layernorm_class(self.d_model, device=device) + if self.attn_impl == 'flash': + self.attn_fn = flash_attn_fn + elif self.attn_impl == 'triton': + self.attn_fn = triton_flash_attn_fn + warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') + elif self.attn_impl == 'torch': + self.attn_fn = scaled_multihead_dot_product_attention + if torch.cuda.is_available(): + warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) + self.out_proj._is_residual = True + + def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.chunk(3, dim=2) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):] + (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights) + return (self.out_proj(context), attn_weights, past_key_value) + +class MultiQueryAttention(nn.Module): + """Multi-Query self attention. + + Using torch or triton attention implemetation enables user to also use + additive bias. + """ + + def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None): + super().__init__() + self.attn_impl = attn_impl + self.clip_qkv = clip_qkv + self.qk_ln = qk_ln + self.d_model = d_model + self.n_heads = n_heads + self.head_dim = d_model // n_heads + self.softmax_scale = softmax_scale + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.head_dim) + self.attn_dropout_p = attn_pdrop + self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) + fuse_splits = (d_model, d_model + self.head_dim) + self.Wqkv._fused = (0, fuse_splits) + if self.qk_ln: + layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm + self.q_ln = layernorm_class(d_model, device=device) + self.k_ln = layernorm_class(self.head_dim, device=device) + if self.attn_impl == 'flash': + self.attn_fn = flash_attn_fn + elif self.attn_impl == 'triton': + self.attn_fn = triton_flash_attn_fn + warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') + elif self.attn_impl == 'torch': + self.attn_fn = scaled_multihead_dot_product_attention + if torch.cuda.is_available(): + warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) + self.out_proj._is_residual = True + + def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):] + (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True) + return (self.out_proj(context), attn_weights, past_key_value) + +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + if (prefix_lm or not causal) or use_sequence_id: + return (1, n_heads, seq_len, seq_len) + return (1, n_heads, 1, seq_len) + elif prefix_lm or use_sequence_id: + return (1, 1, seq_len, seq_len) + return None + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + +def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + (device, dtype) = (attn_bias.device, attn_bias.dtype) + attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype)) + return attn_bias + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + +def gen_slopes(n_heads, alibi_bias_max=8, device=None): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) + m = m.mul(alibi_bias_max / _n_heads) + slopes = 1.0 / torch.pow(2, m) + if _n_heads != n_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] + return slopes.view(1, n_heads, 1, 1) + +def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) + if full: + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1) + alibi_bias = alibi_bias.abs().mul(-1) + slopes = gen_slopes(n_heads, alibi_bias_max, device=device) + alibi_bias = alibi_bias * slopes + return alibi_bias.to(dtype=dtype) +ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention} \ No newline at end of file diff --git a/diffusion/model/llava/mpt/blocks.py b/diffusion/model/llava/mpt/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..04493aa4c03ef1b14ec539c9af8e9c38e8befc8b --- /dev/null +++ b/diffusion/model/llava/mpt/blocks.py @@ -0,0 +1,41 @@ +"""GPT Blocks used for the GPT Model.""" +from typing import Dict, Optional, Tuple +import torch +import torch.nn as nn +from .attention import ATTN_CLASS_REGISTRY +from .norm import NORM_CLASS_REGISTRY + +class MPTMLP(nn.Module): + + def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): + super().__init__() + self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) + self.act = nn.GELU(approximate='none') + self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) + self.down_proj._is_residual = True + + def forward(self, x): + return self.down_proj(self.act(self.up_proj(x))) + +class MPTBlock(nn.Module): + + def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): + del kwargs + super().__init__() + norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] + attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] + self.norm_1 = norm_class(d_model, device=device) + self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) + self.norm_2 = norm_class(d_model, device=device) + self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) + self.resid_attn_dropout = nn.Dropout(resid_pdrop) + self.resid_ffn_dropout = nn.Dropout(resid_pdrop) + + def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + a = self.norm_1(x) + (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) + x = x + self.resid_attn_dropout(b) + m = self.norm_2(x) + n = self.ffn(m) + x = x + self.resid_ffn_dropout(n) + return (x, past_key_value) \ No newline at end of file diff --git a/diffusion/model/llava/mpt/configuration_mpt.py b/diffusion/model/llava/mpt/configuration_mpt.py new file mode 100644 index 0000000000000000000000000000000000000000..35d1269cd4b599799d6df7953a8d0c30b33d1e65 --- /dev/null +++ b/diffusion/model/llava/mpt/configuration_mpt.py @@ -0,0 +1,118 @@ +"""A HuggingFace-style model configuration.""" +from typing import Dict, Optional, Union +from transformers import PretrainedConfig +attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} +init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'} + +class MPTConfig(PretrainedConfig): + model_type = 'mpt' + + def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): + """The MPT configuration class. + + Args: + d_model (int): The size of the embedding dimension of the model. + n_heads (int): The number of attention heads. + n_layers (int): The number of layers in the model. + expansion_ratio (int): The ratio of the up/down scale in the MLP. + max_seq_len (int): The maximum sequence length of the model. + vocab_size (int): The size of the vocabulary. + resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. + emb_pdrop (float): The dropout probability for the embedding layer. + learned_pos_emb (bool): Whether to use learned positional embeddings + attn_config (Dict): A dictionary used to configure the model's attention module: + attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention + attn_pdrop (float): The dropout probability for the attention layers. + attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. + qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. + clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to + this value. + softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, + use the default scale of ``1/sqrt(d_keys)``. + prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an + extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix + can attend to one another bi-directionally. Tokens outside the prefix use causal attention. + attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. + When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates + which sub-sequence each token belongs to. + Defaults to ``False`` meaning any provided `sequence_id` will be ignored. + alibi (bool): Whether to use the alibi bias instead of position embeddings. + alibi_bias_max (int): The maximum value of the alibi bias. + init_device (str): The device to use for parameter initialization. + logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. + no_bias (bool): Whether to use bias in all layers. + verbose (int): The verbosity level. 0 is silent. + embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. + norm_type (str): choose type of norm to use + multiquery_attention (bool): Whether to use multiquery attention implementation. + use_cache (bool): Whether or not the model should return the last key/values attentions + init_config (Dict): A dictionary used to configure the model initialization: + init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', + 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or + 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. + init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. + emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. + emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution + used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. + init_std (float): The standard deviation of the normal distribution used to initialize the model, + if using the baseline_ parameter initialization scheme. + init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. + fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. + init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. + --- + See llmfoundry.models.utils.param_init_fns.py for info on other param init config options + """ + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.attn_config = attn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.verbose = verbose + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + self.init_config = init_config + if 'name' in kwargs: + del kwargs['name'] + if 'loss_fn' in kwargs: + del kwargs['loss_fn'] + super().__init__(**kwargs) + self._validate_config() + + def _set_config_defaults(self, config, config_defaults): + for (k, v) in config_defaults.items(): + if k not in config: + config[k] = v + return config + + def _validate_config(self): + self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) + self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) + if self.d_model % self.n_heads != 0: + raise ValueError('d_model must be divisible by n_heads') + if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): + raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') + if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('alibi only implemented with torch and triton attention.') + if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') + if self.embedding_fraction > 1 or self.embedding_fraction <= 0: + raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') + if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': + raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + if self.init_config.get('name', None) is None: + raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") + if not self.learned_pos_emb and (not self.attn_config['alibi']): + raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') \ No newline at end of file diff --git a/diffusion/model/llava/mpt/modeling_mpt.py b/diffusion/model/llava/mpt/modeling_mpt.py new file mode 100644 index 0000000000000000000000000000000000000000..017b4d33867594ee31db0b0c177ed47af317bd73 --- /dev/null +++ b/diffusion/model/llava/mpt/modeling_mpt.py @@ -0,0 +1,308 @@ +"""A simple, flexible implementation of a GPT model. + +Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py +""" +import math +import warnings +from typing import List, Optional, Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from .attention import attn_bias_shape, build_attn_bias +from .blocks import MPTBlock +from .norm import NORM_CLASS_REGISTRY +from .configuration_mpt import MPTConfig +from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_ +Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +from transformers.utils import logging +logger = logging.get_logger(__name__) + +class MPTPreTrainedModel(PreTrainedModel): + config_class = MPTConfig + base_model_prefix = 'model' + +class MPTModel(MPTPreTrainedModel): + + def __init__(self, config: MPTConfig): + config._validate_config() + super().__init__(config) + self.attn_impl = config.attn_config['attn_impl'] + self.prefix_lm = config.attn_config['prefix_lm'] + self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] + self.alibi = config.attn_config['alibi'] + self.alibi_bias_max = config.attn_config['alibi_bias_max'] + if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): + norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) + raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).') + norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] + self.embedding_fraction = config.embedding_fraction + self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device) + if not self.alibi: + self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device) + self.emb_drop = nn.Dropout(config.emb_pdrop) + self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)]) + self.norm_f = norm_class(config.d_model, device=config.init_device) + if config.init_device != 'meta': + self.apply(self.param_init_fn) + self.is_causal = not self.prefix_lm + self._attn_bias_initialized = False + self.attn_bias = None + self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id) + if config.no_bias: + for module in self.modules(): + if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter): + if config.verbose: + warnings.warn(f'Removing bias ({module.bias}) from {module}.') + module.register_parameter('bias', None) + if config.verbose and config.verbose > 2: + print(self) + if 'verbose' not in self.config.init_config: + self.config.init_config['verbose'] = self.config.verbose + if self.config.init_config['verbose'] > 1: + init_fn_name = self.config.init_config['name'] + warnings.warn(f'Using {init_fn_name} initialization.') + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, value): + self.wte = value + + @torch.no_grad() + def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None): + if not self._attn_bias_initialized: + if self.attn_bias_shape: + self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) + self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max) + self._attn_bias_initialized = True + if self.attn_impl == 'flash': + return (self.attn_bias, attention_mask) + if self.attn_bias is not None: + self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) + attn_bias = self.attn_bias + if self.prefix_lm: + assert isinstance(attn_bias, torch.Tensor) + assert isinstance(prefix_mask, torch.Tensor) + attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) + if self.attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) + attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + if attention_mask is not None: + s_k = attention_mask.shape[-1] + if attn_bias is None: + attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) + else: + attn_bias = attn_bias[:, :, :, -s_k:] + if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: + raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) + return (attn_bias, None) + + def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): + (s_k, s_q) = attn_bias.shape[-2:] + if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: + raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') + seq_len = prefix_mask.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + attn_bias = attn_bias[..., :seq_len, :seq_len] + causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) + prefix = prefix_mask.view(-1, 1, 1, seq_len) + cannot_attend = ~torch.logical_or(causal, prefix.bool()) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): + seq_len = sequence_id.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + attn_bias = attn_bias[..., :seq_len, :seq_len] + cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, tok_emb: Optional[torch.FloatTensor]=None): + return_dict = return_dict if return_dict is not None else self.config.return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if attention_mask is not None: + attention_mask = attention_mask.bool() + if prefix_mask is not None: + prefix_mask = prefix_mask.bool() + if not return_dict: + raise NotImplementedError('return_dict False is not implemented yet for MPT') + if output_attentions: + raise NotImplementedError('output_attentions is not implemented yet for MPT') + if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training: + raise NotImplementedError('MPT does not support training with left padding.') + if self.prefix_lm and prefix_mask is None: + raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') + if self.training: + if self.attn_uses_sequence_id and sequence_id is None: + raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.') + elif self.attn_uses_sequence_id is False and sequence_id is not None: + warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.') + if input_ids is not None: + S = input_ids.size(1) + assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' + tok_emb = self.wte(input_ids) + else: + assert tok_emb is not None + S = tok_emb.size(1) + if self.alibi: + x = tok_emb + else: + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') + past_position = past_key_values[0][0].size(1) + if S + past_position > self.config.max_seq_len: + raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') + pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0) + if attention_mask is not None: + pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0) + pos_emb = self.wpe(pos) + x = tok_emb + pos_emb + if self.embedding_fraction == 1: + x = self.emb_drop(x) + else: + x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction) + assert isinstance(self.emb_drop, nn.Module) + x = self.emb_drop(x_shrunk) + (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) + if use_cache and past_key_values is None: + past_key_values = [() for _ in range(self.config.n_layers)] + all_hidden_states = () if output_hidden_states else None + for (b_idx, block) in enumerate(self.blocks): + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + past_key_value = past_key_values[b_idx] if past_key_values is not None else None + if self.gradient_checkpointing and self.training: + (x, past_key_value) = torch.utils.checkpoint.checkpoint( + block, + x, past_key_value, attn_bias, attention_mask, self.is_causal + ) + else: + (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) + if past_key_values is not None: + past_key_values[b_idx] = past_key_value + x = self.norm_f(x) + return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states) + + def param_init_fn(self, module): + init_fn_name = self.config.init_config['name'] + MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) + + def fsdp_wrap_fn(self, module): + return isinstance(module, MPTBlock) + + def activation_checkpointing_fn(self, module): + return isinstance(module, MPTBlock) + +class MPTForCausalLM(MPTPreTrainedModel): + + def __init__(self, config: MPTConfig): + super().__init__(config) + if not config.tie_word_embeddings: + raise ValueError('MPTForCausalLM only supports tied word embeddings') + self.transformer = MPTModel(config) + self.logit_scale = None + if config.logit_scale is not None: + logit_scale = config.logit_scale + if isinstance(logit_scale, str): + if logit_scale == 'inv_sqrt_d_model': + logit_scale = 1 / math.sqrt(config.d_model) + else: + raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + self.logit_scale = logit_scale + + def get_input_embeddings(self): + return self.transformer.wte + + def set_input_embeddings(self, value): + self.transformer.wte = value + + def get_output_embeddings(self): + return self.transformer.wte + + def set_output_embeddings(self, new_embeddings): + self.transformer.wte = new_embeddings + + def set_decoder(self, decoder): + self.transformer = decoder + + def get_decoder(self): + return self.transformer + + def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None): + return_dict = return_dict if return_dict is not None else self.config.return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) + logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight) + if self.logit_scale is not None: + if self.logit_scale == 0: + warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') + logits *= self.logit_scale + loss = None + if labels is not None: + labels = torch.roll(labels, shifts=-1) + labels[:, -1] = -100 + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) + return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) + + def param_init_fn(self, module): + init_fn_name = self.config.init_config['name'] + MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) + + def fsdp_wrap_fn(self, module): + return isinstance(module, MPTBlock) + + def activation_checkpointing_fn(self, module): + return isinstance(module, MPTBlock) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None: + raise NotImplementedError('inputs_embeds is not implemented for MPT yet') + attention_mask = kwargs['attention_mask'].bool() + if attention_mask[:, -1].sum() != attention_mask.shape[0]: + raise NotImplementedError('MPT does not support generation with right padding.') + if self.transformer.attn_uses_sequence_id and self.training: + sequence_id = torch.zeros_like(input_ids[:1]) + else: + sequence_id = None + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if self.transformer.prefix_lm: + prefix_mask = torch.ones_like(attention_mask) + if kwargs.get('use_cache') == False: + raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') + else: + prefix_mask = None + return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)} + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + """Used by HuggingFace generate when using beam search with kv-caching. + + See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 + for an example in transformers. + """ + reordered_past = [] + for layer_past in past_key_values: + reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] + return reordered_past \ No newline at end of file diff --git a/diffusion/model/llava/mpt/norm.py b/diffusion/model/llava/mpt/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..bec4a4ca3304c2188312387743a49b75015542be --- /dev/null +++ b/diffusion/model/llava/mpt/norm.py @@ -0,0 +1,56 @@ +import torch + +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + +class LPLayerNorm(torch.nn.LayerNorm): + + def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): + super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) + +def rms_norm(x, weight=None, eps=1e-05): + output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + if weight is not None: + return output * weight + return output + +class RMSNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): + super().__init__() + self.eps = eps + if weight: + self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) + else: + self.register_parameter('weight', None) + + def forward(self, x): + return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) + +class LPRMSNorm(RMSNorm): + + def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): + super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) + + def forward(self, x): + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + with torch.autocast(enabled=False, device_type=x.device.type): + return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) +NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} \ No newline at end of file diff --git a/diffusion/model/llava/mpt/param_init_fns.py b/diffusion/model/llava/mpt/param_init_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..418b83ca2363288046f4b48b1d706c5607341fb5 --- /dev/null +++ b/diffusion/model/llava/mpt/param_init_fns.py @@ -0,0 +1,181 @@ +import math +import warnings +from collections.abc import Sequence +from functools import partial +from typing import Optional, Tuple, Union +import torch +from torch import nn +from .norm import NORM_CLASS_REGISTRY + +def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): + del kwargs + if verbose > 1: + warnings.warn(f"Initializing network using module's reset_parameters attribute") + if hasattr(module, 'reset_parameters'): + module.reset_parameters() + +def fused_init_helper_(module: nn.Module, init_fn_): + _fused = getattr(module, '_fused', None) + if _fused is None: + raise RuntimeError(f'Internal logic error') + (dim, splits) = _fused + splits = (0, *splits, module.weight.size(dim)) + for (s, e) in zip(splits[:-1], splits[1:]): + slice_indices = [slice(None)] * module.weight.ndim + slice_indices[dim] = slice(s, e) + init_fn_(module.weight[slice_indices]) + +def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + del kwargs + if verbose > 1: + warnings.warn(f'If model has bias parameters they are initialized to 0.') + init_div_is_residual = init_div_is_residual + if init_div_is_residual is False: + div_is_residual = 1.0 + elif init_div_is_residual is True: + div_is_residual = math.sqrt(2 * n_layers) + elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): + div_is_residual = init_div_is_residual + elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): + div_is_residual = float(init_div_is_residual) + else: + div_is_residual = 1.0 + raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') + if init_div_is_residual is not False: + if verbose > 1: + warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') + if isinstance(module, nn.Linear): + if hasattr(module, '_fused'): + fused_init_helper_(module, init_fn_) + else: + init_fn_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + if init_div_is_residual is not False and getattr(module, '_is_residual', False): + with torch.no_grad(): + module.weight.div_(div_is_residual) + elif isinstance(module, nn.Embedding): + if emb_init_std is not None: + std = emb_init_std + if std == 0: + warnings.warn(f'Embedding layer initialized to 0.') + emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) + if verbose > 1: + warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') + elif emb_init_uniform_lim is not None: + lim = emb_init_uniform_lim + if isinstance(lim, Sequence): + if len(lim) > 2: + raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') + if lim[0] == lim[1]: + warnings.warn(f'Embedding layer initialized to {lim[0]}.') + else: + if lim == 0: + warnings.warn(f'Embedding layer initialized to 0.') + lim = [-lim, lim] + (a, b) = lim + emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) + if verbose > 1: + warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') + else: + emb_init_fn_ = init_fn_ + emb_init_fn_(module.weight) + elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): + if verbose > 1: + warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') + if hasattr(module, 'weight') and module.weight is not None: + torch.nn.init.ones_(module.weight) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + if module._qkv_same_embed_dim: + assert module.in_proj_weight is not None + assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) + assert d_model is not None + _d = d_model + splits = (0, _d, 2 * _d, 3 * _d) + for (s, e) in zip(splits[:-1], splits[1:]): + init_fn_(module.in_proj_weight[s:e]) + else: + assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) + assert module.in_proj_weight is None + init_fn_(module.q_proj_weight) + init_fn_(module.k_proj_weight) + init_fn_(module.v_proj_weight) + if module.in_proj_bias is not None: + torch.nn.init.zeros_(module.in_proj_bias) + if module.bias_k is not None: + torch.nn.init.zeros_(module.bias_k) + if module.bias_v is not None: + torch.nn.init.zeros_(module.bias_v) + init_fn_(module.out_proj.weight) + if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): + with torch.no_grad(): + module.out_proj.weight.div_(div_is_residual) + if module.out_proj.bias is not None: + torch.nn.init.zeros_(module.out_proj.bias) + else: + for _ in module.parameters(recurse=False): + raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') + +def _normal_init_(std, mean=0.0): + return partial(torch.nn.init.normal_, mean=mean, std=std) + +def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + del kwargs + init_fn_ = _normal_init_(std=std) + if verbose > 1: + warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') + generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + del kwargs + if init_std is None: + raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") + _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + del kwargs + std = math.sqrt(2 / (5 * d_model)) + _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + """From section 2.3.1 of GPT-NeoX-20B: + + An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) + see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 + and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py + """ + del kwargs + residual_div = n_layers / math.sqrt(10) + if verbose > 1: + warnings.warn(f'setting init_div_is_residual to {residual_div}') + small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): + del kwargs + if verbose > 1: + warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') + kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) + generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): + del kwargs + if verbose > 1: + warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') + kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) + generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): + del kwargs + xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) + if verbose > 1: + warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') + generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + +def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): + xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) + if verbose > 1: + warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') + generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) +MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} \ No newline at end of file diff --git a/diffusion/model/nets/PixArt.py b/diffusion/model/nets/PixArt.py new file mode 100644 index 0000000000000000000000000000000000000000..fddd3dc299f467d65c4e1f0fad452f87414e70f5 --- /dev/null +++ b/diffusion/model/nets/PixArt.py @@ -0,0 +1,315 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math +import torch +import torch.nn as nn +import os +import numpy as np +from timm.models.layers import DropPath +from timm.models.vision_transformer import PatchEmbed, Mlp + +from diffusion.model.builder import MODELS +from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer +from diffusion.utils.logger import get_root_logger + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + self.sampling = sampling + self.sr_ratio = sr_ratio + + def forward(self, x, y, t, mask=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +############################################################################# +# Core PixArt Model # +################################################################################# +@MODELS.register_module() +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=1.0, + config=None, + model_max_length=120, + qk_norm=False, + kv_compress_config=None, + **kwargs, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.depth = depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + num_patches = self.x_embedder.num_patches + self.base_size = input_size // self.patch_size + # Will use fixed sin-cos embedding: + self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length) + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.kv_compress_config = kv_compress_config + if kv_compress_config is None: + self.kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + sampling=self.kv_compress_config['sampling'], + sr_ratio=int( + self.kv_compress_config['scale_factor'] + ) if i in self.kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + self.initialize_weights() + + if config: + logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + logger.warning(f"position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}") + logger.warning(f"kv compress config: {self.kv_compress_config}") + else: + print(f'Warning: position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}') + print(f"kv compress config: {self.kv_compress_config}") + + + def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, mask) + return model_out.chunk(2, dim=1)[0] + + def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): + """ + Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, y, mask, kwargs) + model_out = model_out['x'] if isinstance(model_out, dict) else model_out + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), + pe_interpolation=self.pe_interpolation, base_size=self.base_size + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + @property + def dtype(self): + return next(self.parameters()).dtype + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = to_2tuple(grid_size) + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# PixArt Configs # +################################################################################# +@MODELS.register_module() +def PixArt_XL_2(**kwargs): + return PixArt(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) diff --git a/diffusion/model/nets/PixArtMS.py b/diffusion/model/nets/PixArtMS.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9ffaf21cec1991d66ed9eda35a704c5bee0c2f --- /dev/null +++ b/diffusion/model/nets/PixArtMS.py @@ -0,0 +1,293 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import torch +import torch.nn as nn +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +from diffusion.model.builder import MODELS +from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder +from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class PixArtMSBlock(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + + def forward(self, x, y, t, mask=None, HW=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +############################################################################# +# Core PixArt Model # +################################################################################# +@MODELS.register_module() +class PixArtMS(PixArt): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=1., + config=None, + model_max_length=120, + micro_condition=False, + qk_norm=False, + kv_compress_config=None, + **kwargs, + ): + super().__init__( + input_size=input_size, + patch_size=patch_size, + in_channels=in_channels, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + class_dropout_prob=class_dropout_prob, + learn_sigma=learn_sigma, + pred_sigma=pred_sigma, + drop_path=drop_path, + pe_interpolation=pe_interpolation, + config=config, + model_max_length=model_max_length, + qk_norm=qk_norm, + kv_compress_config=kv_compress_config, + **kwargs, + ) + self.h = self.w = 0 + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True) + self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length) + self.micro_conditioning = micro_condition + if self.micro_conditioning: + self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed + self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + if kv_compress_config is None: + kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtMSBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + sampling=kv_compress_config['sampling'], + sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + self.initialize() + + def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + bs = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + pos_embed = torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, + base_size=self.base_size + ) + ).unsqueeze(0).to(x.device).to(self.dtype) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep) # (N, D) + + if self.micro_conditioning: + c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) + csize = self.csize_embedder(c_size, bs) # (N, D) + ar = self.ar_embedder(ar, bs) # (N, D) + t = t + torch.cat([csize, ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + return x + + def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs) + return model_out.chunk(2, dim=1)[0] + + def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, mask=None, **kwargs): + """ + Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, y, mask, data_info=data_info, **kwargs) + model_out = model_out['x'] if isinstance(model_out, dict) else model_out + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + assert self.h * self.w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) + return imgs + + def initialize(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + if self.micro_conditioning: + nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +################################################################################# +# PixArt Configs # +################################################################################# +@MODELS.register_module() +def PixArtMS_XL_2(**kwargs): + return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) diff --git a/diffusion/model/nets/PixArt_blocks.py b/diffusion/model/nets/PixArt_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..01cd8a3e966233370f025fadc9f94c4260d77884 --- /dev/null +++ b/diffusion/model/nets/PixArt_blocks.py @@ -0,0 +1,441 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers.ops +from einops import rearrange +from timm.models.vision_transformer import Mlp, Attention as Attention_ + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model*2) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(d_model, d_model) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class AttentionKVCompress(Attention_): + """Multi-head Attention block with KV token compression and qk norm.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + sampling='conv', + sr_ratio=1, + qk_norm=False, + **block_kwargs, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + """ + super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs) + + self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every'] + self.sr_ratio = sr_ratio + if sr_ratio > 1 and sampling == 'conv': + # Avg Conv Init. + self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio) + self.sr.weight.data.fill_(1/sr_ratio**2) + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(dim) + if qk_norm: + self.q_norm = nn.LayerNorm(dim) + self.k_norm = nn.LayerNorm(dim) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def downsample_2d(self, tensor, H, W, scale_factor, sampling=None): + if sampling is None or scale_factor == 1: + return tensor + B, N, C = tensor.shape + + if sampling == 'uniform_every': + return tensor[:, ::scale_factor], int(N // scale_factor) + + tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2) + new_H, new_W = int(H / scale_factor), int(W / scale_factor) + new_N = new_H * new_W + + if sampling == 'ave': + tensor = F.interpolate( + tensor, scale_factor=1 / scale_factor, mode='nearest' + ).permute(0, 2, 3, 1) + elif sampling == 'uniform': + tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1) + elif sampling == 'conv': + tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1) + tensor = self.norm(tensor) + else: + raise ValueError + + return tensor.reshape(B, new_N, C).contiguous(), new_N + + def forward(self, x, mask=None, HW=None, block_id=None): + B, N, C = x.shape + new_N = N + if HW is None: + H = W = int(N ** 0.5) + else: + H, W = HW + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) + dtype = q.dtype + q = self.q_norm(q) + k = self.k_norm(k) + + # KV compression + if self.sr_ratio > 1: + k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) + v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) + + q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + + use_fp32_attention = getattr(self, 'fp32_attention', False) # necessary for NAN loss + if use_fp32_attention: + q, k, v = q.float(), k.float(), v.float() + + attn_bias = None + if mask is not None: + attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + x = x.view(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +################################################################################# +# AMP attention with fp32 softmax to fix loss NaN problem during training # +################################################################################# +class Attention(Attention_): + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + use_fp32_attention = getattr(self, 'fp32_attention', False) + if use_fp32_attention: + q, k = q.float(), k.float() + with torch.cuda.amp.autocast(enabled=not use_fp32_attention): + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) + self.out_channels = out_channels + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MaskFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DecoderLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, decoder_hidden_size): + super().__init__() + self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_decoder(x), shift, scale) + x = self.linear(x) + return x + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs//s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120): + super().__init__() + self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class CaptionEmbedderDoubleBr(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120): + super().__init__() + self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) + self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5) + self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5) + self.uncond_prob = uncond_prob + + def token_drop(self, global_caption, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption) + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return global_caption, caption + + def forward(self, caption, train, force_drop_ids=None): + assert caption.shape[2: ] == self.y_embedding.shape + global_caption = caption.mean(dim=2).squeeze() + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids) + y_embed = self.proj(global_caption) + return y_embed, caption \ No newline at end of file diff --git a/diffusion/model/nets/__init__.py b/diffusion/model/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..837498bfc42e284fd9f7e081a66a87250a395009 --- /dev/null +++ b/diffusion/model/nets/__init__.py @@ -0,0 +1,2 @@ +from .PixArt import PixArt, PixArt_XL_2 +from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock \ No newline at end of file diff --git a/diffusion/model/respace.py b/diffusion/model/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..61cd9dfb329741399f6ab7d0d3e7394bcb2c3112 --- /dev/null +++ b/diffusion/model/respace.py @@ -0,0 +1,134 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def training_losses_diffusers( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, timestep, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype) + new_ts = map_tensor[timestep] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, timestep=new_ts, **kwargs) diff --git a/diffusion/model/sa_solver.py b/diffusion/model/sa_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..e51cdc8c8b941f68682b9bac576d0d35a496ea7a --- /dev/null +++ b/diffusion/model/sa_solver.py @@ -0,0 +1,1149 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Thanks to DPM-Solver for their code base""" + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + def edm_sigma(self, t): + return self.marginal_std(t) / self.marginal_alpha(t) + + def edm_inverse_sigma(self, edmsigma): + alpha = 1 / (edmsigma ** 2 + 1).sqrt() + sigma = alpha * edmsigma + lambda_t = torch.log(alpha / sigma) + t = self.inverse_lambda(lambda_t) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Thanks to DPM-Solver for their code base""" + """Create a wrapper function for the noise prediction model. + SA-Solver needs to solve the continuous-time diffusion SDEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for SA-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t[0] * output) / sigma_t[0] + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t[0] * output + sigma_t[0] * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t[0] * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class SASolver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="data_prediction", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995 + ): + """ + Construct a SA-Solver + The default value for algorithm_type is "data_prediction" and we recommend not to change it to + "noise_prediction". For details, please see Appendix A.2.4 in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["data_prediction", "noise_prediction"] + + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + self.predict_x0 = algorithm_type == "data_prediction" + + self.sigma_min = float(self.noise_schedule.edm_sigma(torch.tensor([1e-3]))) + self.sigma_max = float(self.noise_schedule.edm_sigma(torch.tensor([1]))) + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, order, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = lambda_T + torch.linspace(torch.tensor(0.).cpu().item(), + (lambda_0 - lambda_T).cpu().item() ** (1. / order), N + 1).pow( + order).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time': + t = torch.linspace(t_T ** (1. / order), t_0 ** (1. / order), N + 1).pow(order).to(device) + return t + elif skip_type == 'karras': + sigma_min = max(0.002, self.sigma_min) + sigma_max = min(80, self.sigma_max) + sigma_steps = torch.linspace(sigma_max ** (1. / 7), sigma_min ** (1. / 7), N + 1).pow(7).to(device) + t = self.noise_schedule.edm_inverse_sigma(sigma_steps) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time' or 'karras'".format(skip_type)) + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def get_coefficients_exponential_negative(self, order, interval_start, interval_end): + """ + Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end + For calculating the coefficient of gradient terms after the lagrange interpolation, + see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + For noise_prediction formula. + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + if order == 0: + return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) + elif order == 1: + return torch.exp(-interval_end) * ( + (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1)) + elif order == 2: + return torch.exp(-interval_end) * ( + (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - ( + interval_end ** 2 + 2 * interval_end + 2)) + elif order == 3: + return torch.exp(-interval_end) * ( + (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp( + interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6)) + + def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + For calculating the coefficient of gradient terms after the lagrange interpolation, + see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + For data_prediction formula. + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau ** 2) * interval_end + interval_start_cov = (1 + tau ** 2) * interval_start + + if order == 0: + return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / ( + (1 + tau ** 2)) + elif order == 1: + return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2) + elif order == 2: + return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - ( + interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3) + elif order == 3: + return torch.exp(interval_end_cov) * ( + (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - ( + interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4) + + def lagrange_polynomial_coefficient(self, order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + For lagrange interpolation + """ + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1]] + elif order == 1: + return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], + [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2]) / denominator1, + lambda_list[1] * lambda_list[2] / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2]) / denominator2, + lambda_list[0] * lambda_list[2] / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1]) / denominator3, + lambda_list[0] * lambda_list[1] / denominator3] + ] + elif order == 3: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * ( + lambda_list[0] - lambda_list[3]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * ( + lambda_list[1] - lambda_list[3]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * ( + lambda_list[2] - lambda_list[3]) + denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * ( + lambda_list[3] - lambda_list[2]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[ + 3]) / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3], + + [1 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[ + 2]) / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4] + + ] + + def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): + """ + Calculate the coefficient of gradients. + """ + assert order in [1, 2, 3, 4] + assert order == len(lambda_list), 'the length of lambda list must be equal to the order' + coefficients = [] + lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) + for i in range(order): + coefficient = 0 + for j in range(order): + if self.predict_x0: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( + order - 1 - j, interval_start, interval_end, tau) + else: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( + order - 1 - j, interval_start, interval_end) + coefficients.append(coefficient) + assert len(coefficients) == order, 'the length of coefficients does not match the order' + return coefficients + + def adams_bashforth_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Predictor, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_prev_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def adams_moulton_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + t_list = t_prev_list + [t] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def adams_bashforth_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_prev_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda( + t_prev_list[-2])) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda( + t_prev_list[-2])) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def adams_moulton_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + t_list = t_prev_list + [t] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def sample_few_steps(self, x, tau, steps=5, t_start=None, t_end=None, skip_type='time', skip_order=1, + predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False + ): + """ + For the PC-mode, please refer to the wiki page + https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode + 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations + We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs. + """ + + skip_first_step = False + skip_final_step = True + lower_order_final = True + denoise_to_zero = False + + assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE' + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + + device = x.device + intermediates = [] + with torch.no_grad(): + assert steps >= max(predictor_order, corrector_order - 1) + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order, + device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + noise = torch.randn_like(x) + t_prev_list = [t] + # do not evaluate if skip_first_step + if skip_first_step: + if self.predict_x0: + alpha_t = self.noise_schedule.marginal_alpha(t) + sigma_t = self.noise_schedule.marginal_std(t) + model_prev_list = [(1 - sigma_t) / alpha_t * x] + else: + model_prev_list = [x] + else: + model_prev_list = [self.model_fn(x, t)] + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # determine the first several values + for step in tqdm(range(1, max(predictor_order, corrector_order - 1))): + + t = timesteps[step] + predictor_order_used = min(predictor_order, step) + corrector_order_used = min(corrector_order, step + 1) + noise = torch.randn_like(x) + # predictor step + x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + # evaluation step + model_x = self.model_fn(x_p, t) + + # update model_list + model_prev_list.append(model_x) + # corrector step + if corrector_order > 0: + x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + else: + x = x_p + + # evaluation step if correction and mode = pece + if corrector_order > 0: + if pc_mode == 'PECE': + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + + for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)): + if lower_order_final: + predictor_order_used = min(predictor_order, steps - step + 1) + corrector_order_used = min(corrector_order, steps - step + 2) + + else: + predictor_order_used = predictor_order + corrector_order_used = corrector_order + t = timesteps[step] + noise = torch.randn_like(x) + + # predictor step + if skip_final_step and step == steps and not denoise_to_zero: + x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=0, + model_prev_list=model_prev_list, + t_prev_list=t_prev_list, noise=noise, t=t) + else: + x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, + t_prev_list=t_prev_list, noise=noise, t=t) + + # evaluation step + # do not evaluate if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_x = self.model_fn(x_p, t) + + # update model_list + # do not update if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_prev_list.append(model_x) + + # corrector step + # do not correct if skip_final_step and step = steps + if corrector_order > 0: + if not skip_final_step or step < steps: + x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, + t_prev_list=t_prev_list, noise=noise, t=t) + else: + x = x_p + else: + x = x_p + + # evaluation step if mode = pece and step != steps + if corrector_order > 0: + if pc_mode == 'PECE' and step < steps: + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + del model_prev_list[0] + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + def sample_more_steps(self, x, tau, steps=20, t_start=None, t_end=None, skip_type='time', skip_order=1, + predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False + ): + """ + For the PC-mode, please refer to the wiki page + https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode + 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations + We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs. + """ + + skip_first_step = False + skip_final_step = False + lower_order_final = True + denoise_to_zero = True + + assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE' + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + + device = x.device + intermediates = [] + with torch.no_grad(): + assert steps >= max(predictor_order, corrector_order - 1) + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order, + device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + noise = torch.randn_like(x) + t_prev_list = [t] + # do not evaluate if skip_first_step + if skip_first_step: + if self.predict_x0: + alpha_t = self.noise_schedule.marginal_alpha(t) + sigma_t = self.noise_schedule.marginal_std(t) + model_prev_list = [(1 - sigma_t) / alpha_t * x] + else: + model_prev_list = [x] + else: + model_prev_list = [self.model_fn(x, t)] + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # determine the first several values + for step in tqdm(range(1, max(predictor_order, corrector_order - 1))): + + t = timesteps[step] + predictor_order_used = min(predictor_order, step) + corrector_order_used = min(corrector_order, step + 1) + noise = torch.randn_like(x) + # predictor step + x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise, + t=t) + # evaluation step + model_x = self.model_fn(x_p, t) + + # update model_list + model_prev_list.append(model_x) + # corrector step + if corrector_order > 0: + x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise, + t=t) + else: + x = x_p + + # evaluation step if mode = pece + if corrector_order > 0: + if pc_mode == 'PECE': + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + + for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)): + if lower_order_final: + predictor_order_used = min(predictor_order, steps - step + 1) + corrector_order_used = min(corrector_order, steps - step + 2) + + else: + predictor_order_used = predictor_order + corrector_order_used = corrector_order + t = timesteps[step] + noise = torch.randn_like(x) + + # predictor step + if skip_final_step and step == steps and not denoise_to_zero: + x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=0, + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + else: + x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + + # evaluation step + # do not evaluate if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_x = self.model_fn(x_p, t) + + # update model_list + # do not update if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_prev_list.append(model_x) + + # corrector step + # do not correct if skip_final_step and step = steps + if corrector_order > 0: + if not skip_final_step or step < steps: + x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + else: + x = x_p + else: + x = x_p + + # evaluation step if mode = pece and step != steps + if corrector_order > 0: + if pc_mode == 'PECE' and step < steps: + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + del model_prev_list[0] + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + def sample(self, mode, x, tau, steps, t_start=None, t_end=None, skip_type='time', skip_order=1, predictor_order=3, + corrector_order=4, pc_mode='PEC', return_intermediate=False + ): + """ + For the PC-mode, please refer to the wiki page + https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode + 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations + We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs. + + 'few_steps' mode is recommended. The differences between 'few_steps' and 'more_steps' are as below: + 1) 'few_steps' do not correct at final step and do not denoise to zero, while 'more_steps' do these two. + Thus the NFEs for 'few_steps' = steps, NFEs for 'more_steps' = steps + 2 + For most of the experiments and tasks, we find these two operations do not have much help to sample quality. + 2) 'few_steps' use a rescaling trick as in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + We find it will slightly improve the sample quality especially in few steps. + """ + assert mode in ['few_steps', 'more_steps'], "mode must be either 'few_steps' or 'more_steps'" + if mode == 'few_steps': + return self.sample_few_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type, + skip_order=skip_order, predictor_order=predictor_order, + corrector_order=corrector_order, pc_mode=pc_mode, + return_intermediate=return_intermediate) + else: + return self.sample_more_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type, + skip_order=skip_order, predictor_order=predictor_order, + corrector_order=corrector_order, pc_mode=pc_mode, + return_intermediate=return_intermediate) + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/diffusion/model/t5.py b/diffusion/model/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..19cb6f93378cf2419c14ac03174949819e1a5826 --- /dev/null +++ b/diffusion/model/t5.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +import os +import re +import html +import urllib.parse as ul + +import ftfy +import torch +from bs4 import BeautifulSoup +from transformers import T5EncoderModel, AutoTokenizer +from huggingface_hub import hf_hub_download + +class T5Embedder: + + available_models = ['t5-v1_1-xxl'] + bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + + def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, + t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + if t5_model_kwargs is None: + t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} + if use_offload_folder is not None: + t5_model_kwargs['offload_folder'] = use_offload_folder + t5_model_kwargs['device_map'] = { + 'shared': self.device, + 'encoder.embed_tokens': self.device, + 'encoder.block.0': self.device, + 'encoder.block.1': self.device, + 'encoder.block.2': self.device, + 'encoder.block.3': self.device, + 'encoder.block.4': self.device, + 'encoder.block.5': self.device, + 'encoder.block.6': self.device, + 'encoder.block.7': self.device, + 'encoder.block.8': self.device, + 'encoder.block.9': self.device, + 'encoder.block.10': self.device, + 'encoder.block.11': self.device, + 'encoder.block.12': 'disk', + 'encoder.block.13': 'disk', + 'encoder.block.14': 'disk', + 'encoder.block.15': 'disk', + 'encoder.block.16': 'disk', + 'encoder.block.17': 'disk', + 'encoder.block.18': 'disk', + 'encoder.block.19': 'disk', + 'encoder.block.20': 'disk', + 'encoder.block.21': 'disk', + 'encoder.block.22': 'disk', + 'encoder.block.23': 'disk', + 'encoder.final_layer_norm': 'disk', + 'encoder.dropout': 'disk', + } + else: + t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') + self.dir_or_name = dir_or_name + tokenizer_path, path = dir_or_name, dir_or_name + if local_cache: + cache_dir = os.path.join(self.cache_dir, dir_or_name) + tokenizer_path, path = cache_dir, cache_dir + elif dir_or_name in self.available_models: + cache_dir = os.path.join(self.cache_dir, dir_or_name) + for filename in [ + 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' + ]: + hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + tokenizer_path, path = cache_dir, cache_dir + else: + cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') + for filename in [ + 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + ]: + hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + tokenizer_path = cache_dir + + print(tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + + text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] + text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] + + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=text_tokens_and_mask['input_ids'].to(self.device), + attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), + )['last_hidden_state'].detach() + return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) + + def text_preprocessing(self, text): + if self.use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() diff --git a/diffusion/model/timestep_sampler.py b/diffusion/model/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ec6a583841ea7e39eb02d5027a5cf2f52890195b --- /dev/null +++ b/diffusion/model/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs, device=local_ts.device) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs, device=local_losses.device) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/diffusion/model/utils.py b/diffusion/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9de176aed8c04823490db42d30509bf96140b75 --- /dev/null +++ b/diffusion/model/utils.py @@ -0,0 +1,512 @@ +import os +import sys +import torch.nn as nn +from torch.utils.checkpoint import checkpoint, checkpoint_sequential +import torch.nn.functional as F +import torch +import torch.distributed as dist +import re +import math +from collections.abc import Iterable +from itertools import repeat +from torchvision import transforms as T +import random +from PIL import Image + + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + +def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): + assert isinstance(model, nn.Module) + + def set_attr(module): + module.grad_checkpointing = True + module.fp32_attention = use_fp32_attention + module.grad_checkpointing_step = gc_step + model.apply(set_attr) + + +def auto_grad_checkpoint(module, *args, **kwargs): + if getattr(module, 'grad_checkpointing', False): + if isinstance(module, Iterable): + gc_step = module[0].grad_checkpointing_step + return checkpoint_sequential(module, gc_step, *args, **kwargs) + else: + return checkpoint(module, *args, **kwargs) + return module(*args, **kwargs) + + +def checkpoint_sequential(functions, step, input, *args, **kwargs): + + # Hack for keyword-only parameter in a python 2.7-compliant way + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs: + raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + + def run_function(start, end, functions): + def forward(input): + for j in range(start, end + 1): + input = functions[j](input, *args) + return input + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = list(functions.children()) + + # the last chunk has to be non-volatile + end = -1 + segment = len(functions) // step + for start in range(0, step * (segment - 1), step): + end = start + step - 1 + input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) + return run_function(end + 1, len(functions) - 1, functions)(input) + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + +def mean_flat(tensor): + return tensor.mean(dim=list(range(1, tensor.ndim))) + + +################################################################################# +# Token Masking and Unmasking # +################################################################################# +def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0): + """ + Get the binary mask for the input sequence. + Args: + - batch: batch size + - length: sequence length + - mask_ratio: ratio of tokens to mask + - data_info: dictionary with info for reconstruction + return: + mask_dict with following keys: + - mask: binary mask, 0 is keep, 1 is remove + - ids_keep: indices of tokens to keep + - ids_restore: indices to restore the original order + """ + assert mask_type in ['random', 'fft', 'laplacian', 'group'] + mask = torch.ones([batch, length], device=device) + len_keep = int(length * (1 - mask_ratio)) - extra_len + + if mask_type == 'random' or mask_type == 'group': + noise = torch.rand(batch, length, device=device) # noise in [0, 1] + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_removed = ids_shuffle[:, len_keep:] + + elif mask_type in ['fft', 'laplacian']: + if 'strength' in data_info: + strength = data_info['strength'] + + else: + N = data_info['N'][0] + img = data_info['ori_img'] + # 获取原图的尺寸信息 + _, C, H, W = img.shape + if mask_type == 'fft': + # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) + reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N)) + fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5)) + # 取绝对值并求和获取频率强度 + strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape((batch, -1,)) + elif type == 'laplacian': + laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape(1, 1, 3, 3) + laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1) + # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) + reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N) + laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C) + strength = laplacian_response.sum(dim=[1, 2, 3]).reshape((batch, -1,)) + + # 对频率强度进行归一化,然后使用torch.multinomial进行采样 + probabilities = strength / (strength.max(dim=1)[0][:, None]+1e-5) + ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False) + ids_keep = ids_shuffle[:, :len_keep] + ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_removed = ids_shuffle[:, len_keep:] + + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return {'mask': mask, + 'ids_keep': ids_keep, + 'ids_restore': ids_restore, + 'ids_removed': ids_removed} + + +def mask_out_token(x, ids_keep, ids_removed=None): + """ + Mask out the tokens specified by ids_keep. + Args: + - x: input sequence, [N, L, D] + - ids_keep: indices of tokens to keep + return: + - x_masked: masked sequence + """ + N, L, D = x.shape # batch, length, dim + x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + if ids_removed is not None: + x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D)) + return x_remain, x_masked + else: + return x_remain + + +def mask_tokens(x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + +def unmask_tokens(x, ids_restore, mask_token): + # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D] + mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x = torch.cat([x, mask_tokens], dim=1) + x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + return x + + +# Parse 'None' to None and others to float value +def parse_float_none(s): + assert isinstance(s, str) + return None if s == 'None' else float(s) + + +#---------------------------------------------------------------------------- +# Parse a comma separated list of numbers or ranges and return a list of ints. +# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + +def parse_int_list(s): + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + + +def init_processes(fn, args): + """ Initialize the distributed environment. """ + os.environ['MASTER_ADDR'] = args.master_address + os.environ['MASTER_PORT'] = str(random.randint(2000, 6000)) + print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}') + print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}') + torch.cuda.set_device(args.local_rank) + dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size) + fn(args) + if args.global_size > 1: + cleanup() + + +def mprint(*args, **kwargs): + """ + Print only from rank 0. + """ + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def cleanup(): + """ + End DDP training. + """ + dist.barrier() + mprint("Done!") + dist.barrier() + dist.destroy_process_group() + + +#---------------------------------------------------------------------------- +# logging info. +class Logger(object): + """ + Redirect stderr to stdout, optionally print stdout to a file, + and optionally force flushing on both stdout and the file. + """ + + def __init__(self, file_name=None, file_mode="w", should_flush=True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def write(self, text): + """Write text to stdout (and a file) and optionally flush.""" + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self): + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self): + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + + +class StackedRandomGenerator: + def __init__(self, device, seeds): + super().__init__() + self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) + + def randn_like(self, input): + return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) + + +def prepare_prompt_ar(prompt, ratios, device='cpu', show=True): + # get aspect_ratio or ar + aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt) + ars = re.findall(r"--ar\s+(\d+:\d+)", prompt) + custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt) + if show: + print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw) + prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0] + if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show: + print("Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating") + if len(aspect_ratios) != 0: + ar = float(aspect_ratios[0].split(':')[0]) / float(aspect_ratios[0].split(':')[1]) + elif len(ars) != 0: + ar = float(ars[0].split(':')[0]) / float(ars[0].split(':')[1]) + else: + ar = 1. + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + if len(custom_hw) != 0: + custom_hw = [float(custom_hw[0].split(':')[0]), float(custom_hw[0].split(':')[1])] + else: + custom_hw = ratios[closest_ratio] + default_hw = ratios[closest_ratio] + prompt_show = f'prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}' + return prompt_clean, prompt_show, torch.tensor(default_hw, device=device)[None], torch.tensor([float(closest_ratio)], device=device)[None], torch.tensor(custom_hw, device=device)[None] + + +def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int): + orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int) + custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int) + + if (orig_hw != custom_hw).all(): + ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1]) + resized_width = int(orig_hw[1] * ratio) + resized_height = int(orig_hw[0] * ratio) + + transform = T.Compose([ + T.Resize((resized_height, resized_width)), + T.CenterCrop(custom_hw.tolist()) + ]) + return transform(samples) + else: + return samples + + +def resize_and_crop_img(img: Image, new_width, new_height): + orig_width, orig_height = img.size + + ratio = max(new_width/orig_width, new_height/orig_height) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + img = img.resize((resized_width, resized_height), Image.LANCZOS) + + left = (resized_width - new_width)/2 + top = (resized_height - new_height)/2 + right = (resized_width + new_width)/2 + bottom = (resized_height + new_height)/2 + + img = img.crop((left, top, right, bottom)) + + return img + + + +def mask_feature(emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index + else: + masked_feature = emb * mask[:, None, :, None] + return masked_feature, emb.shape[2] \ No newline at end of file diff --git a/diffusion/sa_sampler.py b/diffusion/sa_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..09372337f0e5868c5cfc5081418d2bf5907c7a59 --- /dev/null +++ b/diffusion/sa_sampler.py @@ -0,0 +1,94 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np + +from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver +from .model import gaussian_diffusion as gd + + +class SASolverSampler(object): + def __init__(self, model, + noise_schedule="linear", + diffusion_steps=1000, + device='cpu', + ): + super().__init__() + self.model = model + self.device = device + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) + betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) + alphas = 1.0 - betas + self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0))) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + model_kwargs={}, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + device = self.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + self.model, + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + model_kwargs=model_kwargs, + ) + + sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction") + + tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0 + + x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False) + + return x.to(device), None \ No newline at end of file diff --git a/diffusion/sa_solver_diffusers.py b/diffusion/sa_solver_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..8e57e9365bf4659136711a9cc2d9566bc16bafe0 --- /dev/null +++ b/diffusion/sa_solver_diffusers.py @@ -0,0 +1,856 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2309.05019 +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union, Callable + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class SASolverScheduler(SchedulerMixin, ConfigMixin): + """ + `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + predictor_order (`int`, defaults to 2): + The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided + sampling, and `predictor_order=3` for unconditional sampling. + corrector_order (`int`, defaults to 2): + The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided + sampling, and `corrector_order=3` for unconditional sampling. + predictor_corrector_mode (`str`, defaults to `PEC`): + The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast + sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC). + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `data_prediction`): + Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction` + with `solver_order=2` for guided sampling like in Stable Diffusion. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Default = True. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + predictor_order: int = 2, + corrector_order: int = 2, + predictor_corrector_mode: str = 'PEC', + prediction_type: str = "epsilon", + tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "data_prediction", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if algorithm_type not in ["data_prediction", "noise_prediction"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.timestep_list = [None] * max(predictor_order, corrector_order - 1) + self.model_outputs = [None] * max(predictor_order, corrector_order - 1) + + self.tau_func = tau_func + self.predict_x0 = algorithm_type == "data_prediction" + self.lower_order_nums = 0 + self.last_sample = None + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * max(self.config.predictor_order, self.config.corrector_order - 1) + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + + # SA-Solver_data_prediction needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["data_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["noise_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def get_coefficients_exponential_negative(self, order, interval_start, interval_end): + """ + Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + if order == 0: + return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) + elif order == 1: + return torch.exp(-interval_end) * ( + (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1)) + elif order == 2: + return torch.exp(-interval_end) * ( + (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - ( + interval_end ** 2 + 2 * interval_end + 2)) + elif order == 3: + return torch.exp(-interval_end) * ( + (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp( + interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6)) + + def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau ** 2) * interval_end + interval_start_cov = (1 + tau ** 2) * interval_start + + if order == 0: + return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / ( + (1 + tau ** 2)) + elif order == 1: + return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2) + elif order == 2: + return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - ( + interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3) + elif order == 3: + return torch.exp(interval_end_cov) * ( + (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - ( + interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4) + + def lagrange_polynomial_coefficient(self, order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + """ + + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1]] + elif order == 1: + return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], + [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2]) / denominator1, + lambda_list[1] * lambda_list[2] / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2]) / denominator2, + lambda_list[0] * lambda_list[2] / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1]) / denominator3, + lambda_list[0] * lambda_list[1] / denominator3] + ] + elif order == 3: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * ( + lambda_list[0] - lambda_list[3]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * ( + lambda_list[1] - lambda_list[3]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * ( + lambda_list[2] - lambda_list[3]) + denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * ( + lambda_list[3] - lambda_list[2]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[ + 3]) / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3], + + [1 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[ + 2]) / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4] + + ] + + def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): + assert order in [1, 2, 3, 4] + assert order == len(lambda_list), 'the length of lambda list must be equal to the order' + coefficients = [] + lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) + for i in range(order): + coefficient = 0 + for j in range(order): + if self.predict_x0: + + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( + order - 1 - j, interval_start, interval_end, tau) + else: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( + order - 1 - j, interval_start, interval_end) + coefficients.append(coefficient) + assert len(coefficients) == order, 'the length of coefficients does not match the order' + return coefficients + + def stochastic_adams_bashforth_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + noise: torch.FloatTensor, + order: int, + tau: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the SA-Predictor. + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of SA-Predictor at this timestep. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + assert noise is not None + timestep_list = self.timestep_list + model_output_list = self.model_outputs + s0, t = self.timestep_list[-1], prev_timestep + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + gradient_part = torch.zeros_like(sample) + h = lambda_t - lambda_s0 + lambda_list = [] + + for i in range(order): + lambda_list.append(self.lambda_t[timestep_list[-(i + 1)]]) + + gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + + x = sample + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[ + timestep_list[-2]]) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[ + timestep_list[-2]]) + + for i in range(order): + if self.predict_x0: + + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_output_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def stochastic_adams_moulton_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + last_noise: torch.FloatTensor, + this_sample: torch.FloatTensor, + order: int, + tau: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the SA-Corrector. + + Args: + this_model_output (`torch.FloatTensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.FloatTensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.FloatTensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The order of SA-Corrector at this step. + + Returns: + `torch.FloatTensor`: + The corrected sample tensor at the current timestep. + """ + + assert last_noise is not None + timestep_list = self.timestep_list + model_output_list = self.model_outputs + s0, t = self.timestep_list[-1], this_timestep + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + gradient_part = torch.zeros_like(this_sample) + h = lambda_t - lambda_s0 + t_list = timestep_list + [this_timestep] + lambda_list = [] + for i in range(order): + lambda_list.append(self.lambda_t[t_list[-(i + 1)]]) + + model_prev_list = model_output_list + [this_model_output] + + gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + + x = last_sample + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * last_noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the SA-Solver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = ( + step_index > 0 and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + + if use_corrector: + current_tau = self.tau_func(self.timestep_list[-1]) + sample = self.stochastic_adams_moulton_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + last_noise=self.last_noise, + this_sample=sample, + order=self.this_corrector_order, + tau=current_tau, + ) + + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + + if self.config.lower_order_final: + this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - step_index) + this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - step_index + 1) + else: + this_predictor_order = self.config.predictor_order + this_corrector_order = self.config.corrector_order + + self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep + self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep + assert self.this_predictor_order > 0 + assert self.this_corrector_order > 0 + + self.last_sample = sample + self.last_noise = noise + + current_tau = self.tau_func(self.timestep_list[-1]) + prev_sample = self.stochastic_adams_bashforth_update( + model_output=model_output_convert, + prev_timestep=prev_timestep, + sample=sample, + noise=noise, + order=self.this_predictor_order, + tau=current_tau, + ) + + if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1): + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/diffusion/utils/__init__.py b/diffusion/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusion/utils/checkpoint.py b/diffusion/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..51a34370bc87aa9580b087fb0c0864c22d22ab41 --- /dev/null +++ b/diffusion/utils/checkpoint.py @@ -0,0 +1,84 @@ +import os +import re +import torch + +from diffusion.utils.logger import get_root_logger + + +def save_checkpoint(work_dir, + epoch, + model, + model_ema=None, + optimizer=None, + lr_scheduler=None, + keep_last=False, + step=None, + ): + os.makedirs(work_dir, exist_ok=True) + state_dict = dict(state_dict=model.state_dict()) + if model_ema is not None: + state_dict['state_dict_ema'] = model_ema.state_dict() + if optimizer is not None: + state_dict['optimizer'] = optimizer.state_dict() + if lr_scheduler is not None: + state_dict['scheduler'] = lr_scheduler.state_dict() + if epoch is not None: + state_dict['epoch'] = epoch + file_path = os.path.join(work_dir, f"epoch_{epoch}.pth") + if step is not None: + file_path = file_path.split('.pth')[0] + f"_step_{step}.pth" + logger = get_root_logger() + torch.save(state_dict, file_path) + logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.') + if keep_last: + for i in range(epoch): + previous_ckgt = file_path.format(i) + if os.path.exists(previous_ckgt): + os.remove(previous_ckgt) + + +def load_checkpoint(checkpoint, + model, + model_ema=None, + optimizer=None, + lr_scheduler=None, + load_ema=False, + resume_optimizer=True, + resume_lr_scheduler=True, + max_length=120, + ): + assert isinstance(checkpoint, str) + ckpt_file = checkpoint + checkpoint = torch.load(ckpt_file, map_location="cpu") + + state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed'] + for key in state_dict_keys: + if key in checkpoint['state_dict']: + del checkpoint['state_dict'][key] + if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']: + del checkpoint['state_dict_ema'][key] + break + + if load_ema: + state_dict = checkpoint['state_dict_ema'] + else: + state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint + + null_embed = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth', map_location='cpu') + state_dict['y_embedder.y_embedding'] = null_embed['uncond_prompt_embeds'][0] + + missing, unexpect = model.load_state_dict(state_dict, strict=False) + if model_ema is not None: + model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) + if optimizer is not None and resume_optimizer: + optimizer.load_state_dict(checkpoint['optimizer']) + if lr_scheduler is not None and resume_lr_scheduler: + lr_scheduler.load_state_dict(checkpoint['scheduler']) + logger = get_root_logger() + if optimizer is not None: + epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0]) + logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, ' + f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.') + return epoch, missing, unexpect + logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.') + return missing, unexpect diff --git a/diffusion/utils/data_sampler.py b/diffusion/utils/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..30b49ea319dec3b688ebeed61bfe91e421649b93 --- /dev/null +++ b/diffusion/utils/data_sampler.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Sequence +from torch.utils.data import BatchSampler, Sampler, Dataset +from random import shuffle, choice +from copy import deepcopy +from diffusion.utils.logger import get_root_logger + + +class AspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + aspect_ratios: dict, + drop_last: bool = False, + config=None, + valid_num=0, # take as valid aspect-ratio when sample number >= valid_num + **kwargs) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.ratio_nums_gt = kwargs.get('ratio_nums', None) + self.config = config + assert self.ratio_nums_gt + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios.keys()} + self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] + logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") + + def __iter__(self) -> Sequence[int]: + for idx in self.sampler: + data_info = self.dataset.get_data_info(idx) + height, width = data_info['height'], data_info['width'] + ratio = height / width + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the buckets + for bucket in self._aspect_ratio_buckets.values(): + while len(bucket) > 0: + if len(bucket) <= self.batch_size: + if not self.drop_last: + yield bucket[:] + bucket = [] + else: + yield bucket[:self.batch_size] + bucket = bucket[self.batch_size:] + + +class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Assign samples to each bucket + self.ratio_nums_gt = kwargs.get('ratio_nums', None) + assert self.ratio_nums_gt + self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()} + self.original_buckets = {} + self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000] + self.all_available_keys = deepcopy(self.current_available_bucket_keys) + self.exhausted_bucket_keys = [] + self.total_batches = len(self.sampler) // self.batch_size + self._aspect_ratio_count = {} + for k in self.all_available_keys: + self._aspect_ratio_count[float(k)] = 0 + self.original_buckets[float(k)] = [] + logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log')) + logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") + + def __iter__(self) -> Sequence[int]: + i = 0 + for idx in self.sampler: + data_info = self.dataset.get_data_info(idx) + height, width = data_info['height'], data_info['width'] + ratio = height / width + closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))) + if closest_ratio not in self.all_available_keys: + continue + if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]: + self._aspect_ratio_count[closest_ratio] += 1 + self._aspect_ratio_buckets[closest_ratio].append(idx) + self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket + if not self.current_available_bucket_keys: + self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, [] + + if closest_ratio not in self.current_available_bucket_keys: + continue + key = closest_ratio + bucket = self._aspect_ratio_buckets[key] + if len(bucket) == self.batch_size: + yield bucket[:self.batch_size] + del bucket[:self.batch_size] + i += 1 + self.exhausted_bucket_keys.append(key) + self.current_available_bucket_keys.remove(key) + + for _ in range(self.total_batches - i): + key = choice(self.all_available_keys) + bucket = self._aspect_ratio_buckets[key] + if len(bucket) >= self.batch_size: + yield bucket[:self.batch_size] + del bucket[:self.batch_size] + + # If a bucket is exhausted + if not bucket: + self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) + shuffle(self._aspect_ratio_buckets[key]) + else: + self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:]) + shuffle(self._aspect_ratio_buckets[key]) diff --git a/diffusion/utils/dist_utils.py b/diffusion/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c365c3383873fb9ab978c399ac2d02a2d836bcb9 --- /dev/null +++ b/diffusion/utils/dist_utils.py @@ -0,0 +1,314 @@ +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" +import os +import pickle +import shutil + +import gc +import mmcv +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info + + +def is_distributed(): + return get_world_size() > 1 + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + local_rank = int(os.getenv('LOCAL_RANK', 0)) + return local_rank + + +def is_master(): + return get_rank() == 0 + + +def is_local_master(): + return get_local_rank() == 0 + + +def get_local_proc_group(group_size=8): + world_size = get_world_size() + if world_size <= group_size or group_size == 1: + return None + assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).' + process_groups = getattr(get_local_proc_group, 'process_groups', dict()) + if group_size not in process_groups: + num_groups = dist.get_world_size() // group_size + groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)] + process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]}) + get_local_proc_group.process_groups = process_groups + + group_idx = get_rank() // group_size + process_groups = get_local_proc_group.process_groups.get(group_size)[group_idx] + return process_groups + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + to_device = torch.device("cuda") + # to_device = torch.device("cpu") + + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(to_device) + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).to(to_device) + size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to(to_device)) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device) + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def broadcast(data, **kwargs): + if get_world_size() == 1: + return data + data = [data] + dist.broadcast_object_list(data, **kwargs) + return data[0] + + +def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True): + rank, world_size = get_dist_info() + if tmpdir is None: + tmpdir = './tmp' + if rank == 0: + mmcv.mkdir_or_exist(tmpdir) + synchronize() + # dump the part result to the dir + mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl')) + synchronize() + # collect all parts + if collect_by_master and rank != 0: + return None + else: + # load results of all parts from tmp dir + results = [] + for i in range(world_size): + part_file = os.path.join(tmpdir, f'part_{i}.pkl') + results.append(mmcv.load(part_file)) + if not collect_by_master: + synchronize() + # remove tmp dir + if rank == 0: + shutil.rmtree(tmpdir) + return results + +def all_gather_tensor(tensor, group_size=None, group=None): + if group_size is None: + group_size = get_world_size() + if group_size == 1: + output = [tensor] + else: + output = [torch.zeros_like(tensor) for _ in range(group_size)] + dist.all_gather(output, tensor, group=group) + return output + + +def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None): + world_size = get_world_size() + if world_size == 1: + if not concat: + return [feat] + return feat + num_samples, *feat_dim = feat.size() + # padding to max number of samples + feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim)) + feat_padding[:num_samples] = feat + # gather + feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size) + for r, num in enumerate(num_samples_list): + feat_gather[r] = feat_gather[r][:num] + if concat: + feat_gather = torch.cat(feat_gather) + return feat_gather + + +class GatherLayer(torch.autograd.Function): + '''Gather tensors from all process, supporting backward propagation. + ''' + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device) + ctx.num_samples_list = all_gather_tensor(num_samples) + output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): # tuple(output)'s grad + input, = ctx.saved_tensors + num_samples_list = ctx.num_samples_list + rank = get_rank() + start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1]) + grads = torch.cat(grads) + if is_distributed(): + dist.all_reduce(grads) + grad_out = torch.zeros_like(input) + grad_out[:] = grads[start:end] + return grad_out, None, None + + +class GatherLayerWithGroup(torch.autograd.Function): + '''Gather tensors from all process, supporting backward propagation. + ''' + + @staticmethod + def forward(ctx, input, group, group_size): + ctx.save_for_backward(input) + ctx.group_size = group_size + output = all_gather_tensor(input, group=group, group_size=group_size) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): # tuple(output)'s grad + input, = ctx.saved_tensors + grads = torch.stack(grads) + if is_distributed(): + dist.all_reduce(grads) + grad_out = torch.zeros_like(input) + grad_out[:] = grads[get_rank() % ctx.group_size] + return grad_out, None, None + + +def gather_layer_with_group(data, group=None, group_size=None): + if group_size is None: + group_size = get_world_size() + output = GatherLayer.apply(data, group, group_size) + return output + +from typing import Union +import math +# from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm + +@torch.no_grad() +def clip_grad_norm_( + self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 +) -> None: + self._lazy_init() + self._wait_for_previous_optim_step() + assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" + self._assert_state(TrainingState_.IDLE) + + max_norm = float(max_norm) + norm_type = float(norm_type) + # Computes the max norm for this shard's gradients and sync's across workers + local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type] + if norm_type == math.inf: + total_norm = local_norm + dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) + else: + total_norm = local_norm ** norm_type + dist.all_reduce(total_norm, group=self.process_group) + total_norm = total_norm ** (1.0 / norm_type) + + clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) + if clip_coef < 1: + # multiply by clip_coef, aka, (max_norm/total_norm). + for p in self.params_with_grad: + assert p.grad is not None + p.grad.detach().mul_(clip_coef.to(p.grad.device)) + return total_norm + + +def flush(): + gc.collect() + torch.cuda.empty_cache() diff --git a/diffusion/utils/logger.py b/diffusion/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..bf3e2165902e9e841070c6cfb92eab3732c7ca5b --- /dev/null +++ b/diffusion/utils/logger.py @@ -0,0 +1,99 @@ +import logging +import os +import torch.distributed as dist +from datetime import datetime +from .dist_utils import is_local_master +from mmcv.utils.logging import logger_initialized + + +def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'): + """Get root logger. + + Args: + log_file (str, optional): File path of log. Defaults to None. + log_level (int, optional): The level of logger. + Defaults to logging.INFO. + name (str): logger name + Returns: + :obj:`logging.Logger`: The obtained logger + """ + if log_file is None: + log_file = '/dev/null' + logger = get_logger(name=name, log_file=log_file, log_level=log_level) + return logger + + +def get_logger(name, log_file=None, log_level=logging.INFO): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + logger.propagate = False # disable root logger to avoid duplicate logging + + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + file_handler = logging.FileHandler(log_file, 'w') + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + # only rank0 for each node will print logs + log_level = log_level if is_local_master() else logging.ERROR + logger.setLevel(log_level) + + logger_initialized[name] = True + + return logger + +def rename_file_with_creation_time(file_path): + # 获取文件的创建时间 + creation_time = os.path.getctime(file_path) + creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S') + + # 构建新的文件名 + dir_name, file_name = os.path.split(file_path) + name, ext = os.path.splitext(file_name) + new_file_name = f"{name}_{creation_time_str}{ext}" + new_file_path = os.path.join(dir_name, new_file_name) + + # 重命名文件 + os.rename(file_path, new_file_path) + print(f"File renamed to: {new_file_path}") diff --git a/diffusion/utils/lr_scheduler.py b/diffusion/utils/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..44ac4419b4125ae1ce21f03f2beb5dde67bfb5f2 --- /dev/null +++ b/diffusion/utils/lr_scheduler.py @@ -0,0 +1,84 @@ +from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR +import math + +from diffusion.utils.logger import get_root_logger + + +def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio): + if not config.get('lr_schedule_args', None): + config.lr_schedule_args = dict() + if config.get('lr_warmup_steps', None): + config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version + + logger = get_root_logger() + logger.info( + f'Lr schedule: {config.lr_schedule}, ' + ",".join( + [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.') + if config.lr_schedule == 'cosine': + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + **config.lr_schedule_args, + num_training_steps=(len(train_dataloader) * config.num_epochs), + ) + elif config.lr_schedule == 'constant': + lr_scheduler = get_constant_schedule_with_warmup( + optimizer=optimizer, + **config.lr_schedule_args, + ) + elif config.lr_schedule == 'cosine_decay_to_constant': + assert lr_scale_ratio >= 1 + lr_scheduler = get_cosine_decay_to_constant_with_warmup( + optimizer=optimizer, + **config.lr_schedule_args, + final_lr=1 / lr_scale_ratio, + num_training_steps=(len(train_dataloader) * config.num_epochs), + ) + else: + raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.') + return lr_scheduler + + +def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + final_lr: float = 0.0, + num_decay: float = 0.667, + num_cycles: float = 0.5, + last_epoch: int = -1 + ): + """ + Create a schedule with a cosine annealing lr followed by a constant lr. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The number of total training steps. + final_lr (`int`): + The final constant lr after cosine decay. + num_decay (`int`): + The + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + num_decay_steps = int(num_training_steps * num_decay) + if current_step > num_decay_steps: + return final_lr + + progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * ( + 1 - final_lr) + final_lr + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/diffusion/utils/misc.py b/diffusion/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..85256a91a7a1bdee238b4044f1fc1754d3111633 --- /dev/null +++ b/diffusion/utils/misc.py @@ -0,0 +1,386 @@ +import collections +import datetime +import os +import random +import subprocess +import time +from multiprocessing import JoinableQueue, Process + +import numpy as np +import torch +import torch.distributed as dist +from mmcv import Config +from mmcv.runner import get_dist_info + +from diffusion.utils.dist_utils import get_rank +from diffusion.utils.logger import get_root_logger + +os.environ["MOX_SILENT_MODE"] = "1" # mute moxing log + + +def read_config(file): + # solve config loading conflict when multi-processes + import time + while True: + config = Config.fromfile(file) + if len(config) == 0: + time.sleep(0.1) + continue + break + return config + + +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2 ** 31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +class SimpleTimer: + def __init__(self, num_tasks, log_interval=1, desc="Process"): + self.num_tasks = num_tasks + self.desc = desc + self.count = 0 + self.log_interval = log_interval + self.start_time = time.time() + self.logger = get_root_logger() + + def log(self, n=1): + self.count += n + if (self.count % self.log_interval) == 0 or self.count == self.num_tasks: + time_elapsed = time.time() - self.start_time + avg_time = time_elapsed / self.count + eta_sec = avg_time * (self.num_tasks - self.count) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + elapsed_str = str(datetime.timedelta(seconds=int(time_elapsed))) + log_info = f"{self.desc} [{self.count}/{self.num_tasks}], elapsed_time:{elapsed_str}," \ + f" avg_time: {avg_time}, eta: {eta_str}." + self.logger.info(log_info) + + +class DebugUnderflowOverflow: + """ + This debug class helps detect and understand where the model starts getting very large or very small, and more + importantly `nan` or `inf` weight and activation elements. + There are 2 working modes: + 1. Underflow/overflow detection (default) + 2. Specific batch absolute min/max tracing without detection + Mode 1: Underflow/overflow detection + To activate the underflow/overflow detection, initialize the object with the model : + ```python + debug_overflow = DebugUnderflowOverflow(model) + ``` + then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or + output elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this + event, each frame reporting + 1. the fully qualified module name plus the class name whose `forward` was run + 2. the absolute min and max value of all elements for each module weights, and the inputs and output + For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 mixed precision : + ``` + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + [...] + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + ``` + You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value + was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which + renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than + 64K, and we get an overlow. + As you can see it's the previous frames that we need to look into when the numbers start going into very large for + fp16 numbers. + The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed. + By default the last 21 frames are printed. You can change the default to adjust for your needs. For example : + ```python + debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) + ``` + To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may + take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next + section. + Mode 2. Specific batch absolute min/max tracing without detection + The second work mode is per-batch tracing with the underflow/overflow detection feature turned off. + Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a + given batch, and only do that for batches 1 and 3. Then you instantiate this class as : + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3]) + ``` + And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed. + This is helpful if you know that the program starts misbehaving after a certain batch number, so you can + fast-forward right to that area. + Early stopping: + You can also specify the batch number after which to stop the training, with : + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3) + ``` + This feature is mainly useful in the tracing mode, but you can use it for any mode. + **Performance**: + As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the + training down. Therefore remember to turn it off once the debugging needs have been met. + Args: + model (`nn.Module`): + The model to debug. + max_frames_to_save (`int`, *optional*, defaults to 21): + How many frames back to record + trace_batch_nums(`List[int]`, *optional*, defaults to `[]`): + Which batch numbers to trace (turns detection off) + abort_after_batch_num (`int``, *optional*): + Whether to abort after a certain batch number has finished + """ + + def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): + self.model = model + self.trace_batch_nums = trace_batch_nums + self.abort_after_batch_num = abort_after_batch_num + + # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence + self.frames = collections.deque([], max_frames_to_save) + self.frame = [] + self.batch_number = 0 + self.total_calls = 0 + self.detected_overflow = False + self.prefix = " " + + self.analyse_model() + + self.register_forward_hook() + + def save_frame(self, frame=None): + if frame is not None: + self.expand_frame(frame) + self.frames.append("\n".join(self.frame)) + self.frame = [] # start a new frame + + def expand_frame(self, line): + self.frame.append(line) + + def trace_frames(self): + print("\n".join(self.frames)) + self.frames = [] + + def reset_saved_frames(self): + self.frames = [] + + def dump_saved_frames(self): + print(f"\nDetected inf/nan during batch_number={self.batch_number} " + f"Last {len(self.frames)} forward frames:" + f"{'abs min':8} {'abs max':8} metadata" + f"'\n'.join(self.frames)" + f"\n\n") + self.frames = [] + + def analyse_model(self): + # extract the fully qualified module names, to be able to report at run time. e.g.: + # encoder.block.2.layer.0.SelfAttention.o + # + # for shared weights only the first shared module name will be registered + self.module_names = {m: name for name, m in self.model.named_modules()} + # self.longest_module_name = max(len(v) for v in self.module_names.values()) + + def analyse_variable(self, var, ctx): + if torch.is_tensor(var): + self.expand_frame(self.get_abs_min_max(var, ctx)) + if self.detect_overflow(var, ctx): + self.detected_overflow = True + elif var is None: + self.expand_frame(f"{'None':>17} {ctx}") + else: + self.expand_frame(f"{'not a tensor':>17} {ctx}") + + def batch_start_frame(self): + self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***") + self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") + + def batch_end_frame(self): + self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n") + + def create_frame(self, module, input, output): + self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") + + # params + for name, p in module.named_parameters(recurse=False): + self.analyse_variable(p, name) + + # inputs + if isinstance(input, tuple): + for i, x in enumerate(input): + self.analyse_variable(x, f"input[{i}]") + else: + self.analyse_variable(input, "input") + + # outputs + if isinstance(output, tuple): + for i, x in enumerate(output): + # possibly a tuple of tuples + if isinstance(x, tuple): + for j, y in enumerate(x): + self.analyse_variable(y, f"output[{i}][{j}]") + else: + self.analyse_variable(x, f"output[{i}]") + else: + self.analyse_variable(output, "output") + + self.save_frame() + + def register_forward_hook(self): + self.model.apply(self._register_forward_hook) + + def _register_forward_hook(self, module): + module.register_forward_hook(self.forward_hook) + + def forward_hook(self, module, input, output): + # - input is a tuple of packed inputs (could be non-Tensors) + # - output could be a Tensor or a tuple of Tensors and non-Tensors + + last_frame_of_batch = False + + trace_mode = True if self.batch_number in self.trace_batch_nums else False + if trace_mode: + self.reset_saved_frames() + + if self.total_calls == 0: + self.batch_start_frame() + self.total_calls += 1 + + # count batch numbers - the very first forward hook of the batch will be called when the + # batch completes - i.e. it gets called very last - we know this batch has finished + if module == self.model: + self.batch_number += 1 + last_frame_of_batch = True + + self.create_frame(module, input, output) + + # if last_frame_of_batch: + # self.batch_end_frame() + + if trace_mode: + self.trace_frames() + + if last_frame_of_batch: + self.batch_start_frame() + + if self.detected_overflow and not trace_mode: + self.dump_saved_frames() + + # now we can abort, as it's pointless to continue running + raise ValueError( + "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " + "Please scroll up above this traceback to see the activation values prior to this event." + ) + + # abort after certain batch if requested to do so + if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: + raise ValueError( + f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg" + ) + + @staticmethod + def get_abs_min_max(var, ctx): + abs_var = var.abs() + return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}" + + @staticmethod + def detect_overflow(var, ctx): + """ + Report whether the tensor contains any `nan` or `inf` entries. + This is useful for detecting overflows/underflows and best to call right after the function that did some math that + modified the tensor in question. + This function contains a few other helper features that you can enable and tweak directly if you want to track + various other things. + Args: + var: the tensor variable to check + ctx: the message to print as a context + Return: + `True` if `inf` or `nan` was detected, `False` otherwise + """ + detected = False + if torch.isnan(var).any().item(): + detected = True + print(f"{ctx} has nans") + if torch.isinf(var).any().item(): + detected = True + print(f"{ctx} has infs") + if var.dtype == torch.float32 and torch.ge(var.abs(), 65535).any().item(): + detected = True + print(f"{ctx} has overflow values {var.abs().max().item()}.") + # if needed to monitor large elements can enable the following + if 0: # and detected: + n100 = var[torch.ge(var.abs(), 100)] + if n100.numel() > 0: + print(f"{ctx}: n100={n100.numel()}") + n1000 = var[torch.ge(var.abs(), 1000)] + if n1000.numel() > 0: + print(f"{ctx}: n1000={n1000.numel()}") + n10000 = var[torch.ge(var.abs(), 10000)] + if n10000.numel() > 0: + print(f"{ctx}: n10000={n10000.numel()}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") + + return detected diff --git a/diffusion/utils/optimizer.py b/diffusion/utils/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb7eb9f41a833997f8515e712fa28d4af91bf0a --- /dev/null +++ b/diffusion/utils/optimizer.py @@ -0,0 +1,246 @@ +import math + +from mmcv import Config +from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \ + OPTIMIZERS +from mmcv.utils import _BatchNorm, _InstanceNorm +from torch.nn import GroupNorm, LayerNorm + +from .logger import get_root_logger + +from typing import Tuple, Optional, Callable + +import torch +from torch.optim.optimizer import Optimizer +from came_pytorch import CAME + + +def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256): + assert rule in ['linear', 'sqrt'] + logger = get_root_logger() + # scale by world size + if rule == 'sqrt': + scale_ratio = math.sqrt(effective_bs / base_batch_size) + elif rule == 'linear': + scale_ratio = effective_bs / base_batch_size + optimizer_cfg['lr'] *= scale_ratio + logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') + return scale_ratio + + +@OPTIMIZER_BUILDERS.register_module() +class MyOptimizerConstructor(DefaultOptimizerConstructor): + + def add_params(self, params, module, prefix='', is_dcn_module=None): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + + for name, param in module.named_parameters(recurse=False): + base_lr = self.base_lr + if name == 'bias' and not (is_norm or is_dcn_module): + base_lr *= bias_lr_mult + + # apply weight decay policies + base_wd = self.base_wd + if self.base_wd is not None: + # norm decay + if is_norm: + base_wd *= norm_decay_mult + # bias lr and decay + elif name == 'bias' and not is_dcn_module: + # TODO: current bias_decay_mult will have affect on DCN + base_wd *= bias_decay_mult + + param_group = {'params': [param]} + if not param.requires_grad: + param_group['requires_grad'] = False + params.append(param_group) + continue + if bypass_duplicate and self._is_in(param_group, params): + logger = get_root_logger() + logger.warn(f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}') + continue + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in custom_keys: + if isinstance(key, tuple): + scope, key_name = key + else: + scope, key_name = None, key + if scope is not None and scope not in f'{prefix}': + continue + if key_name in f'{prefix}.{name}': + is_custom = True + if 'lr_mult' in custom_keys[key]: + # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': + # param_group['lr'] = self.base_lr + # else: + param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult'] + elif 'lr' not in param_group: + param_group['lr'] = base_lr + if self.base_wd is not None: + if 'decay_mult' in custom_keys[key]: + param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult'] + elif 'weight_decay' not in param_group: + param_group['weight_decay'] = base_wd + + if not is_custom: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if base_lr != self.base_lr: + param_group['lr'] = base_lr + if base_wd != self.base_wd: + param_group['weight_decay'] = base_wd + params.append(param_group) + + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) + + +def build_optimizer(model, optimizer_cfg): + # default parameter-wise config + logger = get_root_logger() + + if hasattr(model, 'module'): + model = model.module + # set optimizer constructor + optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor') + # parameter-wise setting: cancel weight decay for some specific modules + custom_keys = dict() + for name, module in model.named_modules(): + if hasattr(module, 'zero_weight_decay'): + custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) + + paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) + given_cfg = optimizer_cfg.get('paramwise_cfg') + if given_cfg: + paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) + optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg + # build optimizer + optimizer = mm_build_optimizer(model, optimizer_cfg) + + weight_decay_groups = dict() + lr_groups = dict() + for group in optimizer.param_groups: + if not group.get('requires_grad', True): continue + lr_groups.setdefault(group['lr'], []).append(group) + weight_decay_groups.setdefault(group['weight_decay'], []).append(group) + + learnable_count, fix_count = 0, 0 + for p in model.parameters(): + if p.requires_grad: + learnable_count += 1 + else: + fix_count += 1 + fix_info = f"{learnable_count} are learnable, {fix_count} are fix" + lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()]) + wd_info = "Weight decay group: " + ", ".join( + [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()]) + opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." + logger.info(opt_info) + + return optimizer + + +@OPTIMIZERS.register_module() +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + + super().__init__(params, defaults) + + @staticmethod + def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + p.data.mul_(1 - lr * wd) + + # weight update + update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + exp_avg.lerp_(grad, 1 - beta2) + + @staticmethod + def exists(val): + return val is not None + + @torch.no_grad() + def step( + self, + closure: Optional[Callable] = None + ): + + loss = None + if self.exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: self.exists(p.grad), group['params']): + + grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ + self.state[p] + + # init state - exponential moving average of gradient values + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss + + +@OPTIMIZERS.register_module() +class CAMEWrapper(CAME): + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..a12bcf234acfbda722e17098e12179d98fffb067 --- /dev/null +++ b/environment.yml @@ -0,0 +1,25 @@ +name: PixArt +channels: + - pytorch + - nvidia +dependencies: + - python >= 3.8 + - pytorch >= 1.13 + - torchvision + - pytorch-cuda=11.7 + - pip: + - timm==0.6.12 + - diffusers + - accelerate + - mmcv==1.7.0 + - diffusers + - accelerate==0.15.0 + - tensorboard + - transformers==4.26.1 + - sentencepiece~=0.1.97 + - ftfy~=6.1.1 + - beautifulsoup4~=4.11.1 + - opencv-python + - bs4 + - einops + - xformers \ No newline at end of file diff --git a/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py b/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py new file mode 100644 index 0000000000000000000000000000000000000000..a59d1e1525e47d87105b24f5ee6280a4fc2f3f14 --- /dev/null +++ b/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py @@ -0,0 +1,30 @@ +_base_ = ['/workspace/PixArt-alpha/configs/PixArt_xl2_internal.py'] +data_root = '/workspace' + +image_list_json = ['data_info.json',] + +data = dict(type='InternalData', root='/workspace/pixart-pokemon', image_list_json=image_list_json, transform='default_train', load_vae_feat=True) +image_size = 512 + +# model setting +model = 'PixArt_XL_2' +fp32_attention = True +load_from = "/workspace/PixArt-alpha/output/pretrained_models/PixArt-XL-2-512x512.pth" +vae_pretrained = "output/pretrained_models/sd-vae-ft-ema" +pe_interpolation = 1.0 + +# training setting +use_fsdp=False # if use FSDP mode +num_workers=10 +train_batch_size = 38 # 32 +num_epochs = 200 # 3 +gradient_accumulation_steps = 1 +grad_checkpointing = True +gradient_clip = 0.01 +optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10) +lr_schedule_args = dict(num_warmup_steps=1000) + +eval_sampling_steps = 200 +log_interval = 20 +save_model_steps=100 +work_dir = 'output/debug' diff --git a/notebooks/convert-checkpoint-to-diffusers.ipynb b/notebooks/convert-checkpoint-to-diffusers.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3a0399873012cc748963c39e88d0444974e876e6 --- /dev/null +++ b/notebooks/convert-checkpoint-to-diffusers.ipynb @@ -0,0 +1,74 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "2878bb5d-33a3-4a5b-b15c-c832c700129b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/PixArt-alpha\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.\n", + " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" + ] + } + ], + "source": [ + "%cd PixArt-alpha" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "7dd2d98c-3f8f-40f1-a9e1-bc916774afb3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of transformer parameters: 610856096\n" + ] + } + ], + "source": [ + "!python tools/convert_pixart_alpha_to_diffusers.py \\\n", + " --orig_ckpt_path \"/workspace/PixArt-alpha/output/trained_model/checkpoints/epoch_5_step_110.pth\" \\\n", + " --dump_path \"/workspace/PixArt-alpha/output/diffusers_trained\" \\\n", + " --only_transformer=True \\\n", + " --image_size 512 \\\n", + " --multi_scale_train=False\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/infer.ipynb b/notebooks/infer.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6fba5169ae0b19e711d6c1ffe88250143130abd0 --- /dev/null +++ b/notebooks/infer.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8b2458c4-c461-4ddc-af94-fcd837357da4", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import PixArtAlphaPipeline\n", + "import torch\n", + "from diffusers import Transformer2DModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81a5bc0f-682b-4ff9-92e9-43b68b3df8fc", + "metadata": {}, + "outputs": [], + "source": [ + "# for comparison\n", + "\n", + "orig_pipe = pipe = PixArtAlphaPipeline.from_pretrained(\"PixArt-alpha/PixArt-XL-2-512x512\", torch_dtype=torch.float16)\n", + "orig_pipe = orig_pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efc07821-5479-4ca3-a2c6-114ac484fd1e", + "metadata": {}, + "outputs": [], + "source": [ + "transformer = Transformer2DModel.from_pretrained(\"/workspace/PixArt-alpha/output/diffusers_trained/transformer\", torch_dtype=torch.float16)\n", + "pipe = PixArtAlphaPipeline.from_pretrained(\"PixArt-alpha/PixArt-XL-2-512x512\", torch_dtype=torch.float16, transformer=transformer)\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "57da873b-2c13-463b-b558-ee69522ccefc", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d69c7683773c4c25914764800ec1ef4f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/20 [00:00" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = \"A green pokemon on white background\"\n", + "image = pipe(prompt=prompt).images[0]\n", + "image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/train.ipynb b/notebooks/train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..24d5b56ceee922736c6a7c6bf0dc97ce515dd2be --- /dev/null +++ b/notebooks/train.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c423d2a1-475e-482e-b759-f16456fd6707", + "metadata": {}, + "source": [ + "# Install" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0440d6a7-78b9-49e9-98a2-9a5ed75e1a2f", + "metadata": {}, + "outputs": [], + "source": [ + "!git clone https://github.com/kopyl/PixArt-alpha.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0abadf51-a7e3-4091-bb02-0bdd8d28fb73", + "metadata": {}, + "outputs": [], + "source": [ + "%cd PixArt-alpha" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4df1af24-f439-485d-a946-966dbf16c49b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117\n", + "!pip install -r requirements.txt\n", + "!pip install wandb" + ] + }, + { + "cell_type": "markdown", + "id": "d44474fd-0b92-48fc-b4cf-142b59d3917c", + "metadata": {}, + "source": [ + "## Download model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06b1c1c9-f8b1-4719-8564-2383eac9ff28", + "metadata": {}, + "outputs": [], + "source": [ + "!python tools/download.py --model_names \"PixArt-XL-2-512x512.pth\"" + ] + }, + { + "cell_type": "markdown", + "id": "f298a89c-d2a5-4da7-8304-c1390da0ba58", + "metadata": {}, + "source": [ + "## Make dataset out of Hugginggface dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e17b8883-0a5c-4fa3-a7d0-e8ee95e42027", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from tqdm.notebook import tqdm\n", + "from datasets import load_dataset\n", + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92957b2c-6765-48ee-9296-d6739066d74d", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(\"lambdalabs/pokemon-blip-captions\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0095cdda-c31a-48ee-a115-076a5fc393c3", + "metadata": {}, + "outputs": [], + "source": [ + "root_dir = \"/workspace/pixart-pokemon\"\n", + "images_dir = \"images\"\n", + "captions_dir = \"captions\"\n", + "\n", + "images_dir_absolute = os.path.join(root_dir, images_dir)\n", + "captions_dir_absolute = os.path.join(root_dir, captions_dir)\n", + "\n", + "if not os.path.exists(root_dir):\n", + " os.makedirs(os.path.join(root_dir, images_dir))\n", + "\n", + "if not os.path.exists(os.path.join(root_dir, images_dir)):\n", + " os.makedirs(os.path.join(root_dir, images_dir))\n", + "if not os.path.exists(os.path.join(root_dir, captions_dir)):\n", + " os.makedirs(os.path.join(root_dir, captions_dir))\n", + "\n", + "image_format = \"png\"\n", + "json_name = \"partition/data_info.json\"\n", + "if not os.path.exists(os.path.join(root_dir, \"partition\")):\n", + " os.makedirs(os.path.join(root_dir, \"partition\"))\n", + "\n", + "absolute_json_name = os.path.join(root_dir, json_name)\n", + "data_info = []\n", + "\n", + "order = 0\n", + "for item in tqdm(dataset[\"train\"]): \n", + " image = item[\"image\"]\n", + " image.save(f\"{images_dir_absolute}/{order}.{image_format}\")\n", + " with open(f\"{captions_dir_absolute}/{order}.txt\", \"w\") as text_file:\n", + " text_file.write(item[\"text\"])\n", + " \n", + " width, height = 512, 512\n", + " ratio = 1\n", + " data_info.append({\n", + " \"height\": height,\n", + " \"width\": width,\n", + " \"ratio\": ratio,\n", + " \"path\": f\"images/{order}.{image_format}\",\n", + " \"prompt\": item[\"text\"],\n", + " })\n", + " \n", + " order += 1\n", + "\n", + "with open(absolute_json_name, \"w\") as json_file:\n", + " json.dump(data_info, json_file)" + ] + }, + { + "cell_type": "markdown", + "id": "25be1c03", + "metadata": {}, + "source": [ + "## Extract features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f07a4f5-1873-48bf-86d0-9304942de5d3", + "metadata": {}, + "outputs": [], + "source": [ + "!python /workspace/PixArt-alpha/tools/extract_features.py \\\n", + " --img_size 512 \\\n", + " --json_path \"/workspace/pixart-pokemon/partition/data_info.json\" \\\n", + " --t5_save_root \"/workspace/pixart-pokemon/caption_feature_wmask\" \\\n", + " --vae_save_root \"/workspace/pixart-pokemon/img_vae_features\" \\\n", + " --pretrained_models_dir \"/workspace/PixArt-alpha/output/pretrained_models\" \\\n", + " --dataset_root \"/workspace/pixart-pokemon\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fc653d0", + "metadata": {}, + "outputs": [], + "source": [ + "!wandb login REPLACE_THIS_WITH_YOUR_AUTH_TOKEN_OF_WANDB" + ] + }, + { + "cell_type": "markdown", + "id": "2cf1fd1a", + "metadata": {}, + "source": [ + "## Train model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea0e9dab-17bc-45ed-9c81-b670bbb8de47", + "metadata": {}, + "outputs": [], + "source": [ + "!python -m torch.distributed.launch \\\n", + " train_scripts/train.py \\\n", + " /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \\\n", + " --work-dir output/trained_model \\\n", + " --report_to=\"wandb\" \\\n", + " --loss_report_name=\"train_loss\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements.txt b/requirements.txt index 0dfd7bb961dc6e421df0288fe0deebad55f5393c..24488a9032e0963c56ed134288b26c87efc1b5c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,20 @@ git+https://github.com/huggingface/diffusers +mmcv==1.7.0 +timm==0.6.12 +accelerate==0.25.0 +tensorboard +tensorboardX +transformers==4.36.1 +sentencepiece~=0.1.99 +ftfy +beautifulsoup4 +protobuf==3.20.2 +gradio==4.1.1 +yapf==0.40.1 +opencv-python +bs4 +einops +xformers==0.0.19 +optimum +peft==0.6.2 +came-pytorch \ No newline at end of file diff --git a/scripts/diffusers_patches.py b/scripts/diffusers_patches.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f374de9b2f03f167684486166bca2ea6138a52 --- /dev/null +++ b/scripts/diffusers_patches.py @@ -0,0 +1,541 @@ +import torch +from diffusers import ImagePipelineOutput, PixArtAlphaPipeline, AutoencoderKL, Transformer2DModel, \ + DPMSolverMultistepScheduler +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.embeddings import PixArtAlphaTextProjection, PatchEmbed +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import retrieve_timesteps +from typing import Callable, List, Optional, Tuple, Union + +from diffusers.utils import deprecate +from torch import nn +from transformers import T5Tokenizer, T5EncoderModel + +ASPECT_RATIO_2048_BIN = { + "0.25": [1024.0, 4096.0], + "0.26": [1024.0, 3968.0], + "0.27": [1024.0, 3840.0], + "0.28": [1024.0, 3712.0], + "0.32": [1152.0, 3584.0], + "0.33": [1152.0, 3456.0], + "0.35": [1152.0, 3328.0], + "0.4": [1280.0, 3200.0], + "0.42": [1280.0, 3072.0], + "0.48": [1408.0, 2944.0], + "0.5": [1408.0, 2816.0], + "0.52": [1408.0, 2688.0], + "0.57": [1536.0, 2688.0], + "0.6": [1536.0, 2560.0], + "0.68": [1664.0, 2432.0], + "0.72": [1664.0, 2304.0], + "0.78": [1792.0, 2304.0], + "0.82": [1792.0, 2176.0], + "0.88": [1920.0, 2176.0], + "0.94": [1920.0, 2048.0], + "1.0": [2048.0, 2048.0], + "1.07": [2048.0, 1920.0], + "1.13": [2176.0, 1920.0], + "1.21": [2176.0, 1792.0], + "1.29": [2304.0, 1792.0], + "1.38": [2304.0, 1664.0], + "1.46": [2432.0, 1664.0], + "1.67": [2560.0, 1536.0], + "1.75": [2688.0, 1536.0], + "2.0": [2816.0, 1408.0], + "2.09": [2944.0, 1408.0], + "2.4": [3072.0, 1280.0], + "2.5": [3200.0, 1280.0], + "2.89": [3328.0, 1152.0], + "3.0": [3456.0, 1152.0], + "3.11": [3584.0, 1152.0], + "3.62": [3712.0, 1024.0], + "3.75": [3840.0, 1024.0], + "3.88": [3968.0, 1024.0], + "4.0": [4096.0, 1024.0] +} + +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0] +} + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + + +def pipeline_pixart_alpha_call( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, +) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +class PixArtSigmaPipeline(PixArtAlphaPipeline): + r""" + tmp Pipeline for text-to-image generation using PixArt-Sigma. + """ + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: Transformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__(tokenizer, text_encoder, vae, transformer, scheduler) + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + +def pixart_sigma_init_patched_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.config.norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + elif self.config.norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim ** 0.5) + self.proj_out = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + # PixArt-Sigma blocks. + self.adaln_single = None + self.use_additional_conditions = False + if self.config.norm_type == "ada_norm_single": + # TODO(Sayak, PVP) clean this, PixArt-Sigma doesn't use additional_conditions anymore + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) diff --git a/scripts/inference.py b/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e14ef89de6353b25ef5bbadd5edf9758b752c49a --- /dev/null +++ b/scripts/inference.py @@ -0,0 +1,225 @@ +import os +import sys +from pathlib import Path +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) +import warnings +warnings.filterwarnings("ignore") # ignore warning +import re +import argparse +from datetime import datetime +from tqdm import tqdm +import torch +from torchvision.utils import save_image +from diffusers.models import AutoencoderKL +from transformers import T5EncoderModel, T5Tokenizer + +from diffusion.model.utils import prepare_prompt_ar +from diffusion import IDDPM, DPMS, SASolverSampler +from tools.download import find_model +from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2 +from diffusion.data.datasets import get_chunks +from diffusion.data.datasets.utils import * + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--image_size', default=1024, type=int) + parser.add_argument('--version', default='sigma', type=str) + parser.add_argument( + "--pipeline_load_from", default='output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers', + type=str, help="Download for loading text_encoder, " + "tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + ) + parser.add_argument('--txt_file', default='asset/samples.txt', type=str) + parser.add_argument('--model_path', default='output/pretrained_models/PixArt-XL-2-1024x1024.pth', type=str) + parser.add_argument('--sdvae', action='store_true', help='sd vae') + parser.add_argument('--bs', default=1, type=int) + parser.add_argument('--cfg_scale', default=4.5, type=float) + parser.add_argument('--sampling_algo', default='dpm-solver', type=str, choices=['iddpm', 'dpm-solver', 'sa-solver']) + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--dataset', default='custom', type=str) + parser.add_argument('--step', default=-1, type=int) + parser.add_argument('--save_name', default='test_sample', type=str) + + return parser.parse_args() + + +def set_env(seed=0): + torch.manual_seed(seed) + torch.set_grad_enabled(False) + for _ in range(30): + torch.randn(1, 4, args.image_size, args.image_size) + +@torch.inference_mode() +def visualize(items, bs, sample_steps, cfg_scale): + + for chunk in tqdm(list(get_chunks(items, bs)), unit='batch'): + + prompts = [] + if bs == 1: + save_path = os.path.join(save_root, f"{prompts[0][:100]}.jpg") + if os.path.exists(save_path): + continue + prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(chunk[0], base_ratios, device=device, show=False) # ar for aspect ratio + if args.image_size == 1024: + latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8) + else: + hw = torch.tensor([[args.image_size, args.image_size]], dtype=torch.float, device=device).repeat(bs, 1) + ar = torch.tensor([[1.]], device=device).repeat(bs, 1) + latent_size_h, latent_size_w = latent_size, latent_size + prompts.append(prompt_clean.strip()) + else: + hw = torch.tensor([[args.image_size, args.image_size]], dtype=torch.float, device=device).repeat(bs, 1) + ar = torch.tensor([[1.]], device=device).repeat(bs, 1) + for prompt in chunk: + prompts.append(prepare_prompt_ar(prompt, base_ratios, device=device, show=False)[0].strip()) + latent_size_h, latent_size_w = latent_size, latent_size + + caption_token = tokenizer(prompts, max_length=max_sequence_length, padding="max_length", truncation=True, + return_tensors="pt").to(device) + caption_embs = text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0] + emb_masks = caption_token.attention_mask + + caption_embs = caption_embs[:, None] + null_y = null_caption_embs.repeat(len(prompts), 1, 1)[:, None] + print(f'finish embedding') + + with torch.no_grad(): + + if args.sampling_algo == 'iddpm': + # Create sampling noise: + n = len(prompts) + z = torch.randn(n, 4, latent_size_h, latent_size_w, device=device).repeat(2, 1, 1, 1) + model_kwargs = dict(y=torch.cat([caption_embs, null_y]), + cfg_scale=cfg_scale, data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + diffusion = IDDPM(str(sample_steps)) + # Sample images: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + elif args.sampling_algo == 'dpm-solver': + # Create sampling noise: + n = len(prompts) + z = torch.randn(n, 4, latent_size_h, latent_size_w, device=device) + model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + dpm_solver = DPMS(model.forward_with_dpmsolver, + condition=caption_embs, + uncondition=null_y, + cfg_scale=cfg_scale, + model_kwargs=model_kwargs) + samples = dpm_solver.sample( + z, + steps=sample_steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + elif args.sampling_algo == 'sa-solver': + # Create sampling noise: + n = len(prompts) + model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + sa_solver = SASolverSampler(model.forward_with_dpmsolver, device=device) + samples = sa_solver.sample( + S=25, + batch_size=n, + shape=(4, latent_size_h, latent_size_w), + eta=1, + conditioning=caption_embs, + unconditional_conditioning=null_y, + unconditional_guidance_scale=cfg_scale, + model_kwargs=model_kwargs, + )[0] + + samples = samples.to(weight_dtype) + samples = vae.decode(samples / vae.config.scaling_factor).sample + torch.cuda.empty_cache() + # Save images: + os.umask(0o000) # file permission: 666; dir permission: 777 + for i, sample in enumerate(samples): + save_path = os.path.join(save_root, f"{prompts[i][:100]}.jpg") + print("Saving path: ", save_path) + save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1)) + + +if __name__ == '__main__': + args = get_args() + # Setup PyTorch: + seed = args.seed + set_env(seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver'] + + # only support fixed latent size currently + latent_size = args.image_size // 8 + max_sequence_length = {"alpha": 120, "sigma": 300}[args.version] + pe_interpolation = {256: 0.5, 512: 1, 1024: 2} # trick for positional embedding interpolation + micro_condition = True if args.version == 'alpha' and args.image_size == 1024 else False + sample_steps_dict = {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25} + sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo] + weight_dtype = torch.float16 + print(f"Inference with {weight_dtype}") + + # model setting + micro_condition = True if args.version == 'alpha' and args.image_size == 1024 else False + if args.image_size in [512, 1024, 2048, 2880]: + model = PixArtMS_XL_2( + input_size=latent_size, + pe_interpolation=pe_interpolation[args.image_size], + micro_condition=micro_condition, + model_max_length=max_sequence_length, + ).to(device) + else: + model = PixArt_XL_2( + input_size=latent_size, + pe_interpolation=pe_interpolation[args.image_size], + model_max_length=max_sequence_length, + ).to(device) + + print("Generating sample from ckpt: %s" % args.model_path) + state_dict = find_model(args.model_path) + if 'pos_embed' in state_dict['state_dict']: + del state_dict['state_dict']['pos_embed'] + missing, unexpected = model.load_state_dict(state_dict['state_dict'], strict=False) + print('Missing keys: ', missing) + print('Unexpected keys', unexpected) + model.eval() + model.to(weight_dtype) + base_ratios = eval(f'ASPECT_RATIO_{args.image_size}_TEST') + + if args.sdvae: + # pixart-alpha vae link: https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/sd-vae-ft-ema + vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema").to(device).to(weight_dtype) + else: + # pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae + vae = AutoencoderKL.from_pretrained(f"{args.pipeline_load_from}/vae").to(device).to(weight_dtype) + + tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.pipeline_load_from, subfolder="text_encoder").to(device) + + null_caption_token = tokenizer("", max_length=max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(device) + null_caption_embs = text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0] + + work_dir = os.path.join(*args.model_path.split('/')[:-2]) + work_dir = '/'+work_dir if args.model_path[0] == '/' else work_dir + + # data setting + with open(args.txt_file, 'r') as f: + items = [item.strip() for item in f.readlines()] + + # img save setting + try: + epoch_name = re.search(r'.*epoch_(\d+).*', args.model_path).group(1) + step_name = re.search(r'.*step_(\d+).*', args.model_path).group(1) + except: + epoch_name = 'unknown' + step_name = 'unknown' + img_save_dir = os.path.join(work_dir, 'vis') + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(img_save_dir, exist_ok=True) + + save_root = os.path.join(img_save_dir, f"{datetime.now().date()}_{args.dataset}_epoch{epoch_name}_step{step_name}_scale{args.cfg_scale}_step{sample_steps}_size{args.image_size}_bs{args.bs}_samp{args.sampling_algo}_seed{seed}") + os.makedirs(save_root, exist_ok=True) + visualize(items, args.bs, sample_steps, args.cfg_scale) \ No newline at end of file diff --git a/scripts/interface.py b/scripts/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a8db73f599472312465ceec56aab95bd951fd9 --- /dev/null +++ b/scripts/interface.py @@ -0,0 +1,258 @@ +import argparse +import sys +from pathlib import Path +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) +import os +import random +import torch +from torchvision.utils import save_image +from diffusion import IDDPM, DPMS, SASolverSampler +from diffusers.models import AutoencoderKL +from tools.download import find_model +from datetime import datetime +from typing import List, Union +import gradio as gr +import numpy as np +from gradio.components import Textbox, Image +from transformers import T5EncoderModel, T5Tokenizer +import gc + +from diffusion.model.t5 import T5Embedder +from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor +from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2 +from torchvision.utils import _log_api_usage_once, make_grid +from diffusion.data.datasets.utils import * +from asset.examples import examples +from diffusion.utils.dist_utils import flush + + +MAX_SEED = np.iinfo(np.int32).max + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--image_size', default=1024, type=int) + parser.add_argument('--version', default='sigma', type=str) + parser.add_argument('--model_path', default='output/pretrained_models/PixArt-XL-2-1024-MS.pth', type=str) + parser.add_argument('--sdvae', action='store_true', help='sd vae') + parser.add_argument( + "--pipeline_load_from", default='output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers', + type=str, help="Download for loading text_encoder, " + "tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + ) + parser.add_argument('--port', default=7788, type=int) + + return parser.parse_args() + + +@torch.no_grad() +def ndarr_image(tensor: Union[torch.Tensor, List[torch.Tensor]], **kwargs,) -> None: + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(save_image) + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + return ndarr + + +def set_env(seed=0): + torch.manual_seed(seed) + torch.set_grad_enabled(False) + for _ in range(30): + torch.randn(1, 4, args.image_size, args.image_size) + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +@torch.inference_mode() +def generate_img(prompt, sampler, sample_steps, scale, seed=0, randomize_seed=False): + flush() + gc.collect() + torch.cuda.empty_cache() + + seed = int(randomize_seed_fn(seed, randomize_seed)) + set_env(seed) + + os.makedirs(f'output/demo/online_demo_prompts/', exist_ok=True) + save_promt_path = f'output/demo/online_demo_prompts/tested_prompts{datetime.now().date()}.txt' + with open(save_promt_path, 'a') as f: + f.write(prompt + '\n') + print(prompt) + prompt_clean, prompt_show, hw, ar, custom_hw = prepare_prompt_ar(prompt, base_ratios, device=device) # ar for aspect ratio + prompt_clean = prompt_clean.strip() + if isinstance(prompt_clean, str): + prompts = [prompt_clean] + + caption_token = tokenizer(prompts, max_length=max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(device) + caption_embs = text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0] + emb_masks = caption_token.attention_mask + + caption_embs = caption_embs[:, None] + null_y = null_caption_embs.repeat(len(prompts), 1, 1)[:, None] + + latent_size_h, latent_size_w = int(hw[0, 0]//8), int(hw[0, 1]//8) + # Sample images: + if sampler == 'iddpm': + # Create sampling noise: + n = len(prompts) + z = torch.randn(n, 4, latent_size_h, latent_size_w, device=device).repeat(2, 1, 1, 1) + model_kwargs = dict(y=torch.cat([caption_embs, null_y]), + cfg_scale=scale, data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + diffusion = IDDPM(str(sample_steps)) + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + elif sampler == 'dpm-solver': + # Create sampling noise: + n = len(prompts) + z = torch.randn(n, 4, latent_size_h, latent_size_w, device=device) + model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + dpm_solver = DPMS(model.forward_with_dpmsolver, + condition=caption_embs, + uncondition=null_y, + cfg_scale=scale, + model_kwargs=model_kwargs) + samples = dpm_solver.sample( + z, + steps=sample_steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + elif sampler == 'sa-solver': + # Create sampling noise: + n = len(prompts) + model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + sa_solver = SASolverSampler(model.forward_with_dpmsolver, device=device) + samples = sa_solver.sample( + S=sample_steps, + batch_size=n, + shape=(4, latent_size_h, latent_size_w), + eta=1, + conditioning=caption_embs, + unconditional_conditioning=null_y, + unconditional_guidance_scale=scale, + model_kwargs=model_kwargs, + )[0] + + samples = samples.to(weight_dtype) + samples = vae.decode(samples / vae.config.scaling_factor).sample + samples = resize_and_crop_tensor(samples, custom_hw[0,1], custom_hw[0,0]) + display_model_info = f'Model path: {args.model_path},\nBase image size: {args.image_size}, \nSampling Algo: {sampler}' + return ndarr_image(samples, normalize=True, value_range=(-1, 1)), prompt_show, display_model_info, seed + + +if __name__ == '__main__': + from diffusion.utils.logger import get_root_logger + args = get_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + logger = get_root_logger() + + assert args.image_size in [256, 512, 1024, 2048], \ + "We only provide pre-trained models for 256x256, 512x512, 1024x1024 and 2048x2048 resolutions." + pe_interpolation = {256: 0.5, 512: 1, 1024: 2, 2048: 4} + latent_size = args.image_size // 8 + max_sequence_length = {"alpha": 120, "sigma": 300}[args.version] + weight_dtype = torch.float16 + micro_condition = True if args.version == 'alpha' and args.image_size == 1024 else False + if args.image_size in [512, 1024, 2048, 2880]: + model = PixArtMS_XL_2( + input_size=latent_size, + pe_interpolation=pe_interpolation[args.image_size], + micro_condition=micro_condition, + model_max_length=max_sequence_length, + ).to(device) + else: + model = PixArt_XL_2( + input_size=latent_size, + pe_interpolation=pe_interpolation[args.image_size], + model_max_length=max_sequence_length, + ).to(device) + state_dict = find_model(args.model_path) + if 'pos_embed' in state_dict['state_dict']: + del state_dict['state_dict']['pos_embed'] + missing, unexpected = model.load_state_dict(state_dict['state_dict'], strict=False) + logger.warning(f'Missing keys: {missing}') + logger.warning(f'Unexpected keys: {unexpected}') + model.to(weight_dtype) + model.eval() + base_ratios = eval(f'ASPECT_RATIO_{args.image_size}_TEST') + + if args.sdvae: + # pixart-alpha vae link: https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/sd-vae-ft-ema + vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema").to(device).to(weight_dtype) + else: + # pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae + vae = AutoencoderKL.from_pretrained(f"{args.pipeline_load_from}/vae").to(device).to(weight_dtype) + + tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.pipeline_load_from, subfolder="text_encoder").to(device) + + null_caption_token = tokenizer("", max_length=max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(device) + null_caption_embs = text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0] + + title = f""" + '' Unleashing your Creativity \n '' +
+ + {args.image_size}px +
+ """ + DESCRIPTION = f"""# PixArt-Sigma {args.image_size}px + ## If PixArt-Sigma is helpful, please help to ⭐ the [Github Repo](https://github.com/PixArt-alpha/PixArt-sigma) and recommend it to your friends ��' + #### [PixArt-Sigma {args.image_size}px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma) checkpoint. + #### English prompts ONLY; 提示词仅限英文 + """ + if not torch.cuda.is_available(): + DESCRIPTION += "\n

Running on CPU �� This demo does not work on CPU.

" + + demo = gr.Interface( + fn=generate_img, + inputs=[Textbox(label="Note: If you want to specify a aspect ratio or determine a customized height and width, " + "use --ar h:w (or --aspect_ratio h:w) or --hw h:w. If no aspect ratio or hw is given, all setting will be default.", + placeholder="Please enter your prompt. \n"), + gr.Radio( + choices=["iddpm", "dpm-solver", "sa-solver"], + label=f"Sampler", + interactive=True, + value='dpm-solver', + ), + gr.Slider( + label='Sample Steps', + minimum=1, + maximum=100, + value=14, + step=1 + ), + gr.Slider( + label='Guidance Scale', + minimum=0.1, + maximum=30.0, + value=4.5, + step=0.1 + ), + gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ), + gr.Checkbox(label="Randomize seed", value=True), + ], + outputs=[Image(type="numpy", label="Img"), + Textbox(label="clean prompt"), + Textbox(label="model info"), + gr.Slider(label='seed')], + title=title, + description=DESCRIPTION, + examples=examples + ) + demo.launch(server_name="0.0.0.0", server_port=args.port, debug=True) \ No newline at end of file diff --git a/scripts/style.css b/scripts/style.css new file mode 100644 index 0000000000000000000000000000000000000000..c635eaa8cd6e4bf96c879a5034a4d31c41287e47 --- /dev/null +++ b/scripts/style.css @@ -0,0 +1,9 @@ +/*.gradio-container{width:680px!important}*/ +/* style.css */ +.gradio_group, .gradio_row, .gradio_column { + display: flex; + flex-direction: row; + justify-content: flex-start; + align-items: flex-start; + flex-wrap: wrap; +} \ No newline at end of file diff --git a/tools/convert_pixart_to_diffusers.py b/tools/convert_pixart_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..d715a69ef505da2e34471b8647ba543295d86bcb --- /dev/null +++ b/tools/convert_pixart_to_diffusers.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python +from __future__ import annotations +import argparse +import os +import sys +from pathlib import Path +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel +from scripts.diffusers_patches import pixart_sigma_init_patched_inputs + + +ckpt_id = "PixArt-alpha" +# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125 +interpolation_scale_alpha = {256: 1, 512: 1, 1024: 2} +interpolation_scale_sigma = {256: 0.5, 512: 1, 1024: 2, 2048: 4} + + +def main(args): + interpolation_scale = interpolation_scale_alpha if args.version == "alpha" else interpolation_scale_sigma + all_state_dict = torch.load(args.orig_ckpt_path) + state_dict = all_state_dict.pop("state_dict") + converted_state_dict = {} + + # Patch embeddings. + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") + + # Caption projection. + converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") + + # AdaLN-single LN + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + + if args.micro_condition: + # Resolution. + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop( + "csize_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop( + "csize_embedder.mlp.0.bias" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop( + "csize_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop( + "csize_embedder.mlp.2.bias" + ) + # Aspect ratio. + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop( + "ar_embedder.mlp.0.weight" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop( + "ar_embedder.mlp.0.bias" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop( + "ar_embedder.mlp.2.weight" + ) + converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop( + "ar_embedder.mlp.2.bias" + ) + # Shared norm. + converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") + converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") + + for depth in range(28): + # Transformer blocks. + converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( + f"blocks.{depth}.scale_shift_table" + ) + # Attention is all you need 🤘 + + # Self attention. + q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias + # Projection. + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.attn.proj.bias" + ) + if args.qk_norm: + converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.bias" + ) + + # Feed-forward. + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.fc2.bias" + ) + + # Cross-attention. + q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") + q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") + k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.bias" + ) + + # Final block. + converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") + + # PixArt XL/2 + # tmp patches for diffusers PixArtSigmaPipeline Implementation + print( + "Changing _init_patched_inputs method of diffusers.models.Transformer2DModel " + "using scripts.diffusers_patches.pixart_sigma_init_patched_inputs") + setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs) + + transformer = Transformer2DModel( + sample_size=args.image_size // 8, + num_layers=28, + attention_head_dim=72, + in_channels=4, + out_channels=8, + patch_size=2, + attention_bias=True, + num_attention_heads=16, + cross_attention_dim=1152, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + caption_channels=4096, + interpolation_scale=interpolation_scale[args.image_size], + ) + transformer.load_state_dict(converted_state_dict, strict=True) + + assert transformer.pos_embed.pos_embed is not None + try: + state_dict.pop("y_embedder.y_embedding") + state_dict.pop("pos_embed") + except: + pass + assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + if args.only_transformer: + transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) + else: + if args.version == "alpha": + # pixart-alpha vae link: https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/sd-vae-ft-ema + vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/PixArt-alpha", subfolder="sd-vae-ft-ema") + elif args.verision == "sigma": + # pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae + vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae") + else: + raise ValueError(f"{args.version} is NOT defined. Only alpha or sigma is available") + + scheduler = DPMSolverMultistepScheduler() + + tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder") + + pipeline = PixArtAlphaPipeline( + tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler + ) + + pipeline.save_pretrained(args.dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training." + ) + parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.") + parser.add_argument("--kv_compress", action="store_true", help="If use kv compression during training.") + parser.add_argument( + "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--version", default="alpha", type=str, help="PixArt version to convert", choices=["alpha", "sigma"] + ) + parser.add_argument( + "--image_size", + default=1024, + type=int, + choices=[256, 512, 1024, 2048], + required=False, + help="Image size of pretrained model, 256, 512, 1024, or 2048.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--only_transformer", default=True, type=bool, required=True) + + args = parser.parse_args() + main(args) diff --git a/tools/download.py b/tools/download.py new file mode 100644 index 0000000000000000000000000000000000000000..874977a708e2ae5e435fa88cc8618dac6e9f6de1 --- /dev/null +++ b/tools/download.py @@ -0,0 +1,57 @@ + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for downloading pre-trained PixArt models +""" +from torchvision.datasets.utils import download_url +import torch +import os +import argparse + + +pretrained_models = { + 'PixArt-Sigma-XL-2-512-MS.pth', 'PixArt-Sigma-XL-2-256x256.pth', 'PixArt-Sigma-XL-2-1024-MS.pth' +} + + +def find_model(model_name): + """ + Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints + return download_model(model_name) + else: # Load a custom PixArt checkpoint: + assert os.path.isfile(model_name), f'Could not find PixArt checkpoint at {model_name}' + return torch.load(model_name, map_location=lambda storage, loc: storage) + + +def download_model(model_name): + """ + Downloads a pre-trained PixArt model from the web. + """ + assert model_name in pretrained_models + local_path = f'output/pretrained_models/{model_name}' + if not os.path.isfile(local_path): + os.makedirs('output/pretrained_models', exist_ok=True) + web_path = f'https://huggingface.co/PixArt-alpha/PixArt-Sigma/resolve/main/{model_name}' + download_url(web_path, 'output/pretrained_models/') + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--model_names', nargs='+', type=str, default=pretrained_models) + args = parser.parse_args() + model_names = args.model_names + model_names = set(model_names) + + # Download PixArt checkpoints + for model in model_names: + download_model(model) + print('Done.') diff --git a/tools/extract_features.py b/tools/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..25fc35942927fb547cda53789d425957430420f5 --- /dev/null +++ b/tools/extract_features.py @@ -0,0 +1,354 @@ +import os +from pathlib import Path +import sys +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) +from PIL import Image +import torch +from torchvision import transforms as T +import numpy as np +import json +from tqdm import tqdm +import argparse +import threading +from queue import Queue +from torch.utils.data import DataLoader, RandomSampler +from accelerate import Accelerator +from torchvision.transforms.functional import InterpolationMode +from torchvision.datasets.folder import default_loader +from transformers import T5Tokenizer, T5EncoderModel + +from diffusers.models import AutoencoderKL +from diffusion.data.datasets.InternalData import InternalData +from diffusion.utils.misc import SimpleTimer +from diffusion.utils.data_sampler import AspectRatioBatchSampler +from diffusion.data.builder import DATASETS +from diffusion.data.datasets.utils import * + + +def get_closest_ratio(height: float, width: float, ratios: dict): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return ratios[closest_ratio], float(closest_ratio) + + +@DATASETS.register_module() +class DatasetExtract(InternalData): + def __init__(self, + root, # Notice: need absolute path here + image_list_json=['data_info.json'], + transform=None, + resolution=1024, + load_vae_feat=False, + aspect_ratio_type=None, + start_index=0, + end_index=100_000_000, + multiscale=True, + **kwargs): + self.root = root + self.img_dir_name = 'InternImgs' # need to change to according to your data structure + self.json_dir_name = 'InternData' # need to change to according to your data structure + self.transform = transform + self.load_vae_feat = load_vae_feat + self.resolution = resolution + self.meta_data_clean = [] + self.img_samples = [] + self.txt_feat_samples = [] + self.interpolate_model = InterpolationMode.BICUBIC + if multiscale: + self.aspect_ratio = aspect_ratio_type + assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880] + if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]: + self.interpolate_model = InterpolationMode.LANCZOS + self.ratio_index = {} + self.ratio_nums = {} + for k, v in self.aspect_ratio.items(): + self.ratio_index[float(k)] = [] # used for self.getitem + self.ratio_nums[float(k)] = 0 # used for batch-sampler + + image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json] + for json_file in image_list_json: + meta_data = self.load_json(os.path.join(self.root, json_file)) + meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5] + self.meta_data_clean.extend(meta_data_clean) + self.img_samples.extend([os.path.join(self.root.replace(self.json_dir_name, self.img_dir_name), item['path']) for item in meta_data_clean]) + self.img_samples = self.img_samples[start_index: end_index] + + if multiscale: + # scan the dataset for ratio static + for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]): + ori_h, ori_w = info['height'], info['width'] + closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio) + self.ratio_nums[closest_ratio] += 1 + if len(self.ratio_index[closest_ratio]) == 0: + self.ratio_index[closest_ratio].append(i) + + # Set loader and extensions + if self.load_vae_feat: + raise ValueError("No VAE loader here") + self.loader = default_loader + + def __getitem__(self, idx): + data_info = {} + for i in range(20): + try: + img_path = self.img_samples[idx] + img = self.loader(img_path) + if self.transform: + img = self.transform(img) + # Calculate closest aspect ratio and resize & crop image[w, h] + elif isinstance(img, Image.Image): + h, w = (img.size[1], img.size[0]) + assert h, w == (self.meta_data_clean[idx]['height'], self.meta_data_clean[idx]['width']) + closest_size, closest_ratio = get_closest_ratio(h, w, self.aspect_ratio) + closest_size = list(map(lambda x: int(x), closest_size)) + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.Resize(closest_size, interpolation=self.interpolate_model), # Image.BICUBIC or Image.LANCZOS + T.CenterCrop(closest_size), + T.ToTensor(), + T.Normalize([.5], [.5]), + ]) + img = transform(img) + data_info['img_hw'] = torch.tensor([h, w], dtype=torch.float32) + data_info['aspect_ratio'] = closest_ratio + # change the path according to your data structure + return img, img_path.split('/')[-1] # change from 'serial-number-of-dir/serial-number-of-image.png' ---> 'serial-number-of-dir_serial-number-of-image.png' + except Exception as e: + print(f"Error details: {str(e)}") + with open('./failed_files.txt', 'a+') as f: + f.write(self.img_samples[idx] + "\n") + idx = np.random.randint(len(self)) + raise RuntimeError('Too many bad data.') + + def get_data_info(self, idx): + data_info = self.meta_data_clean[idx] + return {'height': data_info['height'], 'width': data_info['width']} + + +def extract_caption_t5_do(q): + while not q.empty(): + item = q.get() + extract_caption_t5_job(item) + q.task_done() + + +def extract_caption_t5_job(item): + global mutex + global t5 + global t5_save_dir + global count + global total_item + + with torch.no_grad(): + # make sure the save path is unique here + save_path = os.path.join(t5_save_dir, f"{Path(item['path']).stem}") + if os.path.exists(save_path + ".npz"): + count += 1 + return + + caption = item[args.caption_label].strip() + if isinstance(caption, str): + caption = [caption] + + try: + mutex.acquire() + caption_token = tokenizer(caption, max_length=args.max_length, padding="max_length", truncation=True, return_tensors="pt").to(device) + caption_emb = text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0] + + mutex.release() + emb_dict = { + 'caption_feature': caption_emb.to(torch.float16).cpu().data.numpy(), + 'attention_mask': caption_token.attention_mask.to(torch.int16).cpu().data.numpy(), + } + os.umask(0o000) # file permission: 666; dir permission: 777 + np.savez_compressed(save_path, **emb_dict) + count += 1 + except Exception as e: + print(e) + print(f"CUDA: {os.environ['CUDA_VISIBLE_DEVICES']}, processed: {count}/{total_item}, User Prompt = {args.caption_label}, token length: {args.max_length}, saved at: {t5_save_dir}") + + +def extract_caption_t5(): + global tokenizer + global text_encoder + global t5_save_dir + global count + global total_item + + tokenizer = T5Tokenizer.from_pretrained(args.t5_models_dir, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.t5_models_dir, subfolder="text_encoder", torch_dtype=torch.float16).to(device) + count = 0 + + t5_save_dir = os.path.join(args.t5_save_root, f"{args.caption_label}_caption_features_new".replace('prompt_', '')) + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(t5_save_dir, exist_ok=True) + + train_data_json = json.load(open(args.t5_json_path, 'r')) + train_data = train_data_json[args.start_index: args.end_index] + total_item = len(train_data) + + global mutex + mutex = threading.Lock() + jobs = Queue() + + for item in tqdm(train_data): + jobs.put(item) + + for _ in range(20): + worker = threading.Thread(target=extract_caption_t5_do, args=(jobs,)) + worker.start() + + jobs.join() + + +def extract_img_vae(bs): + print("Starting") + accelerator = Accelerator(mixed_precision='fp16') + vae = AutoencoderKL.from_pretrained(f'{args.vae_models_dir}', torch_dtype=torch.float16).to(device) + print('VAE Loaded') + + vae_save_dir = f'{args.vae_save_root}/img_sdxl_vae_features_{image_resize}resolution_new' + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(vae_save_dir, exist_ok=True) + interpolation = InterpolationMode.BILINEAR + if image_resize in [2048, 2880]: + interpolation = InterpolationMode.LANCZOS + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.Resize(image_resize, interpolation=interpolation), + T.CenterCrop(image_resize), + T.ToTensor(), + T.Normalize([.5], [.5]), + ]) + signature = '' + dataset = DatasetExtract(args.dataset_root, image_list_json=[args.vae_json_file], transform=transform, sample_subset=None, + start_index=args.start_index, end_index=args.end_index, multiscale=False, work_dir=os.path.join(vae_save_dir, signature)) + dataloader = DataLoader(dataset, batch_size=bs, num_workers=13, pin_memory=True) + dataloader = accelerator.prepare(dataloader, ) + + inference(vae, dataloader, signature=signature, work_dir=vae_save_dir) + accelerator.wait_for_everyone() + + return + + +def save_results(results, paths, signature, work_dir): + timer = SimpleTimer(len(results), log_interval=100, desc=f"Saving at {work_dir}") + # save to npy + new_paths = [] + new_folder = signature + save_folder = os.path.join(work_dir, new_folder) + os.makedirs(save_folder, exist_ok=True) + os.umask(0o000) # file permission: 666; dir permission: 777 + for res, p in zip(results, paths): + file_name = p.split('.')[0] + '.npy' + save_path = os.path.join(save_folder, file_name) + if os.path.exists(save_path): + continue + new_paths.append(os.path.join(new_folder, file_name)) + np.save(save_path, res) + timer.log() + # save paths + with open(os.path.join(work_dir, f"VAE-{signature}.txt"), 'a+') as f: + f.write('\n'.join(new_paths)) + + +def inference(vae, dataloader, signature, work_dir): + timer = SimpleTimer(len(dataloader), log_interval=100, desc=f"VAE-Inference") + + for step, batch in enumerate(dataloader): + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=True): + posterior = vae.encode(batch[0]).latent_dist + results = torch.cat([posterior.mean, posterior.std], dim=1).detach().cpu().numpy() + path = batch[1] + save_results(results, path, signature=signature, work_dir=work_dir) + timer.log() + + +def extract_img_vae_multiscale(bs=1): + + assert image_resize in [512, 1024, 2048, 2880] + work_dir = f"{os.path.abspath(args.vae_save_root)}/img_sdxl_vae_features_{image_resize}resolution_ms_new" + os.umask(0o000) # file permission: 666; dir permission: 777 + os.makedirs(work_dir, exist_ok=True) + accelerator = Accelerator(mixed_precision='fp16') + vae = AutoencoderKL.from_pretrained(f'{args.vae_models_dir}').to(device) + + signature = '' + + aspect_ratio_type = eval(f"ASPECT_RATIO_{image_resize}") + print(f"Aspect Ratio Here: {aspect_ratio_type}") + dataset = DatasetExtract( + args.dataset_root, image_list_json=[args.vae_json_file], transform=None, sample_subset=None, + aspect_ratio_type=aspect_ratio_type, start_index=args.start_index, end_index=args.end_index, + work_dir=os.path.join(work_dir, signature) + ) + + # create AspectRatioBatchSampler + sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset, batch_size=bs, aspect_ratios=dataset.aspect_ratio, ratio_nums=dataset.ratio_nums) + + # create DataLoader + dataloader = DataLoader(dataset, batch_sampler=sampler, num_workers=13, pin_memory=True) + dataloader = accelerator.prepare(dataloader, ) + + inference(vae, dataloader, signature=signature, work_dir=work_dir) + accelerator.wait_for_everyone() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_t5_feature_extract", action='store_true', help="run t5 feature extracting") + parser.add_argument("--run_vae_feature_extract", action='store_true', help="run VAE feature extracting") + parser.add_argument('--start_index', default=0, type=int) + parser.add_argument('--end_index', default=50000000, type=int) + + ### vae feauture extraction + parser.add_argument("--multi_scale", action='store_true', help="multi-scale feature extraction") + parser.add_argument("--img_size", default=512, type=int, help="image scale for VAE feature extraction") + parser.add_argument('--dataset_root', default='pixart-sigma-toy-dataset', type=str) + parser.add_argument('--vae_json_file', type=str) # relative to args.dataset_root + parser.add_argument( + '--vae_models_dir', default='madebyollin/sdxl-vae-fp16-fix', type=str + ) + parser.add_argument( + '--vae_save_root', default='pixart-sigma-toy-dataset/InternData', + type=str + ) + + ### for t5 feature + parser.add_argument("--max_length", default=300, type=int, help="max token length for T5") + parser.add_argument('--t5_json_path', type=str) # absolute path or relative to this project + parser.add_argument( + '--t5_models_dir', default='PixArt-alpha/PixArt-XL-2-1024-MS', type=str + ) + parser.add_argument('--caption_label', default='prompt', type=str) + parser.add_argument('--t5_save_root', default='pixart-sigma-toy-dataset/InternData', type=str) + return parser.parse_args() + + +if __name__ == '__main__': + + args = get_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + image_resize = args.img_size + + # prepare extracted caption t5 features for training + if args.run_t5_feature_extract: + extract_caption_t5() + + # prepare extracted image vae features for training + if args.run_vae_feature_extract: + if args.multi_scale: + assert args.img_size in [512, 1024, 2048, 2880],\ + "Multi Scale VAE feature is not for 256px in PixArt-Sigma." + print('Extracting Multi-scale Image Resolution based on %s' % image_resize) + extract_img_vae_multiscale(bs=1) # recommend bs = 1 for AspectRatioBatchSampler + else: + assert args.img_size == 256,\ + f"Single Scale VAE feature is only for 256px in PixArt-Sigma. NOT for {args.img_size}px" + print('Extracting Single Image Resolution %s' % image_resize) + extract_img_vae(bs=2) + + print("Done") \ No newline at end of file diff --git a/train_scripts/train.py b/train_scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5ddae11ee134cf557ee69ccd5385a20805e3fc06 --- /dev/null +++ b/train_scripts/train.py @@ -0,0 +1,481 @@ +import argparse +import datetime +import os +import sys +import time +import types +import warnings +from pathlib import Path + +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) + +import numpy as np +import torch +from accelerate import Accelerator, InitProcessGroupKwargs +from accelerate.utils import DistributedType +from diffusers.models import AutoencoderKL +from transformers import T5EncoderModel, T5Tokenizer +from mmcv.runner import LogBuffer +from PIL import Image +from torch.utils.data import RandomSampler + +from diffusion import IDDPM, DPMS +from diffusion.data.builder import build_dataset, build_dataloader, set_data_root +from diffusion.model.builder import build_model +from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint +from diffusion.utils.data_sampler import AspectRatioBatchSampler +from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_, flush +from diffusion.utils.logger import get_root_logger, rename_file_with_creation_time +from diffusion.utils.lr_scheduler import build_lr_scheduler +from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow +from diffusion.utils.optimizer import build_optimizer, auto_scale_lr + +warnings.filterwarnings("ignore") # ignore warning + + +def set_fsdp_env(): + os.environ["ACCELERATE_USE_FSDP"] = 'true' + os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP' + os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE' + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock' + + +@torch.inference_mode() +def log_validation(model, step, device, vae=None): + torch.cuda.empty_cache() + model = accelerator.unwrap_model(model).eval() + hw = torch.tensor([[1024, 1024]], dtype=torch.float, device=device).repeat(1, 1) + ar = torch.tensor([[1.]], device=device).repeat(1, 1) + null_y = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth') + null_y = null_y['uncond_prompt_embeds'].to(device) + + # Create sampling noise: + logger.info("Running validation... ") + image_logs = [] + latents = [] + + for prompt in validation_prompts: + z = torch.randn(1, 4, latent_size, latent_size, device=device) + embed = torch.load(f'output/tmp/{prompt}_{max_length}token.pth', map_location='cpu') + caption_embs, emb_masks = embed['caption_embeds'].to(device), embed['emb_mask'].to(device) + # caption_embs = caption_embs[:, None] + # emb_masks = emb_masks[:, None] + model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + + dpm_solver = DPMS(model.forward_with_dpmsolver, + condition=caption_embs, + uncondition=null_y, + cfg_scale=4.5, + model_kwargs=model_kwargs) + denoised = dpm_solver.sample( + z, + steps=14, + order=2, + skip_type="time_uniform", + method="multistep", + ) + latents.append(denoised) + + torch.cuda.empty_cache() + if vae is None: + vae = AutoencoderKL.from_pretrained(config.vae_pretrained).to(accelerator.device).to(torch.float16) + for prompt, latent in zip(validation_prompts, latents): + latent = latent.to(torch.float16) + samples = vae.decode(latent.detach() / vae.config.scaling_factor).sample + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] + image = Image.fromarray(samples) + image_logs.append({"validation_prompt": prompt, "images": [image]}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + import wandb + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del vae + flush() + return image_logs + + +def train(): + if config.get('debug_nan', False): + DebugUnderflowOverflow(model) + logger.info('NaN debugger registered. Start to detect overflow during training.') + time_start, last_tic = time.time(), time.time() + log_buffer = LogBuffer() + + global_step = start_step + 1 + + load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False) + load_t5_feat = getattr(train_dataloader.dataset, 'load_t5_feat', False) + # Now you train the model + for epoch in range(start_epoch + 1, config.num_epochs + 1): + data_time_start= time.time() + data_time_all = 0 + for step, batch in enumerate(train_dataloader): + if step < skip_step: + global_step += 1 + continue # skip data in the resumed ckpt + if load_vae_feat: + z = batch[0] + else: + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=(config.mixed_precision == 'fp16' or config.mixed_precision == 'bf16')): + posterior = vae.encode(batch[0]).latent_dist + if config.sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + clean_images = z * config.scale_factor + data_info = batch[3] + + if load_t5_feat: + y = batch[1] + y_mask = batch[2] + else: + with torch.no_grad(): + txt_tokens = tokenizer( + batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + y = text_encoder( + txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] + y_mask = txt_tokens.attention_mask[:, None, None] + + # Sample a random timestep for each image + bs = clean_images.shape[0] + timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=clean_images.device).long() + grad_norm = None + data_time_all += time.time() - data_time_start + with accelerator.accumulate(model): + # Predict the noise residual + optimizer.zero_grad() + loss_term = train_diffusion.training_losses(model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info)) + loss = loss_term['loss'].mean() + accelerator.backward(loss) + if accelerator.sync_gradients: + grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip) + optimizer.step() + lr_scheduler.step() + + lr = lr_scheduler.get_last_lr()[0] + logs = {args.loss_report_name: accelerator.gather(loss).mean().item()} + if grad_norm is not None: + logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) + log_buffer.update(logs) + if (step + 1) % config.log_interval == 0 or (step + 1) == 1: + t = (time.time() - last_tic) / config.log_interval + t_d = data_time_all / config.log_interval + avg_time = (time.time() - time_start) / (global_step + 1) + eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1)))) + eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1)))) + log_buffer.average() + info = f"Step/Epoch [{global_step}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \ + f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({model.module.h}, {model.module.w}), " + info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) + logger.info(info) + last_tic = time.time() + log_buffer.clear() + data_time_all = 0 + logs.update(lr=lr) + accelerator.log(logs, step=global_step) + + global_step += 1 + data_time_start = time.time() + + if global_step % config.save_model_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + os.umask(0o000) + save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), + epoch=epoch, + step=global_step, + model=accelerator.unwrap_model(model), + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + if config.visualize and (global_step % config.eval_sampling_steps == 0 or (step + 1) == 1): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + log_validation(model, global_step, device=accelerator.device, vae=vae) + + if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + os.umask(0o000) + save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), + epoch=epoch, + step=global_step, + model=accelerator.unwrap_model(model), + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + accelerator.wait_for_everyone() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Process some integers.") + parser.add_argument("config", type=str, help="config") + parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine") + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument('--resume-from', help='the dir to resume the training') + parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training') + parser.add_argument('--local-rank', type=int, default=-1) + parser.add_argument('--local_rank', type=int, default=-1) + parser.add_argument('--debug', action='store_true') + parser.add_argument( + "--pipeline_load_from", default='output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers', + type=str, help="Download for loading text_encoder, " + "tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument("--loss_report_name", type=str, default="loss") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + config = read_config(args.config) + if args.work_dir is not None: + config.work_dir = args.work_dir + if args.resume_from is not None: + config.load_from = None + config.resume_from = dict( + checkpoint=args.resume_from, + load_ema=False, + resume_optimizer=True, + resume_lr_scheduler=True) + if args.debug: + config.log_interval = 1 + config.train_batch_size = 2 + + os.umask(0o000) + os.makedirs(config.work_dir, exist_ok=True) + + init_handler = InitProcessGroupKwargs() + init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug + # Initialize accelerator and tensorboard logging + if config.use_fsdp: + init_train = 'FSDP' + from accelerate import FullyShardedDataParallelPlugin + from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig + set_fsdp_env() + fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),) + else: + init_train = 'DDP' + fsdp_plugin = None + + even_batches = True + if config.multi_scale: + even_batches=False, + + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + log_with=args.report_to, + project_dir=os.path.join(config.work_dir, "logs"), + fsdp_plugin=fsdp_plugin, + even_batches=even_batches, + kwargs_handlers=[init_handler] + ) + + log_name = 'train_log.log' + if accelerator.is_main_process: + if os.path.exists(os.path.join(config.work_dir, log_name)): + rename_file_with_creation_time(os.path.join(config.work_dir, log_name)) + logger = get_root_logger(os.path.join(config.work_dir, log_name)) + + logger.info(accelerator.state) + config.seed = init_random_seed(config.get('seed', None)) + set_random_seed(config.seed) + + if accelerator.is_main_process: + config.dump(os.path.join(config.work_dir, 'config.py')) + + logger.info(f"Config: \n{config.pretty_text}") + logger.info(f"World_size: {get_world_size()}, seed: {config.seed}") + logger.info(f"Initializing: {init_train} for training") + image_size = config.image_size # @param [256, 512] + latent_size = int(image_size) // 8 + pred_sigma = getattr(config, 'pred_sigma', True) + learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma + max_length = config.model_max_length + kv_compress_config = config.kv_compress_config if config.kv_compress else None + vae = None + if not config.data.load_vae_feat: + vae = AutoencoderKL.from_pretrained(config.vae_pretrained, torch_dtype=torch.float16).to(accelerator.device) + config.scale_factor = vae.config.scaling_factor + tokenizer = text_encoder = None + if not config.data.load_t5_feat: + tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) + + logger.info(f"vae sacle factor: {config.scale_factor}") + + if config.visualize: + # preparing embeddings for visualization. We put it here for saving GPU memory + validation_prompts = [ + "dog", + "portrait photo of a girl, photograph, highly detailed face, depth of field", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + ] + skip = True + for prompt in validation_prompts: + if not (os.path.exists(f'output/tmp/{prompt}_{max_length}token.pth') + and os.path.exists(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth')): + skip = False + logger.info("Preparing Visualization prompt embeddings...") + break + if accelerator.is_main_process and not skip: + if config.data.load_t5_feat and (tokenizer is None or text_encoder is None): + logger.info(f"Loading text encoder and tokenizer from {args.pipeline_load_from} ...") + tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) + for prompt in validation_prompts: + txt_tokens = tokenizer( + prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] + torch.save( + {'caption_embeds': caption_emb, 'emb_mask': txt_tokens.attention_mask}, + f'output/tmp/{prompt}_{max_length}token.pth') + null_tokens = tokenizer( + "", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] + torch.save( + {'uncond_prompt_embeds': null_token_emb, 'uncond_prompt_embeds_mask': null_tokens.attention_mask}, + f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth') + if config.data.load_t5_feat: + del tokenizer + del txt_tokens + flush() + + model_kwargs={"pe_interpolation": config.pe_interpolation, "config":config, + "model_max_length": max_length, "qk_norm": config.qk_norm, + "kv_compress_config": kv_compress_config, "micro_condition": config.micro_condition} + + # build models + train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, snr=config.snr_loss) + model = build_model(config.model, + config.grad_checkpointing, + config.get('fp32_attention', False), + input_size=latent_size, + learn_sigma=learn_sigma, + pred_sigma=pred_sigma, + **model_kwargs).train() + logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}") + + if args.load_from is not None: + config.load_from = args.load_from + if config.load_from is not None: + missing, unexpected = load_checkpoint( + config.load_from, model, load_ema=config.get('load_ema', False), max_length=max_length) + logger.warning(f'Missing keys: {missing}') + logger.warning(f'Unexpected keys: {unexpected}') + + # prepare for FSDP clip grad norm calculation + if accelerator.distributed_type == DistributedType.FSDP: + for m in accelerator._models: + m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) + + # build dataloader + set_data_root(config.data_root) + dataset = build_dataset( + config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type, + real_prompt_ratio=config.real_prompt_ratio, max_length=max_length, config=config, + ) + if config.multi_scale: + batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset, + batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True, + ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num) + train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers) + else: + train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True) + + # build optimizer and lr scheduler + lr_scale_ratio = 1 + if config.get('auto_lr', None): + lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps, + config.optimizer, **config.auto_lr) + optimizer = build_optimizer(model, config.optimizer) + lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio) + + timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + + if accelerator.is_main_process: + tracker_config = dict(vars(config)) + try: + accelerator.init_trackers(args.tracker_project_name, tracker_config) + except: + accelerator.init_trackers(f"tb_{timestamp}") + + start_epoch = 0 + start_step = 0 + skip_step = config.skip_step + total_steps = len(train_dataloader) * config.num_epochs + + if config.resume_from is not None and config.resume_from['checkpoint'] is not None: + resume_path = config.resume_from['checkpoint'] + path = os.path.basename(resume_path) + start_epoch = int(path.replace('.pth', '').split("_")[1]) - 1 + start_step = int(path.replace('.pth', '').split("_")[3]) + _, missing, unexpected = load_checkpoint(**config.resume_from, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + max_length=max_length, + ) + + logger.warning(f'Missing keys: {missing}') + logger.warning(f'Unexpected keys: {unexpected}') + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model = accelerator.prepare(model) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + train() diff --git a/train_scripts/train_pixart_lcm.py b/train_scripts/train_pixart_lcm.py new file mode 100644 index 0000000000000000000000000000000000000000..f0d84a5800d2d5d96ddcf991b2a536d920aa39d1 --- /dev/null +++ b/train_scripts/train_pixart_lcm.py @@ -0,0 +1,570 @@ +import os +import sys +import types +from pathlib import Path +current_file_path = Path(__file__).resolve() +sys.path.insert(0, str(current_file_path.parent.parent)) +import argparse +import datetime +import time +import warnings +warnings.filterwarnings("ignore") # ignore warning +import torch +import torch.nn as nn +from accelerate import Accelerator, InitProcessGroupKwargs +from accelerate.utils import DistributedType +from diffusers.models import AutoencoderKL +from transformers import T5EncoderModel, T5Tokenizer +from torch.utils.data import RandomSampler +from mmcv.runner import LogBuffer +from copy import deepcopy +import numpy as np +import torch.nn.functional as F +from tqdm import tqdm +from PIL import Image +import gc + +from diffusion import IDDPM +from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint +from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_, flush +from diffusion.data.builder import build_dataset, build_dataloader, set_data_root +from diffusion.model.builder import build_model +from diffusion.utils.logger import get_root_logger +from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow +from diffusion.utils.optimizer import build_optimizer, auto_scale_lr +from diffusion.utils.lr_scheduler import build_lr_scheduler +from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler +from diffusion.lcm_scheduler import LCMScheduler +from torchvision.utils import save_image + + +def set_fsdp_env(): + os.environ["ACCELERATE_USE_FSDP"] = 'true' + os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP' + os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE' + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock' + + +def ema_update(model_dest: nn.Module, model_src: nn.Module, rate): + param_dict_src = dict(model_src.named_parameters()) + for p_name, p_dest in model_dest.named_parameters(): + p_src = param_dict_src[p_name] + assert p_src is not p_dest + p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# From LCMScheduler.get_scalings_for_boundary_condition_discrete +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) + c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class DDIMSolver: + def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): + # DDIM sampling parameters + step_ratio = timesteps // ddim_timesteps + + self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device): + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step(self, pred_x0, pred_noise, timestep_index): + alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + +@torch.inference_mode() +def log_validation(model, step, device): + torch.cuda.empty_cache() + model = accelerator.unwrap_model(model).eval() + scheduler = LCMScheduler(beta_start=0.0001, beta_end=0.02, beta_schedule="linear", prediction_type="epsilon") + scheduler.set_timesteps(4, 50) + infer_timesteps = scheduler.timesteps + + hw = torch.tensor([[1024, 1024]], dtype=torch.float, device=device).repeat(1, 1) + ar = torch.tensor([[1.]], device=device).repeat(1, 1) + # Create sampling noise: + logger.info("Running validation... ") + image_logs = [] + + latents = [] + for prompt in validation_prompts: + infer_latents = torch.randn(1, 4, latent_size, latent_size, device=device) + embed = torch.load(f'output/tmp/{prompt}_{max_length}token.pth', map_location='cpu') + caption_embs, emb_masks = embed['caption_embeds'].to(device), embed['emb_mask'].to(device) + model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) + + # 7. LCM MultiStep Sampling Loop: + for i, t in tqdm(list(enumerate(infer_timesteps))): + ts = torch.full((1,), t, device=device, dtype=torch.long) + + # model prediction (v-prediction, eps, x) + model_pred = model(infer_latents, ts, caption_embs, **model_kwargs)[:, :4] + + # compute the previous noisy sample x_t -> x_t-1 + infer_latents, denoised = scheduler.step(model_pred, i, t, infer_latents, return_dict=False) + latents.append(denoised) + torch.cuda.empty_cache() + vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda() + for prompt, latent in zip(validation_prompts, latents): + samples = vae.decode(latent.detach() / vae.config.scaling_factor).sample + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] + image = Image.fromarray(samples) + image_logs.append({"validation_prompt": prompt, "images": [image]}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + import wandb + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + gc.collect() + torch.cuda.empty_cache() + return image_logs + + +def train(): + if config.get('debug_nan', False): + DebugUnderflowOverflow(model) + logger.info('NaN debugger registered. Start to detect overflow during training.') + time_start, last_tic = time.time(), time.time() + log_buffer = LogBuffer() + + start_step = start_epoch * len(train_dataloader) + global_step = 0 + total_steps = len(train_dataloader) * config.num_epochs + + load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False) + load_t5_feat = getattr(train_dataloader.dataset, 'load_t5_feat', False) + + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = model.module.y_embedder.y_embedding.repeat(config.train_batch_size, 1, 1, 1) + + # Now you train the model + for epoch in range(start_epoch + 1, config.num_epochs + 1): + data_time_start= time.time() + data_time_all = 0 + for step, batch in enumerate(train_dataloader): + data_time_all += time.time() - data_time_start + if load_vae_feat: + z = batch[0] + else: + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'): + posterior = vae.encode(batch[0]).latent_dist + if config.sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + latents = z * config.scale_factor + data_info = {'img_hw': batch[3]['img_hw'].to(latents.dtype), 'aspect_ratio': batch[3]['aspect_ratio'].to(latents.dtype),} + if load_t5_feat: + y = batch[1] + y_mask = batch[2] + else: + with torch.no_grad(): + txt_tokens = tokenizer( + batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + y = text_encoder( + txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] + y_mask = txt_tokens.attention_mask[:, None, None] + + # Sample a random timestep for each image + grad_norm = None + with accelerator.accumulate(model): + # Predict the noise residual + optimizer.zero_grad() + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + topk = config.train_sampling_steps // config.num_ddim_timesteps + index = torch.randint(0, config.num_ddim_timesteps, (bsz,), device=latents.device).long() + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # Get boundary scalings for start_timesteps and (end) timesteps. + c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] + c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] + + # Sample a random guidance scale w from U[w_min, w_max] and embed it + # w = (config.w_max - config.w_min) * torch.rand((bsz,)) + config.w_min + w = config.cfg_scale * torch.ones((bsz,)) + w = w.reshape(bsz, 1, 1, 1) + w = w.to(device=latents.device, dtype=latents.dtype) + + # Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + _, pred_x_0, noisy_model_input = train_diffusion.training_losses( + model, latents, start_timesteps, + model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), + noise=noise + ) + + model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 + + # Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after + # noisy_latents with both the conditioning embedding c and unconditional embedding 0 + # Get teacher model prediction on noisy_latents and conditional embedding + with torch.no_grad(): + with torch.autocast("cuda"): + cond_teacher_output, cond_pred_x0, _ = train_diffusion.training_losses( + model_teacher, latents, start_timesteps, + model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), + noise=noise + ) + + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_teacher_output, uncond_pred_x0, _ = train_diffusion.training_losses( + model_teacher, latents, start_timesteps, + model_kwargs=dict(y=uncond_prompt_embeds, mask=y_mask, data_info=data_info), + noise=noise + ) + + # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + + # Get target LCM prediction on x_prev, w, c, t_n + with torch.no_grad(): + with torch.autocast("cuda", enabled=True): + _, pred_x_0, _ = train_diffusion.training_losses( + model_ema, x_prev.float(), timesteps, + model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), + skip_noise=True + ) + + target = c_skip * x_prev + c_out * pred_x_0 + + # Calculate loss + if config.loss_type == "l2": + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + elif config.loss_type == "huber": + loss = torch.mean(torch.sqrt((model_pred.float() - target.float()) ** 2 + config.huber_c**2) - config.huber_c) + + # Backpropagation on the online student model (`model`) + accelerator.backward(loss) + if accelerator.sync_gradients: + grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + ema_update(model_ema, model, config.ema_decay) + + lr = lr_scheduler.get_last_lr()[0] + logs = {"loss": accelerator.gather(loss).mean().item()} + if grad_norm is not None: + logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) + log_buffer.update(logs) + if (step + 1) % config.log_interval == 0 or (step + 1) == 1: + t = (time.time() - last_tic) / config.log_interval + t_d = data_time_all / config.log_interval + avg_time = (time.time() - time_start) / (global_step + 1) + eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1)))) + eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1)))) + # avg_loss = sum(loss_buffer) / len(loss_buffer) + log_buffer.average() + info = f"Step/Epoch [{(epoch-1)*len(train_dataloader)+step+1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \ + f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({data_info['img_hw'][0][0].item()}, {data_info['img_hw'][0][1].item()}), " + info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) + logger.info(info) + last_tic = time.time() + log_buffer.clear() + data_time_all = 0 + logs.update(lr=lr) + accelerator.log(logs, step=global_step + start_step) + + global_step += 1 + data_time_start= time.time() + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0: + os.umask(0o000) + save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), + epoch=epoch, + step=(epoch - 1) * len(train_dataloader) + step + 1, + model=accelerator.unwrap_model(model), + model_ema=accelerator.unwrap_model(model_ema), + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + if ((epoch - 1) * len(train_dataloader) + step + 1) % config.eval_sampling_steps == 0: + log_validation(model, global_step, device=accelerator.device) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs: + os.umask(0o000) + save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), + epoch=epoch, + step=(epoch - 1) * len(train_dataloader) + step + 1, + model=accelerator.unwrap_model(model), + model_ema=accelerator.unwrap_model(model_ema), + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + synchronize() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Process some integers.") + parser.add_argument("config", type=str, help="config") + parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine") + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument('--resume-from', help='the dir to resume the training') + parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training') + parser.add_argument('--local-rank', type=int, default=-1) + parser.add_argument('--local_rank', type=int, default=-1) + parser.add_argument('--debug', action='store_true') + parser.add_argument( + "--pipeline_load_from", default='output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers', + type=str, help="Download for loading text_encoder, " + "tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + config = read_config(args.config) + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + config.work_dir = args.work_dir + if args.cloud: + config.data_root = '/data/data' + if args.resume_from is not None: + config.load_from = None + config.resume_from = dict( + checkpoint=args.resume_from, + load_ema=False, + resume_optimizer=True, + resume_lr_scheduler=True) + if args.debug: + config.log_interval = 1 + config.train_batch_size = 2 + + os.umask(0o000) + os.makedirs(config.work_dir, exist_ok=True) + + init_handler = InitProcessGroupKwargs() + init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug + # Initialize accelerator and tensorboard logging + if config.use_fsdp: + init_train = 'FSDP' + from accelerate import FullyShardedDataParallelPlugin + from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig + set_fsdp_env() + fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),) + else: + init_train = 'DDP' + fsdp_plugin = None + + even_batches = True + if config.multi_scale: + even_batches=False, + + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + log_with="tensorboard", + project_dir=os.path.join(config.work_dir, "logs"), + fsdp_plugin=fsdp_plugin, + even_batches=even_batches, + kwargs_handlers=[init_handler] + ) + + logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log')) + + config.seed = init_random_seed(config.get('seed', None)) + set_random_seed(config.seed) + + if accelerator.is_main_process: + config.dump(os.path.join(config.work_dir, 'config.py')) + + logger.info(f"Config: \n{config.pretty_text}") + logger.info(f"World_size: {get_world_size()}, seed: {config.seed}") + logger.info(f"Initializing: {init_train} for training") + image_size = config.image_size # @param [256, 512] + latent_size = int(image_size) // 8 + pred_sigma = getattr(config, 'pred_sigma', True) + learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma + max_length = config.model_max_length + model_kwargs={"pe_interpolation": config.pe_interpolation, 'config':config, 'model_max_length': max_length} + + # build models + train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, + snr=config.snr_loss, return_startx=True) + model = build_model(config.model, + config.grad_checkpointing, + config.get('fp32_attention', False), + input_size=latent_size, + learn_sigma=learn_sigma, + pred_sigma=pred_sigma, + **model_kwargs).train() + logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}") + + if config.load_from is not None: + if args.load_from is not None: + config.load_from = args.load_from + missing, unexpected = load_checkpoint( + config.load_from, model, load_ema=config.get('load_ema', False), max_length=max_length) + logger.warning(f'Missing keys: {missing}') + logger.warning(f'Unexpected keys: {unexpected}') + + model_ema = deepcopy(model).eval() + model_teacher = deepcopy(model).eval() + + if not config.data.load_vae_feat: + vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda() + + # prepare for FSDP clip grad norm calculation + if accelerator.distributed_type == DistributedType.FSDP: + for m in accelerator._models: + m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) + tokenizer = text_encoder = None + if not config.data.load_t5_feat: + tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) + + logger.info(f"vae sacle factor: {config.scale_factor}") + + # build dataloader + set_data_root(config.data_root) + dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type) + if config.multi_scale: + batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset, + batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True, + ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num) + train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers) + else: + train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True) + + # preparing embeddings for visualization. We put it here for saving GPU memory + validation_prompts = [ + "dog", + "portrait photo of a girl, photograph, highly detailed face, depth of field", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + ] + logger.info("Preparing Visulalization prompt embeddings...") + skip = True + for prompt in validation_prompts: + if not os.path.exists(f'output/tmp/{prompt}_{max_length}token.pth'): + skip = False + break + logger.info("Preparing Visualization prompt embeddings...") + if accelerator.is_main_process and not skip: + if config.data.load_t5_feat and (tokenizer is None or text_encoder is None): + logger.info(f"Loading text encoder and tokenizer from {args.pipeline_load_from} ...") + tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") + text_encoder = T5EncoderModel.from_pretrained( + args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) + for prompt in validation_prompts: + txt_tokens = tokenizer( + prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(accelerator.device) + caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] + torch.save( + {'caption_embeds': caption_emb, 'emb_mask': txt_tokens.attention_mask}, + f'output/tmp/{prompt}_{max_length}token.pth') + if config.data.load_t5_feat: + del tokenizer + del txt_tokens + flush() + time.sleep(5) + + # build optimizer and lr scheduler + lr_scale_ratio = 1 + if config.get('auto_lr', None): + lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps, + config.optimizer, + **config.auto_lr) + optimizer = build_optimizer(model, config.optimizer) + lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio) + + timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + + if accelerator.is_main_process: + accelerator.init_trackers(f"tb_{timestamp}") + + start_epoch = 0 + if config.resume_from is not None and config.resume_from['checkpoint'] is not None: + start_epoch, missing, unexpected = load_checkpoint(**config.resume_from, + model=model, + model_ema=model_ema, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + + logger.warning(f'Missing keys: {missing}') + logger.warning(f'Unexpected keys: {unexpected}') + + solver = DDIMSolver(train_diffusion.alphas_cumprod, timesteps=config.train_sampling_steps, ddim_timesteps=config.num_ddim_timesteps) + solver.to(accelerator.device) + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model, model_ema, model_teacher = accelerator.prepare(model, model_ema, model_teacher) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + train()