diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..269f8828ae13248392cc4301b48102b5b9a7e2ed --- /dev/null +++ b/.gitignore @@ -0,0 +1,167 @@ + +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so +checkpoints/ +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +pytestdebug.log + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +doc/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# End of https://www.toptal.com/developers/gitignore/api/python + +/shelf/ +/workspace.xml + +dataset +dataset_raw +raw +results +inference/chunks_temp.json +logs +hubert/checkpoint_best_legacy_500.pt +configs/config.json +filelists/test.txt +filelists/train.txt +filelists/val.txt +.idea/ +.vscode/ +.idea/modules.xml +.idea/so-vits-svc.iml +.idea/vcs.xml +.idea/inspectionProfiles/profiles_settings.xml +.idea/inspectionProfiles/Project_Default.xml +pretrain/ +.vscode/launch.json + +trained/**/ diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 0000000000000000000000000000000000000000..fb05db4af7d21f5fc40159754a27f55b77031175 --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,4 @@ +select = ["E", "F", "I"] + +# Never enforce `E501` (line length violations). +ignore = ["E501", "E741"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/README_zh_CN.md b/README_zh_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..71c3a60abbc7e8425056f2108ef9465e14eeb7d5 --- /dev/null +++ b/README_zh_CN.md @@ -0,0 +1,572 @@ +
+LOGO + +# SoftVC VITS Singing Voice Conversion + +[**English**](./README.md) | [**中文简体**](./README_zh_CN.md) + +[![在Google Cloab中打开](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb) +[![LICENSE](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE) + +本轮限时更新即将结束,仓库将进入Archieve状态,望周知 + +
+ + +#### ✨ 带有 F0 曲线编辑器,角色混合时间轴编辑器的推理端 (Onnx 模型的用途): [MoeVoiceStudio](https://github.com/NaruseMioShirakana/MoeVoiceStudio) + +#### ✨ 改善了交互的一个分支推荐: [34j/so-vits-svc-fork](https://github.com/34j/so-vits-svc-fork) + +#### ✨ 支持实时转换的一个客户端: [w-okada/voice-changer](https://github.com/w-okada/voice-changer) + +**本项目与 Vits 有着根本上的不同。Vits 是 TTS,本项目是 SVC。本项目无法实现 TTS,Vits 也无法实现 SVC,这两个项目的模型是完全不通用的。** + +## 重要通知 + +这个项目是为了让开发者最喜欢的动画角色唱歌而开发的,任何涉及真人的东西都与开发者的意图背道而驰。 + +## 声明 + +本项目为开源、离线的项目,SvcDevelopTeam 的所有成员与本项目的所有开发者以及维护者(以下简称贡献者)对本项目没有控制力。本项目的贡献者从未向任何组织或个人提供包括但不限于数据集提取、数据集加工、算力支持、训练支持、推理等一切形式的帮助;本项目的贡献者不知晓也无法知晓使用者使用该项目的用途。故一切基于本项目训练的 AI 模型和合成的音频都与本项目贡献者无关。一切由此造成的问题由使用者自行承担。 + +此项目完全离线运行,不能收集任何用户信息或获取用户输入数据。因此,这个项目的贡献者不知道所有的用户输入和模型,因此不负责任何用户输入。 + +本项目只是一个框架项目,本身并没有语音合成的功能,所有的功能都需要用户自己训练模型。同时,这个项目没有任何模型,任何二次分发的项目都与这个项目的贡献者无关。 + +## 📏 使用规约 + +# Warning:请自行解决数据集授权问题,禁止使用非授权数据集进行训练!任何由于使用非授权数据集进行训练造成的问题,需自行承担全部责任和后果!与仓库、仓库维护者、svc develop team 无关! + +1. 本项目是基于学术交流目的建立,仅供交流与学习使用,并非为生产环境准备。 +2. 任何发布到视频平台的基于 sovits 制作的视频,都必须要在简介明确指明用于变声器转换的输入源歌声、音频,例如:使用他人发布的视频 / 音频,通过分离的人声作为输入源进行转换的,必须要给出明确的原视频、音乐链接;若使用是自己的人声,或是使用其他歌声合成引擎合成的声音作为输入源进行转换的,也必须在简介加以说明。 +3. 由输入源造成的侵权问题需自行承担全部责任和一切后果。使用其他商用歌声合成软件作为输入源时,请确保遵守该软件的使用条例,注意,许多歌声合成引擎使用条例中明确指明不可用于输入源进行转换! +4. 禁止使用该项目从事违法行为与宗教、政治等活动,该项目维护者坚决抵制上述行为,不同意此条则禁止使用该项目。 +5. 继续使用视为已同意本仓库 README 所述相关条例,本仓库 README 已进行劝导义务,不对后续可能存在问题负责。 +6. 如果将此项目用于任何其他企划,请提前联系并告知本仓库作者,十分感谢。 + +## 📝 模型简介 + +歌声音色转换模型,通过 SoftVC 内容编码器提取源音频语音特征,与 F0 同时输入 VITS 替换原本的文本输入达到歌声转换的效果。同时,更换声码器为 [NSF HiFiGAN](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan) 解决断音问题。 + +### 🆕 4.1-Stable 版本更新内容 + ++ 特征输入更换为 [Content Vec](https://github.com/auspicious3000/contentvec) 的第 12 层 Transformer 输出,并兼容 4.0 分支 ++ 更新浅层扩散,可以使用浅层扩散模型提升音质 ++ 增加 whisper 语音编码器的支持 ++ 增加静态/动态声线融合 ++ 增加响度嵌入 ++ 增加特征检索,来自于 [RVC](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) + +### 🆕 关于兼容 4.0 模型的问题 + ++ 可通过修改 4.0 模型的 config.json 对 4.0 的模型进行支持,需要在 config.json 的 model 字段中添加 speech_encoder 字段,具体见下 + +``` + "model": { + ......... + "ssl_dim": 256, + "n_speakers": 200, + "speech_encoder":"vec256l9" + } +``` + +### 🆕 关于浅扩散 +![Diagram](shadowdiffusion.png) + +## 💬 关于 Python 版本问题 + +在进行测试后,我们认为`Python 3.8.9`能够稳定地运行该项目 + +## 📥 预先下载的模型文件 + +#### **必须项** + +**以下编码器需要选择一个使用** + +##### **1. 若使用 contentvec 作为声音编码器(推荐)** + +`vec768l12`与`vec256l9` 需要该编码器 + ++ contentvec :[checkpoint_best_legacy_500.pt](https://ibm.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr) + + 放在`pretrain`目录下 + +或者下载下面的 ContentVec,大小只有 199MB,但效果相同: ++ contentvec :[hubert_base.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt) + + 将文件名改为`checkpoint_best_legacy_500.pt`后,放在`pretrain`目录下 + +```shell +# contentvec +wget -P pretrain/ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -O checkpoint_best_legacy_500.pt +# 也可手动下载放在 pretrain 目录 +``` + +##### **2. 若使用 hubertsoft 作为声音编码器** ++ soft vc hubert:[hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt) + + 放在`pretrain`目录下 + +##### **3. 若使用 Whisper-ppg 作为声音编码器** ++ 下载模型 [medium.pt](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt), 该模型适配`whisper-ppg` ++ 下载模型 [large-v2.pt](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt), 该模型适配`whisper-ppg-large` + + 放在`pretrain`目录下 + +##### **4. 若使用 cnhubertlarge 作为声音编码器** ++ 下载模型 [chinese-hubert-large-fairseq-ckpt.pt](https://huggingface.co/TencentGameMate/chinese-hubert-large/resolve/main/chinese-hubert-large-fairseq-ckpt.pt) + + 放在`pretrain`目录下 + +##### **5. 若使用 dphubert 作为声音编码器** ++ 下载模型 [DPHuBERT-sp0.75.pth](https://huggingface.co/pyf98/DPHuBERT/resolve/main/DPHuBERT-sp0.75.pth) + + 放在`pretrain`目录下 + +##### **6. 若使用 WavLM 作为声音编码器** ++ 下载模型 [WavLM-Base+.pt](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D), 该模型适配`wavlmbase+` + + 放在`pretrain`目录下 + +##### **7. 若使用 OnnxHubert/ContentVec 作为声音编码器** ++ 下载模型 [MoeSS-SUBModel](https://huggingface.co/NaruseMioShirakana/MoeSS-SUBModel/tree/main) + + 放在`pretrain`目录下 + +#### **编码器列表** +- "vec768l12" +- "vec256l9" +- "vec256l9-onnx" +- "vec256l12-onnx" +- "vec768l9-onnx" +- "vec768l12-onnx" +- "hubertsoft-onnx" +- "hubertsoft" +- "whisper-ppg" +- "cnhubertlarge" +- "dphubert" +- "whisper-ppg-large" +- "wavlmbase+" + +#### **可选项(强烈建议使用)** + ++ 预训练底模文件: `G_0.pth` `D_0.pth` + + 放在`logs/44k`目录下 + ++ 扩散模型预训练底模文件: `model_0.pt` + + 放在`logs/44k/diffusion`目录下 + +从 svc-develop-team(待定)或任何其他地方获取 Sovits 底模 + +扩散模型引用了 [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) 的 Diffusion Model,底模与 [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) 的扩散模型底模通用,可以去 [Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) 获取扩散模型的底模 + +虽然底模一般不会引起什么版权问题,但还是请注意一下,比如事先询问作者,又或者作者在模型描述中明确写明了可行的用途 + +#### **可选项(根据情况选择)** + +##### NSF-HIFIGAN + +如果使用`NSF-HIFIGAN 增强器`或`浅层扩散`的话,需要下载预训练的 NSF-HIFIGAN 模型,如果不需要可以不下载 + ++ 预训练的 NSF-HIFIGAN 声码器 :[nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip) + + 解压后,将四个文件放在`pretrain/nsf_hifigan`目录下 + +```shell +# nsf_hifigan +wget -P pretrain/ https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip +unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip +# 也可手动下载放在 pretrain/nsf_hifigan 目录 +# 地址:https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1 +``` + +##### RMVPE + +如果使用`rmvpe`F0预测器的话,需要下载预训练的 RMVPE 模型 + ++ 下载模型 [rmvpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt) + + 放在`pretrain`目录下 + +##### FCPE(预览版) + +> 你说的对,但是[FCPE](https://github.com/CNChTu/MelPE)是由svc-develop-team自主研发的一款全新的F0预测器,后面忘了 + +[FCPE(Fast Context-base Pitch Estimator)](https://github.com/CNChTu/MelPE)是一个为实时语音转换所设计的专用F0预测器,他将在未来成为Sovits实时语音转换的首选F0预测器.(论文未来会有的) + +如果使用 `fcpe` F0预测器的话,需要下载预训练的 FCPE 模型 + ++ 下载模型 [fcpe.pt](https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt) + + 放在`pretrain`目录下 + + +## 📊 数据集准备 + +仅需要以以下文件结构将数据集放入 dataset_raw 目录即可。 + +``` +dataset_raw +├───speaker0 +│ ├───xxx1-xxx1.wav +│ ├───... +│ └───Lxx-0xx8.wav +└───speaker1 + ├───xx2-0xxx2.wav + ├───... + └───xxx7-xxx007.wav +``` +对于每一个音频文件的名称并没有格式的限制(`000001.wav`~`999999.wav`之类的命名方式也是合法的),不过文件类型必须是`wav`。 + +可以自定义说话人名称 + +``` +dataset_raw +└───suijiSUI + ├───1.wav + ├───... + └───25788785-20221210-200143-856_01_(Vocals)_0_0.wav +``` + +## 🛠️ 数据预处理 + +### 0. 音频切片 + +将音频切片至`5s - 15s`, 稍微长点也无伤大雅,实在太长可能会导致训练中途甚至预处理就爆显存 + +可以使用 [audio-slicer-GUI](https://github.com/flutydeer/audio-slicer)、[audio-slicer-CLI](https://github.com/openvpi/audio-slicer) + +一般情况下只需调整其中的`Minimum Interval`,普通陈述素材通常保持默认即可,歌唱素材可以调整至`100`甚至`50` + +切完之后手动删除过长过短的音频 + +**如果你使用 Whisper-ppg 声音编码器进行训练,所有的切片长度必须小于 30s** + +### 1. 重采样至 44100Hz 单声道 + +```shell +python resample.py +``` + +#### 注意 + +虽然本项目拥有重采样、转换单声道与响度匹配的脚本 resample.py,但是默认的响度匹配是匹配到 0db。这可能会造成音质的受损。而 python 的响度匹配包 pyloudnorm 无法对电平进行压限,这会导致爆音。所以建议可以考虑使用专业声音处理软件如`adobe audition`等软件做响度匹配处理。若已经使用其他软件做响度匹配,可以在运行上述命令时添加`--skip_loudnorm`跳过响度匹配步骤。如: + +```shell +python resample.py --skip_loudnorm +``` + +### 2. 自动划分训练集、验证集,以及自动生成配置文件 + +```shell +python preprocess_flist_config.py --speech_encoder vec768l12 +``` + +speech_encoder 拥有以下选择 + +``` +vec768l12 +vec256l9 +hubertsoft +whisper-ppg +whisper-ppg-large +cnhubertlarge +dphubert +wavlmbase+ +``` + +如果省略 speech_encoder 参数,默认值为 vec768l12 + +**使用响度嵌入** + +若使用响度嵌入,需要增加`--vol_aug`参数,比如: + +```shell +python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug +``` +使用后训练出的模型将匹配到输入源响度,否则为训练集响度。 + +#### 此时可以在生成的 config.json 与 diffusion.yaml 修改部分参数 + +##### config.json + +* `keep_ckpts`:训练时保留最后几个模型,`0`为保留所有,默认只保留最后`3`个 + +* `all_in_mem`:加载所有数据集到内存中,某些平台的硬盘 IO 过于低下、同时内存容量 **远大于** 数据集体积时可以启用 + +* `batch_size`:单次训练加载到 GPU 的数据量,调整到低于显存容量的大小即可 + +* `vocoder_name` : 选择一种声码器,默认为`nsf-hifigan`. + +##### diffusion.yaml + +* `cache_all_data`:加载所有数据集到内存中,某些平台的硬盘 IO 过于低下、同时内存容量 **远大于** 数据集体积时可以启用 + +* `duration`:训练时音频切片时长,可根据显存大小调整,**注意,该值必须小于训练集内音频的最短时间!** + +* `batch_size`:单次训练加载到 GPU 的数据量,调整到低于显存容量的大小即可 + +* `timesteps` : 扩散模型总步数,默认为 1000. + +* `k_step_max` : 训练时可仅训练`k_step_max`步扩散以节约训练时间,注意,该值必须小于`timesteps`,0 为训练整个扩散模型,**注意,如果不训练整个扩散模型将无法使用仅扩散模型推理!** + +##### **声码器列表** + +``` +nsf-hifigan +nsf-snake-hifigan +``` + +### 3. 生成 hubert 与 f0 + +```shell +python preprocess_hubert_f0.py --f0_predictor dio +``` + +f0_predictor 拥有以下选择 + +``` +crepe +dio +pm +harvest +rmvpe +fcpe +``` + +如果训练集过于嘈杂,请使用 crepe 处理 f0 + +如果省略 f0_predictor 参数,默认值为 rmvpe + +尚若需要浅扩散功能(可选),需要增加--use_diff 参数,比如 + +```shell +python preprocess_hubert_f0.py --f0_predictor dio --use_diff +``` + +**加速预处理** +如若您的数据集比较大,可以尝试添加`--num_processes`参数: +```shell +python preprocess_hubert_f0.py --f0_predictor dio --use_diff --num_processes 8 +``` +所有的Workers会被自动分配到多个线程上 + +执行完以上步骤后 dataset 目录便是预处理完成的数据,可以删除 dataset_raw 文件夹了 + +## 🏋️‍ 训练 + +### 主模型训练 + +```shell +python train.py -c configs/config.json -m 44k +``` + +### 扩散模型(可选) + +尚若需要浅扩散功能,需要训练扩散模型,扩散模型训练方法为: + +```shell +python train_diff.py -c configs/diffusion.yaml +``` + +模型训练结束后,模型文件保存在`logs/44k`目录下,扩散模型在`logs/44k/diffusion`下 + +## 🤖 推理 + +使用 [inference_main.py](inference_main.py) + +```shell +# 例 +python inference_main.py -m "logs/44k/G_30400.pth" -c "configs/config.json" -n "君の知らない物語-src.wav" -t 0 -s "nen" +``` + +必填项部分: ++ `-m` | `--model_path`:模型路径 ++ `-c` | `--config_path`:配置文件路径 ++ `-n` | `--clean_names`:wav 文件名列表,放在 raw 文件夹下 ++ `-t` | `--trans`:音高调整,支持正负(半音) ++ `-s` | `--spk_list`:合成目标说话人名称 ++ `-cl` | `--clip`:音频强制切片,默认 0 为自动切片,单位为秒/s + +可选项部分:部分具体见下一节 ++ `-lg` | `--linear_gradient`:两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值 0,单位为秒 ++ `-f0p` | `--f0_predictor`:选择 F0 预测器,可选择 crepe,pm,dio,harvest,rmvpe,fcpe, 默认为 pm(注意:crepe 为原 F0 使用均值滤波器) ++ `-a` | `--auto_predict_f0`:语音转换自动预测音高,转换歌声时不要打开这个会严重跑调 ++ `-cm` | `--cluster_model_path`:聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填 ++ `-cr` | `--cluster_infer_ratio`:聚类方案或特征检索占比,范围 0-1,若没有训练聚类模型或特征检索则默认 0 即可 ++ `-eh` | `--enhance`:是否使用 NSF_HIFIGAN 增强器,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭 ++ `-shd` | `--shallow_diffusion`:是否使用浅层扩散,使用后可解决一部分电音问题,默认关闭,该选项打开时,NSF_HIFIGAN 增强器将会被禁止 ++ `-usm` | `--use_spk_mix`:是否使用角色融合/动态声线融合 ++ `-lea` | `--loudness_envelope_adjustment`:输入源响度包络替换输出响度包络融合比例,越靠近 1 越使用输出响度包络 ++ `-fr` | `--feature_retrieval`:是否使用特征检索,如果使用聚类模型将被禁用,且 cm 与 cr 参数将会变成特征检索的索引路径与混合比例 + +浅扩散设置: ++ `-dm` | `--diffusion_model_path`:扩散模型路径 ++ `-dc` | `--diffusion_config_path`:扩散模型配置文件路径 ++ `-ks` | `--k_step`:扩散步数,越大越接近扩散模型的结果,默认 100 ++ `-od` | `--only_diffusion`:纯扩散模式,该模式不会加载 sovits 模型,以扩散模型推理 ++ `-se` | `--second_encoding`:二次编码,浅扩散前会对原始音频进行二次编码,玄学选项,有时候效果好,有时候效果差 + +### 注意! + +如果使用`whisper-ppg` 声音编码器进行推理,需要将`--clip`设置为 25,`-lg`设置为 1。否则将无法正常推理。 + +## 🤔 可选项 + +如果前面的效果已经满意,或者没看明白下面在讲啥,那后面的内容都可以忽略,不影响模型使用(这些可选项影响比较小,可能在某些特定数据上有点效果,但大部分情况似乎都感知不太明显) + +### 自动 f0 预测 + +4.0 模型训练过程会训练一个 f0 预测器,对于语音转换可以开启自动音高预测,如果效果不好也可以使用手动的,但转换歌声时请不要启用此功能!!!会严重跑调!! ++ 在 inference_main 中设置 auto_predict_f0 为 true 即可 + +### 聚类音色泄漏控制 + +介绍:聚类方案可以减小音色泄漏,使得模型训练出来更像目标的音色(但其实不是特别明显),但是单纯的聚类方案会降低模型的咬字(会口齿不清)(这个很明显),本模型采用了融合的方式,可以线性控制聚类方案与非聚类方案的占比,也就是可以手动在"像目标音色" 和 "咬字清晰" 之间调整比例,找到合适的折中点 + +使用聚类前面的已有步骤不用进行任何的变动,只需要额外训练一个聚类模型,虽然效果比较有限,但训练成本也比较低 + ++ 训练过程: + + 使用 cpu 性能较好的机器训练,据我的经验在腾讯云 6 核 cpu 训练每个 speaker 需要约 4 分钟即可完成训练 + + 执行`python cluster/train_cluster.py`,模型的输出会在`logs/44k/kmeans_10000.pt` + + 聚类模型目前可以使用 gpu 进行训练,执行`python cluster/train_cluster.py --gpu` ++ 推理过程: + + `inference_main.py`中指定`cluster_model_path` 为模型输出文件,留空则默认为`logs/44k/kmeans_10000.pt` + + `inference_main.py`中指定`cluster_infer_ratio`,`0`为完全不使用聚类,`1`为只使用聚类,通常设置`0.5`即可 + +### 特征检索 + +介绍:跟聚类方案一样可以减小音色泄漏,咬字比聚类稍好,但会降低推理速度,采用了融合的方式,可以线性控制特征检索与非特征检索的占比, + ++ 训练过程: + 首先需要在生成 hubert 与 f0 后执行: + +```shell +python train_index.py -c configs/config.json +``` + +模型的输出会在`logs/44k/feature_and_index.pkl` + ++ 推理过程: + + 需要首先指定`--feature_retrieval`,此时聚类方案会自动切换到特征检索方案 + + `inference_main.py`中指定`cluster_model_path` 为模型输出文件,留空则默认为`logs/44k/feature_and_index.pkl` + + `inference_main.py`中指定`cluster_infer_ratio`,`0`为完全不使用特征检索,`1`为只使用特征检索,通常设置`0.5`即可 + + +## 🗜️ 模型压缩 + +生成的模型含有继续训练所需的信息。如果确认不再训练,可以移除模型中此部分信息,得到约 1/3 大小的最终模型。 + +使用 [compress_model.py](compress_model.py) + +```shell +# 例 +python compress_model.py -c="configs/config.json" -i="logs/44k/G_30400.pth" -o="logs/44k/release.pth" +``` + +## 👨‍🔧 声线混合 + +### 静态声线混合 + +**参考`webUI.py`文件中,小工具/实验室特性的静态声线融合。** + +介绍:该功能可以将多个声音模型合成为一个声音模型(多个模型参数的凸组合或线性组合),从而制造出现实中不存在的声线 +**注意:** + +1. 该功能仅支持单说话人的模型 +2. 如果强行使用多说话人模型,需要保证多个模型的说话人数量相同,这样可以混合同一个 SpaekerID 下的声音 +3. 保证所有待混合模型的 config.json 中的 model 字段是相同的 +4. 输出的混合模型可以使用待合成模型的任意一个 config.json,但聚类模型将不能使用 +5. 批量上传模型的时候最好把模型放到一个文件夹选中后一起上传 +6. 混合比例调整建议大小在 0-100 之间,也可以调为其他数字,但在线性组合模式下会出现未知的效果 +7. 混合完毕后,文件将会保存在项目根目录中,文件名为 output.pth +8. 凸组合模式会将混合比例执行 Softmax 使混合比例相加为 1,而线性组合模式不会 + +### 动态声线混合 + +**参考`spkmix.py`文件中关于动态声线混合的介绍** + +角色混合轨道 编写规则: + +角色 ID : \[\[起始时间 1, 终止时间 1, 起始数值 1, 起始数值 1], [起始时间 2, 终止时间 2, 起始数值 2, 起始数值 2]] + +起始时间和前一个的终止时间必须相同,第一个起始时间必须为 0,最后一个终止时间必须为 1 (时间的范围为 0-1) + +全部角色必须填写,不使用的角色填、[\[0., 1., 0., 0.]] 即可 + +融合数值可以随便填,在指定的时间段内从起始数值线性变化为终止数值,内部会自动确保线性组合为 1(凸组合条件),可以放心使用 + +推理的时候使用`--use_spk_mix`参数即可启用动态声线混合 + +## 📤 Onnx 导出 + +使用 [onnx_export.py](onnx_export.py) + ++ 新建文件夹:`checkpoints` 并打开 ++ 在`checkpoints`文件夹中新建一个文件夹作为项目文件夹,文件夹名为你的项目名称,比如`aziplayer` ++ 将你的模型更名为`model.pth`,配置文件更名为`config.json`,并放置到刚才创建的`aziplayer`文件夹下 ++ 将 [onnx_export.py](onnx_export.py) 中`path = "NyaruTaffy"` 的 `"NyaruTaffy"` 修改为你的项目名称,`path = "aziplayer" (onnx_export_speaker_mix,为支持角色混合的 onnx 导出)` ++ 运行 [onnx_export.py](onnx_export.py) ++ 等待执行完毕,在你的项目文件夹下会生成一个`model.onnx`,即为导出的模型 + +注意:Hubert Onnx 模型请使用 MoeSS 提供的模型,目前无法自行导出(fairseq 中 Hubert 有不少 onnx 不支持的算子和涉及到常量的东西,在导出时会报错或者导出的模型输入输出 shape 和结果都有问题) + +## 📎 引用及论文 + +| URL | 名称 | 标题 | 源码 | +| --- | ----------- | ----- | --------------------- | +|[2106.06103](https://arxiv.org/abs/2106.06103) | VITS (Synthesizer)| Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech | [jaywalnut310/vits](https://github.com/jaywalnut310/vits) | +|[2111.02392](https://arxiv.org/abs/2111.02392) | SoftVC (Speech Encoder)| A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion | [bshall/hubert](https://github.com/bshall/hubert) | +|[2204.09224](https://arxiv.org/abs/2204.09224) | ContentVec (Speech Encoder)| ContentVec: An Improved Self-Supervised Speech Representation by Disentangling Speakers | [auspicious3000/contentvec](https://github.com/auspicious3000/contentvec) | +|[2212.04356](https://arxiv.org/abs/2212.04356) | Whisper (Speech Encoder) | Robust Speech Recognition via Large-Scale Weak Supervision | [openai/whisper](https://github.com/openai/whisper) | +|[2110.13900](https://arxiv.org/abs/2110.13900) | WavLM (Speech Encoder) | WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing | [microsoft/unilm/wavlm](https://github.com/microsoft/unilm/tree/master/wavlm) | +|[2305.17651](https://arxiv.org/abs/2305.17651) | DPHubert (Speech Encoder) | DPHuBERT: Joint Distillation and Pruning of Self-Supervised Speech Models | [pyf98/DPHuBERT](https://github.com/pyf98/DPHuBERT) | +|[DOI:10.21437/Interspeech.2017-68](http://dx.doi.org/10.21437/Interspeech.2017-68) | Harvest (F0 Predictor) | Harvest: A high-performance fundamental frequency estimator from speech signals | [mmorise/World/harvest](https://github.com/mmorise/World/blob/master/src/harvest.cpp) | +|[aes35-000039](https://www.aes.org/e-lib/online/browse.cfm?elib=15165) | Dio (F0 Predictor) | Fast and reliable F0 estimation method based on the period extraction of vocal fold vibration of singing voice and speech | [mmorise/World/dio](https://github.com/mmorise/World/blob/master/src/dio.cpp) | +|[8461329](https://ieeexplore.ieee.org/document/8461329) | Crepe (F0 Predictor) | Crepe: A Convolutional Representation for Pitch Estimation | [maxrmorrison/torchcrepe](https://github.com/maxrmorrison/torchcrepe) | +|[DOI:10.1016/j.wocn.2018.07.001](https://doi.org/10.1016/j.wocn.2018.07.001) | Parselmouth (F0 Predictor) | Introducing Parselmouth: A Python interface to Praat | [YannickJadoul/Parselmouth](https://github.com/YannickJadoul/Parselmouth) | +|[2306.15412v2](https://arxiv.org/abs/2306.15412v2) | RMVPE (F0 Predictor) | RMVPE: A Robust Model for Vocal Pitch Estimation in Polyphonic Music | [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE) | +|[2010.05646](https://arxiv.org/abs/2010.05646) | HIFIGAN (Vocoder) | HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | [jik876/hifi-gan](https://github.com/jik876/hifi-gan) | +|[1810.11946](https://arxiv.org/abs/1810.11946.pdf) | NSF (Vocoder) | Neural source-filter-based waveform model for statistical parametric speech synthesis | [openvpi/DiffSinger/modules/nsf_hifigan](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan) +|[2006.08195](https://arxiv.org/abs/2006.08195) | Snake (Vocoder) | Neural Networks Fail to Learn Periodic Functions and How to Fix It | [EdwardDixon/snake](https://github.com/EdwardDixon/snake) +|[2105.02446v3](https://arxiv.org/abs/2105.02446v3) | Shallow Diffusion (PostProcessing)| DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism | [CNChTu/Diffusion-SVC](https://github.com/CNChTu/Diffusion-SVC) | +|[K-means](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=01D65490BADCC216F350D06F84D721AD?doi=10.1.1.308.8619&rep=rep1&type=pdf) | Feature K-means Clustering (PreProcessing)| Some methods for classification and analysis of multivariate observations | 本代码库 | +| | Feature TopK Retrieval (PreProcessing)| Retrieval based Voice Conversion | [RVC-Project/Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) | + +## ☀️ 旧贡献者 + +因为某些原因原作者进行了删库处理,本仓库重建之初由于组织成员疏忽直接重新上传了所有文件导致以前的 contributors 全部木大,现在在 README 里重新添加一个旧贡献者列表 + +*某些成员已根据其个人意愿不将其列出* + + + + + + + + + + + +

MistEO


XiaoMiku01


しぐれ


TomoGaSukunai


Plachtaa


zd 小达


凍聲響世

+ +## 📚 一些法律条例参考 + +#### 任何国家,地区,组织和个人使用此项目必须遵守以下法律 + +#### 《民法典》 + +##### 第一千零一十九条 + +任何组织或者个人不得以丑化、污损,或者利用信息技术手段伪造等方式侵害他人的肖像权。未经肖像权人同意,不得制作、使用、公开肖像权人的肖像,但是法律另有规定的除外。未经肖像权人同意,肖像作品权利人不得以发表、复制、发行、出租、展览等方式使用或者公开肖像权人的肖像。对自然人声音的保护,参照适用肖像权保护的有关规定。 + +##### 第一千零二十四条 + +【名誉权】民事主体享有名誉权。任何组织或者个人不得以侮辱、诽谤等方式侵害他人的名誉权。 + +##### 第一千零二十七条 + +【作品侵害名誉权】行为人发表的文学、艺术作品以真人真事或者特定人为描述对象,含有侮辱、诽谤内容,侵害他人名誉权的,受害人有权依法请求该行为人承担民事责任。行为人发表的文学、艺术作品不以特定人为描述对象,仅其中的情节与该特定人的情况相似的,不承担民事责任。 + +#### 《[中华人民共和国宪法](http://www.gov.cn/guoqing/2018-03/22/content_5276318.htm)》 + +#### 《[中华人民共和国刑法](http://gongbao.court.gov.cn/Details/f8e30d0689b23f57bfc782d21035c3.html?sw=中华人民共和国刑法)》 + +#### 《[中华人民共和国民法典](http://gongbao.court.gov.cn/Details/51eb6750b8361f79be8f90d09bc202.html)》 + +#### 《[中华人民共和国合同法](http://www.npc.gov.cn/zgrdw/npc/lfzt/rlyw/2016-07/01/content_1992739.htm)》 + +## 💪 感谢所有的贡献者 + + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..535c78414b9a9b84292068ae1ea59a4335151a83 --- /dev/null +++ b/app.py @@ -0,0 +1,430 @@ +import glob +import json +import logging +import os +import re +import subprocess +import sys +import time +import traceback +from itertools import chain +from pathlib import Path + +# os.system("wget -P cvec/ https://huggingface.co/spaces/innnky/nanami/resolve/main/checkpoint_best_legacy_500.pt") +import gradio as gr +import librosa +import numpy as np +import soundfile +import torch + +from compress_model import removeOptimizer +from edgetts.tts_voices import SUPPORTED_LANGUAGES +from inference.infer_tool import Svc +from utils import mix_model + +logging.getLogger('numba').setLevel(logging.WARNING) +logging.getLogger('markdown_it').setLevel(logging.WARNING) +logging.getLogger('urllib3').setLevel(logging.WARNING) +logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.getLogger('multipart').setLevel(logging.WARNING) + +model = None +spk = None +debug = False + +local_model_root = './trained' + +cuda = {} +if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + device_name = torch.cuda.get_device_properties(i).name + cuda[f"CUDA:{i} {device_name}"] = f"cuda:{i}" + +def upload_mix_append_file(files,sfiles): + try: + if(sfiles is None): + file_paths = [file.name for file in files] + else: + file_paths = [file.name for file in chain(files,sfiles)] + p = {file:100 for file in file_paths} + return file_paths,mix_model_output1.update(value=json.dumps(p,indent=2)) + except Exception as e: + if debug: + traceback.print_exc() + raise gr.Error(e) + +def mix_submit_click(js,mode): + try: + assert js.lstrip()!="" + modes = {"凸组合":0, "线性组合":1} + mode = modes[mode] + data = json.loads(js) + data = list(data.items()) + model_path,mix_rate = zip(*data) + path = mix_model(model_path,mix_rate,mode) + return f"成功,文件被保存在了{path}" + except Exception as e: + if debug: + traceback.print_exc() + raise gr.Error(e) + +def updata_mix_info(files): + try: + if files is None : + return mix_model_output1.update(value="") + p = {file.name:100 for file in files} + return mix_model_output1.update(value=json.dumps(p,indent=2)) + except Exception as e: + if debug: + traceback.print_exc() + raise gr.Error(e) + +def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix,local_model_enabled,local_model_selection): + global model + try: + device = cuda[device] if "CUDA" in device else device + cluster_filepath = os.path.split(cluster_model_path.name) if cluster_model_path is not None else "no_cluster" + # get model and config path + if (local_model_enabled): + # local path + model_path = glob.glob(os.path.join(local_model_selection, '*.pth'))[0] + config_path = glob.glob(os.path.join(local_model_selection, '*.json'))[0] + else: + # upload from webpage + model_path = model_path.name + config_path = config_path.name + fr = ".pkl" in cluster_filepath[1] + model = Svc(model_path, + config_path, + device=device if device != "Auto" else None, + cluster_model_path = cluster_model_path.name if cluster_model_path is not None else "", + nsf_hifigan_enhance=enhance, + diffusion_model_path = diff_model_path.name if diff_model_path is not None else "", + diffusion_config_path = diff_config_path.name if diff_config_path is not None else "", + shallow_diffusion = True if diff_model_path is not None else False, + only_diffusion = only_diffusion, + spk_mix_enable = use_spk_mix, + feature_retrieval = fr + ) + spks = list(model.spk2id.keys()) + device_name = torch.cuda.get_device_properties(model.dev).name if "cuda" in str(model.dev) else str(model.dev) + msg = f"成功加载模型到设备{device_name}上\n" + if cluster_model_path is None: + msg += "未加载聚类模型或特征检索模型\n" + elif fr: + msg += f"特征检索模型{cluster_filepath[1]}加载成功\n" + else: + msg += f"聚类模型{cluster_filepath[1]}加载成功\n" + if diff_model_path is None: + msg += "未加载扩散模型\n" + else: + msg += f"扩散模型{diff_model_path.name}加载成功\n" + msg += "当前模型的可用音色:\n" + for i in spks: + msg += i + " " + return sid.update(choices = spks,value=spks[0]), msg + except Exception as e: + if debug: + traceback.print_exc() + raise gr.Error(e) + + +def modelUnload(): + global model + if model is None: + return sid.update(choices = [],value=""),"没有模型需要卸载!" + else: + model.unload_model() + model = None + torch.cuda.empty_cache() + return sid.update(choices = [],value=""),"模型卸载完毕!" + +def vc_infer(output_format, sid, audio_path, truncated_basename, vc_transform, auto_f0, cluster_ratio, slice_db, noise_scale, pad_seconds, cl_num, lg_num, lgr_num, f0_predictor, enhancer_adaptive_key, cr_threshold, k_step, use_spk_mix, second_encoding, loudness_envelope_adjustment): + global model + _audio = model.slice_inference( + audio_path, + sid, + vc_transform, + slice_db, + cluster_ratio, + auto_f0, + noise_scale, + pad_seconds, + cl_num, + lg_num, + lgr_num, + f0_predictor, + enhancer_adaptive_key, + cr_threshold, + k_step, + use_spk_mix, + second_encoding, + loudness_envelope_adjustment + ) + model.clear_empty() + #构建保存文件的路径,并保存到results文件夹内 + str(int(time.time())) + if not os.path.exists("results"): + os.makedirs("results") + key = "auto" if auto_f0 else f"{int(vc_transform)}key" + cluster = "_" if cluster_ratio == 0 else f"_{cluster_ratio}_" + isdiffusion = "sovits" + if model.shallow_diffusion: + isdiffusion = "sovdiff" + + if model.only_diffusion: + isdiffusion = "diff" + + output_file_name = 'result_'+truncated_basename+f'_{sid}_{key}{cluster}{isdiffusion}.{output_format}' + output_file = os.path.join("results", output_file_name) + soundfile.write(output_file, _audio, model.target_sample, format=output_format) + return output_file + +def vc_fn(sid, input_audio, output_format, vc_transform, auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment): + global model + try: + if input_audio is None: + return "You need to upload an audio", None + if model is None: + return "You need to upload an model", None + if getattr(model, 'cluster_model', None) is None and model.feature_retrieval is False: + if cluster_ratio != 0: + return "You need to upload an cluster model or feature retrieval model before assigning cluster ratio!", None + #print(input_audio) + audio, sampling_rate = soundfile.read(input_audio) + #print(audio.shape,sampling_rate) + if np.issubdtype(audio.dtype, np.integer): + audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32) + #print(audio.dtype) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio.transpose(1, 0)) + # 未知原因Gradio上传的filepath会有一个奇怪的固定后缀,这里去掉 + truncated_basename = Path(input_audio).stem[:-6] + processed_audio = os.path.join("raw", f"{truncated_basename}.wav") + soundfile.write(processed_audio, audio, sampling_rate, format="wav") + output_file = vc_infer(output_format, sid, processed_audio, truncated_basename, vc_transform, auto_f0, cluster_ratio, slice_db, noise_scale, pad_seconds, cl_num, lg_num, lgr_num, f0_predictor, enhancer_adaptive_key, cr_threshold, k_step, use_spk_mix, second_encoding, loudness_envelope_adjustment) + + return "Success", output_file + except Exception as e: + if debug: + traceback.print_exc() + raise gr.Error(e) + +def text_clear(text): + return re.sub(r"[\n\,\(\) ]", "", text) + +def vc_fn2(_text, _lang, _gender, _rate, _volume, sid, output_format, vc_transform, auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold, k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment): + global model + try: + if model is None: + return "You need to upload an model", None + if getattr(model, 'cluster_model', None) is None and model.feature_retrieval is False: + if cluster_ratio != 0: + return "You need to upload an cluster model or feature retrieval model before assigning cluster ratio!", None + _rate = f"+{int(_rate*100)}%" if _rate >= 0 else f"{int(_rate*100)}%" + _volume = f"+{int(_volume*100)}%" if _volume >= 0 else f"{int(_volume*100)}%" + if _lang == "Auto": + _gender = "Male" if _gender == "男" else "Female" + subprocess.run([sys.executable, "edgetts/tts.py", _text, _lang, _rate, _volume, _gender]) + else: + subprocess.run([sys.executable, "edgetts/tts.py", _text, _lang, _rate, _volume]) + target_sr = 44100 + y, sr = librosa.load("tts.wav") + resampled_y = librosa.resample(y, orig_sr=sr, target_sr=target_sr) + soundfile.write("tts.wav", resampled_y, target_sr, subtype = "PCM_16") + input_audio = "tts.wav" + #audio, _ = soundfile.read(input_audio) + output_file_path = vc_infer(output_format, sid, input_audio, "tts", vc_transform, auto_f0, cluster_ratio, slice_db, noise_scale, pad_seconds, cl_num, lg_num, lgr_num, f0_predictor, enhancer_adaptive_key, cr_threshold, k_step, use_spk_mix, second_encoding, loudness_envelope_adjustment) + os.remove("tts.wav") + return "Success", output_file_path + except Exception as e: + if debug: traceback.print_exc() # noqa: E701 + raise gr.Error(e) + +def model_compression(_model): + if _model == "": + return "请先选择要压缩的模型" + else: + model_path = os.path.split(_model.name) + filename, extension = os.path.splitext(model_path[1]) + output_model_name = f"{filename}_compressed{extension}" + output_path = os.path.join(os.getcwd(), output_model_name) + removeOptimizer(_model.name, output_path) + return f"模型已成功被保存在了{output_path}" + +def scan_local_models(): + res = [] + candidates = glob.glob(os.path.join(local_model_root, '**', '*.json'), recursive=True) + candidates = set([os.path.dirname(c) for c in candidates]) + for candidate in candidates: + jsons = glob.glob(os.path.join(candidate, '*.json')) + pths = glob.glob(os.path.join(candidate, '*.pth')) + if (len(jsons) == 1 and len(pths) == 1): + # must contain exactly one json and one pth file + res.append(candidate) + return res + +def local_model_refresh_fn(): + choices = scan_local_models() + return gr.Dropdown.update(choices=choices) + +def debug_change(): + global debug + debug = debug_button.value + +with gr.Blocks( + theme=gr.themes.Base( + primary_hue = gr.themes.colors.green, + font=["Source Sans Pro", "Arial", "sans-serif"], + font_mono=['JetBrains mono', "Consolas", 'Courier New'] + ), +) as app: + with gr.Tabs(): + with gr.TabItem("推理"): + gr.Markdown(value=""" + So-vits-svc 4.0 推理 webui + """) + with gr.Row(variant="panel"): + with gr.Column(): + gr.Markdown(value=""" + 模型设置 + """) + with gr.Tabs(): + # invisible checkbox that tracks tab status + local_model_enabled = gr.Checkbox(value=False, visible=False) + with gr.TabItem('上传') as local_model_tab_upload: + with gr.Row(): + model_path = gr.File(label="选择模型文件") + config_path = gr.File(label="选择配置文件") + with gr.TabItem('本地') as local_model_tab_local: + gr.Markdown(f'模型应当放置于{local_model_root}文件夹下') + local_model_refresh_btn = gr.Button('刷新本地模型列表') + local_model_selection = gr.Dropdown(label='选择模型文件夹', choices=[], interactive=True) + with gr.Row(): + diff_model_path = gr.File(label="选择扩散模型文件") + diff_config_path = gr.File(label="选择扩散模型配置文件") + cluster_model_path = gr.File(label="选择聚类模型或特征检索文件(没有可以不选)") + device = gr.Dropdown(label="推理设备,默认为自动选择CPU和GPU", choices=["Auto",*cuda.keys(),"cpu"], value="Auto") + enhance = gr.Checkbox(label="是否使用NSF_HIFIGAN增强,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭", value=False) + only_diffusion = gr.Checkbox(label="是否使用全扩散推理,开启后将不使用So-VITS模型,仅使用扩散模型进行完整扩散推理,默认关闭", value=False) + with gr.Column(): + gr.Markdown(value=""" + 左侧文件全部选择完毕后(全部文件模块显示download),点击“加载模型”进行解析: + """) + model_load_button = gr.Button(value="加载模型", variant="primary") + model_unload_button = gr.Button(value="卸载模型", variant="primary") + sid = gr.Dropdown(label="音色(说话人)") + sid_output = gr.Textbox(label="Output Message") + + + with gr.Row(variant="panel"): + with gr.Column(): + gr.Markdown(value=""" + 推理设置 + """) + auto_f0 = gr.Checkbox(label="自动f0预测,配合聚类模型f0预测效果更好,会导致变调功能失效(仅限转换语音,歌声勾选此项会究极跑调)", value=False) + f0_predictor = gr.Dropdown(label="选择F0预测器,可选择crepe,pm,dio,harvest,rmvpe,默认为pm(注意:crepe为原F0使用均值滤波器)", choices=["pm","dio","harvest","crepe","rmvpe"], value="pm") + vc_transform = gr.Number(label="变调(整数,可以正负,半音数量,升高八度就是12)", value=0) + cluster_ratio = gr.Number(label="聚类模型/特征检索混合比例,0-1之间,0即不启用聚类/特征检索。使用聚类/特征检索能提升音色相似度,但会导致咬字下降(如果使用建议0.5左右)", value=0) + slice_db = gr.Number(label="切片阈值", value=-40) + output_format = gr.Radio(label="音频输出格式", choices=["wav", "flac", "mp3"], value = "wav") + noise_scale = gr.Number(label="noise_scale 建议不要动,会影响音质,玄学参数", value=0.4) + k_step = gr.Slider(label="浅扩散步数,只有使用了扩散模型才有效,步数越大越接近扩散模型的结果", value=100, minimum = 1, maximum = 1000) + with gr.Column(): + pad_seconds = gr.Number(label="推理音频pad秒数,由于未知原因开头结尾会有异响,pad一小段静音段后就不会出现", value=0.5) + cl_num = gr.Number(label="音频自动切片,0为不切片,单位为秒(s)", value=0) + lg_num = gr.Number(label="两端音频切片的交叉淡入长度,如果自动切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值0,注意,该设置会影响推理速度,单位为秒/s", value=0) + lgr_num = gr.Number(label="自动音频切片后,需要舍弃每段切片的头尾。该参数设置交叉长度保留的比例,范围0-1,左开右闭", value=0.75) + enhancer_adaptive_key = gr.Number(label="使增强器适应更高的音域(单位为半音数)|默认为0", value=0) + cr_threshold = gr.Number(label="F0过滤阈值,只有启动crepe时有效. 数值范围从0-1. 降低该值可减少跑调概率,但会增加哑音", value=0.05) + loudness_envelope_adjustment = gr.Number(label="输入源响度包络替换输出响度包络融合比例,越靠近1越使用输出响度包络", value = 0) + second_encoding = gr.Checkbox(label = "二次编码,浅扩散前会对原始音频进行二次编码,玄学选项,效果时好时差,默认关闭", value=False) + use_spk_mix = gr.Checkbox(label = "动态声线融合", value = False, interactive = False) + with gr.Tabs(): + with gr.TabItem("音频转音频"): + vc_input3 = gr.Audio(label="选择音频", type="filepath") + vc_submit = gr.Button("音频转换", variant="primary") + with gr.TabItem("文字转音频"): + text2tts=gr.Textbox(label="在此输入要转译的文字。注意,使用该功能建议打开F0预测,不然会很怪") + with gr.Row(): + tts_gender = gr.Radio(label = "说话人性别", choices = ["男","女"], value = "男") + tts_lang = gr.Dropdown(label = "选择语言,Auto为根据输入文字自动识别", choices=SUPPORTED_LANGUAGES, value = "Auto") + tts_rate = gr.Slider(label = "TTS语音变速(倍速相对值)", minimum = -1, maximum = 3, value = 0, step = 0.1) + tts_volume = gr.Slider(label = "TTS语音音量(相对值)", minimum = -1, maximum = 1.5, value = 0, step = 0.1) + vc_submit2 = gr.Button("文字转换", variant="primary") + with gr.Row(): + with gr.Column(): + vc_output1 = gr.Textbox(label="Output Message") + with gr.Column(): + vc_output2 = gr.Audio(label="Output Audio", interactive=False) + + with gr.TabItem("小工具/实验室特性"): + gr.Markdown(value=""" + So-vits-svc 4.0 小工具/实验室特性 + """) + with gr.Tabs(): + with gr.TabItem("静态声线融合"): + gr.Markdown(value=""" + 介绍:该功能可以将多个声音模型合成为一个声音模型(多个模型参数的凸组合或线性组合),从而制造出现实中不存在的声线 + 注意: + 1.该功能仅支持单说话人的模型 + 2.如果强行使用多说话人模型,需要保证多个模型的说话人数量相同,这样可以混合同一个SpaekerID下的声音 + 3.保证所有待混合模型的config.json中的model字段是相同的 + 4.输出的混合模型可以使用待合成模型的任意一个config.json,但聚类模型将不能使用 + 5.批量上传模型的时候最好把模型放到一个文件夹选中后一起上传 + 6.混合比例调整建议大小在0-100之间,也可以调为其他数字,但在线性组合模式下会出现未知的效果 + 7.混合完毕后,文件将会保存在项目根目录中,文件名为output.pth + 8.凸组合模式会将混合比例执行Softmax使混合比例相加为1,而线性组合模式不会 + + """) + mix_model_path = gr.Files(label="选择需要混合模型文件") + mix_model_upload_button = gr.UploadButton("选择/追加需要混合模型文件", file_count="multiple") + mix_model_output1 = gr.Textbox( + label="混合比例调整,单位/%", + interactive = True + ) + mix_mode = gr.Radio(choices=["凸组合", "线性组合"], label="融合模式",value="凸组合",interactive = True) + mix_submit = gr.Button("声线融合启动", variant="primary") + mix_model_output2 = gr.Textbox( + label="Output Message" + ) + mix_model_path.change(updata_mix_info,[mix_model_path],[mix_model_output1]) + mix_model_upload_button.upload(upload_mix_append_file, [mix_model_upload_button,mix_model_path], [mix_model_path,mix_model_output1]) + mix_submit.click(mix_submit_click, [mix_model_output1,mix_mode], [mix_model_output2]) + + with gr.TabItem("模型压缩工具"): + gr.Markdown(value=""" + 该工具可以实现对模型的体积压缩,在**不影响模型推理功能**的情况下,将原本约600M的So-VITS模型压缩至约200M, 大大减少了硬盘的压力。 + **注意:压缩后的模型将无法继续训练,请在确认封炉后再压缩。** + """) + model_to_compress = gr.File(label="模型上传") + compress_model_btn = gr.Button("压缩模型", variant="primary") + compress_model_output = gr.Textbox(label="输出信息", value="") + + compress_model_btn.click(model_compression, [model_to_compress], [compress_model_output]) + + + with gr.Tabs(): + with gr.Row(variant="panel"): + with gr.Column(): + gr.Markdown(value=""" + WebUI设置 + """) + debug_button = gr.Checkbox(label="Debug模式,如果向社区反馈BUG需要打开,打开后控制台可以显示具体错误提示", value=debug) + # refresh local model list + local_model_refresh_btn.click(local_model_refresh_fn, outputs=local_model_selection) + # set local enabled/disabled on tab switch + local_model_tab_upload.select(lambda: False, outputs=local_model_enabled) + local_model_tab_local.select(lambda: True, outputs=local_model_enabled) + + vc_submit.click(vc_fn, [sid, vc_input3, output_format, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment], [vc_output1, vc_output2]) + vc_submit2.click(vc_fn2, [text2tts, tts_lang, tts_gender, tts_rate, tts_volume, sid, output_format, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,f0_predictor,enhancer_adaptive_key,cr_threshold,k_step,use_spk_mix,second_encoding,loudness_envelope_adjustment], [vc_output1, vc_output2]) + + debug_button.change(debug_change,[],[]) + model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix,local_model_enabled,local_model_selection],[sid,sid_output]) + model_unload_button.click(modelUnload,[],[sid,sid_output]) + os.system("start http://127.0.0.1:7860") + app.launch() + + + diff --git a/cluster/__init__.py b/cluster/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae00ea692643e69c8c8c60e392f456ab0adcdd93 --- /dev/null +++ b/cluster/__init__.py @@ -0,0 +1,29 @@ +import torch +from sklearn.cluster import KMeans + + +def get_cluster_model(ckpt_path): + checkpoint = torch.load(ckpt_path) + kmeans_dict = {} + for spk, ckpt in checkpoint.items(): + km = KMeans(ckpt["n_features_in_"]) + km.__dict__["n_features_in_"] = ckpt["n_features_in_"] + km.__dict__["_n_threads"] = ckpt["_n_threads"] + km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"] + kmeans_dict[spk] = km + return kmeans_dict + +def get_cluster_result(model, x, speaker): + """ + x: np.array [t, 256] + return cluster class result + """ + return model[speaker].predict(x) + +def get_cluster_center_result(model, x,speaker): + """x: np.array [t, 256]""" + predict = model[speaker].predict(x) + return model[speaker].cluster_centers_[predict] + +def get_center(model, x,speaker): + return model[speaker].cluster_centers_[x] diff --git a/cluster/kmeans.py b/cluster/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0ac69d9607213f7d592c70962c9de0db27fe6b --- /dev/null +++ b/cluster/kmeans.py @@ -0,0 +1,204 @@ +from time import time + +import numpy as np +import pynvml +import torch +from torch.nn.functional import normalize + + +# device=torch.device("cuda:0") +def _kpp(data: torch.Tensor, k: int, sample_size: int = -1): + """ Picks k points in the data based on the kmeans++ method. + + Parameters + ---------- + data : torch.Tensor + Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D + data, rank 2 multidimensional data, in which case one + row is one observation. + k : int + Number of samples to generate. + sample_size : int + sample data to avoid memory overflow during calculation + + Returns + ------- + init : ndarray + A 'k' by 'N' containing the initial centroids. + + References + ---------- + .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of + careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium + on Discrete Algorithms, 2007. + .. [2] scipy/cluster/vq.py: _kpp + """ + batch_size=data.shape[0] + if batch_size>sample_size: + data = data[torch.randint(0, batch_size,[sample_size], device=data.device)] + dims = data.shape[1] if len(data.shape) > 1 else 1 + init = torch.zeros((k, dims)).to(data.device) + r = torch.distributions.uniform.Uniform(0, 1) + for i in range(k): + if i == 0: + init[i, :] = data[torch.randint(data.shape[0], [1])] + else: + D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0) + probs = D2 / torch.sum(D2) + cumprobs = torch.cumsum(probs, dim=0) + init[i, :] = data[torch.searchsorted(cumprobs, r.sample([1]).to(data.device))] + return init +class KMeansGPU: + ''' + Kmeans clustering algorithm implemented with PyTorch + + Parameters: + n_clusters: int, + Number of clusters + + max_iter: int, default: 100 + Maximum number of iterations + + tol: float, default: 0.0001 + Tolerance + + verbose: int, default: 0 + Verbosity + + mode: {'euclidean', 'cosine'}, default: 'euclidean' + Type of distance measure + + init_method: {'random', 'point', '++'} + Type of initialization + + minibatch: {None, int}, default: None + Batch size of MinibatchKmeans algorithm + if None perform full KMeans algorithm + + Attributes: + centroids: torch.Tensor, shape: [n_clusters, n_features] + cluster centroids + ''' + def __init__(self, n_clusters, max_iter=200, tol=1e-4, verbose=0, mode="euclidean",device=torch.device("cuda:0")): + self.n_clusters = n_clusters + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.mode = mode + self.device=device + pynvml.nvmlInit() + gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) + info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) + self.minibatch=int(33e6/self.n_clusters*info.free/ 1024 / 1024 / 1024) + print("free_mem/GB:",info.free/ 1024 / 1024 / 1024,"minibatch:",self.minibatch) + + @staticmethod + def cos_sim(a, b): + """ + Compute cosine similarity of 2 sets of vectors + + Parameters: + a: torch.Tensor, shape: [m, n_features] + + b: torch.Tensor, shape: [n, n_features] + """ + return normalize(a, dim=-1) @ normalize(b, dim=-1).transpose(-2, -1) + + @staticmethod + def euc_sim(a, b): + """ + Compute euclidean similarity of 2 sets of vectors + Parameters: + a: torch.Tensor, shape: [m, n_features] + b: torch.Tensor, shape: [n, n_features] + """ + return 2 * a @ b.transpose(-2, -1) -(a**2).sum(dim=1)[..., :, None] - (b**2).sum(dim=1)[..., None, :] + + def max_sim(self, a, b): + """ + Compute maximum similarity (or minimum distance) of each vector + in a with all of the vectors in b + Parameters: + a: torch.Tensor, shape: [m, n_features] + b: torch.Tensor, shape: [n, n_features] + """ + if self.mode == 'cosine': + sim_func = self.cos_sim + elif self.mode == 'euclidean': + sim_func = self.euc_sim + sim = sim_func(a, b) + max_sim_v, max_sim_i = sim.max(dim=-1) + return max_sim_v, max_sim_i + + def fit_predict(self, X): + """ + Combination of fit() and predict() methods. + This is faster than calling fit() and predict() seperately. + Parameters: + X: torch.Tensor, shape: [n_samples, n_features] + centroids: {torch.Tensor, None}, default: None + if given, centroids will be initialized with given tensor + if None, centroids will be randomly chosen from X + Return: + labels: torch.Tensor, shape: [n_samples] + + mini_=33kk/k*remain + mini=min(mini_,fea_shape) + offset=log2(k/1000)*1.5 + kpp_all=min(mini_*10/offset,fea_shape) + kpp_sample=min(mini_/12/offset,fea_shape) + """ + assert isinstance(X, torch.Tensor), "input must be torch.Tensor" + assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point" + assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] " + # print("verbose:%s"%self.verbose) + + offset = np.power(1.5,np.log(self.n_clusters / 1000))/np.log(2) + with torch.no_grad(): + batch_size= X.shape[0] + # print(self.minibatch, int(self.minibatch * 10 / offset), batch_size) + start_time = time() + if (self.minibatch*10//offset< batch_size): + x = X[torch.randint(0, batch_size,[int(self.minibatch*10/offset)])].to(self.device) + else: + x = X.to(self.device) + # print(x.device) + self.centroids = _kpp(x, self.n_clusters, min(int(self.minibatch/12/offset),batch_size)) + del x + torch.cuda.empty_cache() + # self.centroids = self.centroids.to(self.device) + num_points_in_clusters = torch.ones(self.n_clusters, device=self.device, dtype=X.dtype)#全1 + closest = None#[3098036]#int64 + if(self.minibatch>=batch_size//2 and self.minibatch=batch_size): + X=X.to(self.device) + for i in range(self.max_iter): + iter_time = time() + if self.minibatch= 2: + print('iter:', i, 'error:', error.item(), 'time spent:', round(time()-iter_time, 4)) + if error <= self.tol: + break + + if self.verbose >= 1: + print(f'used {i+1} iterations ({round(time()-start_time, 4)}s) to cluster {batch_size} items into {self.n_clusters} clusters') + return closest diff --git a/cluster/train_cluster.py b/cluster/train_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..135f179a389804afc0266873ae31a1cf107ebcf8 --- /dev/null +++ b/cluster/train_cluster.py @@ -0,0 +1,85 @@ +import argparse +import logging +import os +import time +from pathlib import Path + +import numpy as np +import torch +import tqdm +from kmeans import KMeansGPU +from sklearn.cluster import KMeans, MiniBatchKMeans + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑 + if str(in_dir).endswith(".ipynb_checkpoints"): + logger.info(f"Ignore {in_dir}") + + logger.info(f"Loading features from {in_dir}") + features = [] + nums = 0 + for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): + # for name in os.listdir(in_dir): + # path="%s/%s"%(in_dir,name) + features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T) + # print(features[-1].shape) + features = np.concatenate(features, axis=0) + print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) + features = features.astype(np.float32) + logger.info(f"Clustering features of shape: {features.shape}") + t = time.time() + if(use_gpu is False): + if use_minibatch: + kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) + else: + kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) + else: + kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)# + features=torch.from_numpy(features)#.to(device) + kmeans.fit_predict(features)# + + print(time.time()-t, "s") + + x = { + "n_features_in_": kmeans.n_features_in_ if use_gpu is False else features.shape[1], + "_n_threads": kmeans._n_threads if use_gpu is False else 4, + "cluster_centers_": kmeans.cluster_centers_ if use_gpu is False else kmeans.centroids.cpu().numpy(), + } + print("end") + + return x + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=Path, default="./dataset/44k", + help='path of training data directory') + parser.add_argument('--output', type=Path, default="logs/44k", + help='path of model output directory') + parser.add_argument('--gpu',action='store_true', default=False , + help='to use GPU') + + + args = parser.parse_args() + + checkpoint_dir = args.output + dataset = args.dataset + use_gpu = args.gpu + n_clusters = 10000 + + ckpt = {} + for spk in os.listdir(dataset): + if os.path.isdir(dataset/spk): + print(f"train kmeans for {spk}...") + in_dir = dataset/spk + x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu) + ckpt[spk] = x + + checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt" + checkpoint_path.parent.mkdir(exist_ok=True, parents=True) + torch.save( + ckpt, + checkpoint_path, + ) + diff --git a/compress_model.py b/compress_model.py new file mode 100644 index 0000000000000000000000000000000000000000..46f188a76c03b979a1bbe82dfdd2c19880408f84 --- /dev/null +++ b/compress_model.py @@ -0,0 +1,72 @@ +from collections import OrderedDict + +import torch + +import utils +from models import SynthesizerTrn + + +def copyStateDict(state_dict): + if list(state_dict.keys())[0].startswith('module'): + start_idx = 1 + else: + start_idx = 0 + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = ','.join(k.split('.')[start_idx:]) + new_state_dict[name] = v + return new_state_dict + + +def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): + hps = utils.get_hparams_from_file(config) + + net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model) + + optim_g = torch.optim.AdamW(net_g.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps) + + state_dict_g = torch.load(input_model, map_location="cpu") + new_dict_g = copyStateDict(state_dict_g) + keys = [] + for k, v in new_dict_g['model'].items(): + if "enc_q" in k: continue # noqa: E701 + keys.append(k) + + new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys} + + torch.save( + { + 'model': new_dict_g, + 'iteration': 0, + 'optimizer': optim_g.state_dict(), + 'learning_rate': 0.0001 + }, output_model) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-c", + "--config", + type=str, + default='configs/config.json') + parser.add_argument("-i", "--input", type=str) + parser.add_argument("-o", "--output", type=str, default=None) + parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16') + + args = parser.parse_args() + + output = args.output + + if output is None: + import os.path + filename, ext = os.path.splitext(args.input) + half = "_half" if args.half else "" + output = filename + "_release" + half + ext + + removeOptimizer(args.config, args.input, args.half, output) \ No newline at end of file diff --git a/configs/diffusion.yaml b/configs/diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs_template/config_template.json b/configs_template/config_template.json new file mode 100644 index 0000000000000000000000000000000000000000..4b1b32321cb92769f27ec3dc78f50e0d9d600b06 --- /dev/null +++ b/configs_template/config_template.json @@ -0,0 +1,79 @@ +{ + "train": { + "log_interval": 200, + "eval_interval": 800, + "seed": 1234, + "epochs": 10000, + "learning_rate": 0.0001, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 6, + "fp16_run": false, + "half_type": "fp16", + "lr_decay": 0.999875, + "segment_size": 10240, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "use_sr": true, + "max_speclen": 512, + "port": "8001", + "keep_ckpts": 3, + "all_in_mem": false, + "vol_aug":false + }, + "data": { + "training_files": "filelists/train.txt", + "validation_files": "filelists/val.txt", + "max_wav_value": 32768.0, + "sampling_rate": 44100, + "filter_length": 2048, + "hop_length": 512, + "win_length": 2048, + "n_mel_channels": 80, + "mel_fmin": 0.0, + "mel_fmax": 22050, + "unit_interpolate_mode":"nearest" + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + "upsample_rates": [ 8, 8, 2, 2, 2], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [16,16, 4, 4, 4], + "n_layers_q": 3, + "n_layers_trans_flow": 3, + "n_flow_layer": 4, + "use_spectral_norm": false, + "gin_channels": 768, + "ssl_dim": 768, + "n_speakers": 200, + "vocoder_name":"nsf-hifigan", + "speech_encoder":"vec768l12", + "speaker_embedding":false, + "vol_embedding":false, + "use_depthwise_conv":false, + "flow_share_parameter": false, + "use_automatic_f0_prediction": true, + "use_transformer_flow": false + }, + "spk": { + "nyaru": 0, + "huiyu": 1, + "nen": 2, + "paimon": 3, + "yunhao": 4 + } +} \ No newline at end of file diff --git a/configs_template/config_tiny_template.json b/configs_template/config_tiny_template.json new file mode 100644 index 0000000000000000000000000000000000000000..d0a4381e72f7f4001675b1a907d45eff98e3ec12 --- /dev/null +++ b/configs_template/config_tiny_template.json @@ -0,0 +1,79 @@ +{ + "train": { + "log_interval": 200, + "eval_interval": 800, + "seed": 1234, + "epochs": 10000, + "learning_rate": 0.0001, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 6, + "fp16_run": false, + "half_type": "fp16", + "lr_decay": 0.999875, + "segment_size": 10240, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "use_sr": true, + "max_speclen": 512, + "port": "8001", + "keep_ckpts": 3, + "all_in_mem": false, + "vol_aug":false + }, + "data": { + "training_files": "filelists/train.txt", + "validation_files": "filelists/val.txt", + "max_wav_value": 32768.0, + "sampling_rate": 44100, + "filter_length": 2048, + "hop_length": 512, + "win_length": 2048, + "n_mel_channels": 80, + "mel_fmin": 0.0, + "mel_fmax": 22050, + "unit_interpolate_mode":"nearest" + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 512, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + "upsample_rates": [ 8, 8, 2, 2, 2], + "upsample_initial_channel": 400, + "upsample_kernel_sizes": [16,16, 4, 4, 4], + "n_layers_q": 3, + "n_layers_trans_flow": 3, + "n_flow_layer": 4, + "use_spectral_norm": false, + "gin_channels": 768, + "ssl_dim": 768, + "n_speakers": 200, + "vocoder_name":"nsf-hifigan", + "speech_encoder":"vec768l12", + "speaker_embedding":false, + "vol_embedding":false, + "use_depthwise_conv":true, + "flow_share_parameter": true, + "use_automatic_f0_prediction": true, + "use_transformer_flow": false + }, + "spk": { + "nyaru": 0, + "huiyu": 1, + "nen": 2, + "paimon": 3, + "yunhao": 4 + } +} \ No newline at end of file diff --git a/configs_template/diffusion_template.yaml b/configs_template/diffusion_template.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b006027d10491a899e893a44e649927f72e30755 --- /dev/null +++ b/configs_template/diffusion_template.yaml @@ -0,0 +1,51 @@ +data: + sampling_rate: 44100 + block_size: 512 # Equal to hop_length + duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip + encoder: 'vec768l12' # 'hubertsoft', 'vec256l9', 'vec768l12' + cnhubertsoft_gate: 10 + encoder_sample_rate: 16000 + encoder_hop_size: 320 + encoder_out_channels: 768 # 256 if using 'hubertsoft' + training_files: "filelists/train.txt" + validation_files: "filelists/val.txt" + extensions: # List of extension included in the data collection + - wav + unit_interpolate_mode: "nearest" +model: + type: 'Diffusion' + n_layers: 20 + n_chans: 512 + n_hidden: 256 + use_pitch_aug: true + timesteps : 1000 + k_step_max: 0 # must <= timesteps, If it is 0, train all + n_spk: 1 # max number of different speakers +device: cuda +vocoder: + type: 'nsf-hifigan' + ckpt: 'pretrain/nsf_hifigan/model' +infer: + speedup: 10 + method: 'dpm-solver++' # 'pndm' or 'dpm-solver' or 'ddim' or 'unipc' or 'dpm-solver++' +env: + expdir: logs/44k/diffusion + gpu_id: 0 +train: + num_workers: 4 # If your cpu and gpu are both very strong, set to 0 may be faster! + amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu) + batch_size: 48 + cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow + cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu + cache_fp16: true + epochs: 100000 + interval_log: 10 + interval_val: 2000 + interval_force_save: 5000 + lr: 0.0001 + decay_step: 100000 + gamma: 0.5 + weight_decay: 0 + save_opt: false +spk: + 'nyaru': 0 \ No newline at end of file diff --git a/data_utils.py b/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..025fc3b61c784f4d61e27042c1f668adf73f5d5a --- /dev/null +++ b/data_utils.py @@ -0,0 +1,185 @@ +import os +import random + +import numpy as np +import torch +import torch.utils.data + +import utils +from modules.mel_processing import spectrogram_torch +from utils import load_filepaths_and_text, load_wav_to_torch + +# import h5py + + +"""Multi speaker version""" + + +class TextAudioSpeakerLoader(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths, hparams, all_in_mem: bool = False, vol_aug: bool = True): + self.audiopaths = load_filepaths_and_text(audiopaths) + self.hparams = hparams + self.max_wav_value = hparams.data.max_wav_value + self.sampling_rate = hparams.data.sampling_rate + self.filter_length = hparams.data.filter_length + self.hop_length = hparams.data.hop_length + self.win_length = hparams.data.win_length + self.unit_interpolate_mode = hparams.data.unit_interpolate_mode + self.sampling_rate = hparams.data.sampling_rate + self.use_sr = hparams.train.use_sr + self.spec_len = hparams.train.max_speclen + self.spk_map = hparams.spk + self.vol_emb = hparams.model.vol_embedding + self.vol_aug = hparams.train.vol_aug and vol_aug + random.seed(1234) + random.shuffle(self.audiopaths) + + self.all_in_mem = all_in_mem + if self.all_in_mem: + self.cache = [self.get_audio(p[0]) for p in self.audiopaths] + + def get_audio(self, filename): + filename = filename.replace("\\", "/") + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.sampling_rate: + raise ValueError( + "Sample Rate not match. Expect {} but got {} from {}".format( + self.sampling_rate, sampling_rate, filename)) + audio_norm = audio / self.max_wav_value + audio_norm = audio_norm.unsqueeze(0) + spec_filename = filename.replace(".wav", ".spec.pt") + + # Ideally, all data generated after Mar 25 should have .spec.pt + if os.path.exists(spec_filename): + spec = torch.load(spec_filename) + else: + spec = spectrogram_torch(audio_norm, self.filter_length, + self.sampling_rate, self.hop_length, self.win_length, + center=False) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename) + + spk = filename.split("/")[-2] + spk = torch.LongTensor([self.spk_map[spk]]) + + f0, uv = np.load(filename + ".f0.npy",allow_pickle=True) + + f0 = torch.FloatTensor(np.array(f0,dtype=float)) + uv = torch.FloatTensor(np.array(uv,dtype=float)) + + c = torch.load(filename+ ".soft.pt") + c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0], mode=self.unit_interpolate_mode) + if self.vol_emb: + volume_path = filename + ".vol.npy" + volume = np.load(volume_path) + volume = torch.from_numpy(volume).float() + else: + volume = None + + lmin = min(c.size(-1), spec.size(-1)) + assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename) + assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length + spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin] + audio_norm = audio_norm[:, :lmin * self.hop_length] + if volume is not None: + volume = volume[:lmin] + return c, f0, spec, audio_norm, spk, uv, volume + + def random_slice(self, c, f0, spec, audio_norm, spk, uv, volume): + # if spec.shape[1] < 30: + # print("skip too short audio:", filename) + # return None + + if random.choice([True, False]) and self.vol_aug and volume is not None: + max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5 + max_shift = min(1, np.log10(1/max_amp)) + log10_vol_shift = random.uniform(-1, max_shift) + audio_norm = audio_norm * (10 ** log10_vol_shift) + volume = volume * (10 ** log10_vol_shift) + spec = spectrogram_torch(audio_norm, + self.hparams.data.filter_length, + self.hparams.data.sampling_rate, + self.hparams.data.hop_length, + self.hparams.data.win_length, + center=False)[0] + + if spec.shape[1] > 800: + start = random.randint(0, spec.shape[1]-800) + end = start + 790 + spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end] + audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length] + if volume is not None: + volume = volume[start:end] + return c, f0, spec, audio_norm, spk, uv,volume + + def __getitem__(self, index): + if self.all_in_mem: + return self.random_slice(*self.cache[index]) + else: + return self.random_slice(*self.get_audio(self.audiopaths[index][0])) + + def __len__(self): + return len(self.audiopaths) + + +class TextAudioCollate: + + def __call__(self, batch): + batch = [b for b in batch if b is not None] + + input_lengths, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[0].shape[1] for x in batch]), + dim=0, descending=True) + + max_c_len = max([x[0].size(1) for x in batch]) + max_wav_len = max([x[3].size(1) for x in batch]) + + lengths = torch.LongTensor(len(batch)) + + c_padded = torch.FloatTensor(len(batch), batch[0][0].shape[0], max_c_len) + f0_padded = torch.FloatTensor(len(batch), max_c_len) + spec_padded = torch.FloatTensor(len(batch), batch[0][2].shape[0], max_c_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + spkids = torch.LongTensor(len(batch), 1) + uv_padded = torch.FloatTensor(len(batch), max_c_len) + volume_padded = torch.FloatTensor(len(batch), max_c_len) + + c_padded.zero_() + spec_padded.zero_() + f0_padded.zero_() + wav_padded.zero_() + uv_padded.zero_() + volume_padded.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + c = row[0] + c_padded[i, :, :c.size(1)] = c + lengths[i] = c.size(1) + + f0 = row[1] + f0_padded[i, :f0.size(0)] = f0 + + spec = row[2] + spec_padded[i, :, :spec.size(1)] = spec + + wav = row[3] + wav_padded[i, :, :wav.size(1)] = wav + + spkids[i, 0] = row[4] + + uv = row[5] + uv_padded[i, :uv.size(0)] = uv + volume = row[6] + if volume is not None: + volume_padded[i, :volume.size(0)] = volume + else : + volume_padded = None + return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded, volume_padded diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusion/data_loaders.py b/diffusion/data_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..9f00b9afd01565e568e5315dfa49b82dd2ec68ed --- /dev/null +++ b/diffusion/data_loaders.py @@ -0,0 +1,288 @@ +import os +import random + +import librosa +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + +from utils import repeat_expand_2d + + +def traverse_dir( + root_dir, + extensions, + amount=None, + str_include=None, + str_exclude=None, + is_pure=False, + is_sort=False, + is_ext=True): + + file_list = [] + cnt = 0 + for root, _, files in os.walk(root_dir): + for file in files: + if any([file.endswith(f".{ext}") for ext in extensions]): + # path + mix_path = os.path.join(root, file) + pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path + + # amount + if (amount is not None) and (cnt == amount): + if is_sort: + file_list.sort() + return file_list + + # check string + if (str_include is not None) and (str_include not in pure_path): + continue + if (str_exclude is not None) and (str_exclude in pure_path): + continue + + if not is_ext: + ext = pure_path.split('.')[-1] + pure_path = pure_path[:-(len(ext)+1)] + file_list.append(pure_path) + cnt += 1 + if is_sort: + file_list.sort() + return file_list + + +def get_data_loaders(args, whole_audio=False): + data_train = AudioDataset( + filelists = args.data.training_files, + waveform_sec=args.data.duration, + hop_size=args.data.block_size, + sample_rate=args.data.sampling_rate, + load_all_data=args.train.cache_all_data, + whole_audio=whole_audio, + extensions=args.data.extensions, + n_spk=args.model.n_spk, + spk=args.spk, + device=args.train.cache_device, + fp16=args.train.cache_fp16, + unit_interpolate_mode = args.data.unit_interpolate_mode, + use_aug=True) + loader_train = torch.utils.data.DataLoader( + data_train , + batch_size=args.train.batch_size if not whole_audio else 1, + shuffle=True, + num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0, + persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False, + pin_memory=True if args.train.cache_device=='cpu' else False + ) + data_valid = AudioDataset( + filelists = args.data.validation_files, + waveform_sec=args.data.duration, + hop_size=args.data.block_size, + sample_rate=args.data.sampling_rate, + load_all_data=args.train.cache_all_data, + whole_audio=True, + spk=args.spk, + extensions=args.data.extensions, + unit_interpolate_mode = args.data.unit_interpolate_mode, + n_spk=args.model.n_spk) + loader_valid = torch.utils.data.DataLoader( + data_valid, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True + ) + return loader_train, loader_valid + + +class AudioDataset(Dataset): + def __init__( + self, + filelists, + waveform_sec, + hop_size, + sample_rate, + spk, + load_all_data=True, + whole_audio=False, + extensions=['wav'], + n_spk=1, + device='cpu', + fp16=False, + use_aug=False, + unit_interpolate_mode = 'left' + ): + super().__init__() + + self.waveform_sec = waveform_sec + self.sample_rate = sample_rate + self.hop_size = hop_size + self.filelists = filelists + self.whole_audio = whole_audio + self.use_aug = use_aug + self.data_buffer={} + self.pitch_aug_dict = {} + self.unit_interpolate_mode = unit_interpolate_mode + # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item() + if load_all_data: + print('Load all the data filelists:', filelists) + else: + print('Load the f0, volume data filelists:', filelists) + with open(filelists,"r") as f: + self.paths = f.read().splitlines() + for name_ext in tqdm(self.paths, total=len(self.paths)): + path_audio = name_ext + duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate) + + path_f0 = name_ext + ".f0.npy" + f0,_ = np.load(path_f0,allow_pickle=True) + f0 = torch.from_numpy(np.array(f0,dtype=float)).float().unsqueeze(-1).to(device) + + path_volume = name_ext + ".vol.npy" + volume = np.load(path_volume) + volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) + + path_augvol = name_ext + ".aug_vol.npy" + aug_vol = np.load(path_augvol) + aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) + + if n_spk is not None and n_spk > 1: + spk_name = name_ext.split("/")[-2] + spk_id = spk[spk_name] if spk_name in spk else 0 + if spk_id < 0 or spk_id >= n_spk: + raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 0 to n_spk-1 ') + else: + spk_id = 0 + spk_id = torch.LongTensor(np.array([spk_id])).to(device) + + if load_all_data: + ''' + audio, sr = librosa.load(path_audio, sr=self.sample_rate) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + audio = torch.from_numpy(audio).to(device) + ''' + path_mel = name_ext + ".mel.npy" + mel = np.load(path_mel) + mel = torch.from_numpy(mel).to(device) + + path_augmel = name_ext + ".aug_mel.npy" + aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) + aug_mel = np.array(aug_mel,dtype=float) + aug_mel = torch.from_numpy(aug_mel).to(device) + self.pitch_aug_dict[name_ext] = keyshift + + path_units = name_ext + ".soft.pt" + units = torch.load(path_units).to(device) + units = units[0] + units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1) + + if fp16: + mel = mel.half() + aug_mel = aug_mel.half() + units = units.half() + + self.data_buffer[name_ext] = { + 'duration': duration, + 'mel': mel, + 'aug_mel': aug_mel, + 'units': units, + 'f0': f0, + 'volume': volume, + 'aug_vol': aug_vol, + 'spk_id': spk_id + } + else: + path_augmel = name_ext + ".aug_mel.npy" + aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) + self.pitch_aug_dict[name_ext] = keyshift + self.data_buffer[name_ext] = { + 'duration': duration, + 'f0': f0, + 'volume': volume, + 'aug_vol': aug_vol, + 'spk_id': spk_id + } + + + def __getitem__(self, file_idx): + name_ext = self.paths[file_idx] + data_buffer = self.data_buffer[name_ext] + # check duration. if too short, then skip + if data_buffer['duration'] < (self.waveform_sec + 0.1): + return self.__getitem__( (file_idx + 1) % len(self.paths)) + + # get item + return self.get_data(name_ext, data_buffer) + + def get_data(self, name_ext, data_buffer): + name = os.path.splitext(name_ext)[0] + frame_resolution = self.hop_size / self.sample_rate + duration = data_buffer['duration'] + waveform_sec = duration if self.whole_audio else self.waveform_sec + + # load audio + idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) + start_frame = int(idx_from / frame_resolution) + units_frame_len = int(waveform_sec / frame_resolution) + aug_flag = random.choice([True, False]) and self.use_aug + ''' + audio = data_buffer.get('audio') + if audio is None: + path_audio = os.path.join(self.path_root, 'audio', name) + '.wav' + audio, sr = librosa.load( + path_audio, + sr = self.sample_rate, + offset = start_frame * frame_resolution, + duration = waveform_sec) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + # clip audio into N seconds + audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size] + audio = torch.from_numpy(audio).float() + else: + audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size] + ''' + # load mel + mel_key = 'aug_mel' if aug_flag else 'mel' + mel = data_buffer.get(mel_key) + if mel is None: + mel = name_ext + ".mel.npy" + mel = np.load(mel) + mel = mel[start_frame : start_frame + units_frame_len] + mel = torch.from_numpy(mel).float() + else: + mel = mel[start_frame : start_frame + units_frame_len] + + # load f0 + f0 = data_buffer.get('f0') + aug_shift = 0 + if aug_flag: + aug_shift = self.pitch_aug_dict[name_ext] + f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len] + + # load units + units = data_buffer.get('units') + if units is None: + path_units = name_ext + ".soft.pt" + units = torch.load(path_units) + units = units[0] + units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1) + + units = units[start_frame : start_frame + units_frame_len] + + # load volume + vol_key = 'aug_vol' if aug_flag else 'volume' + volume = data_buffer.get(vol_key) + volume_frames = volume[start_frame : start_frame + units_frame_len] + + # load spk_id + spk_id = data_buffer.get('spk_id') + + # load shift + aug_shift = torch.from_numpy(np.array([[aug_shift]])).float() + + return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext) + + def __len__(self): + return len(self.paths) \ No newline at end of file diff --git a/diffusion/diffusion.py b/diffusion/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..646234b5f8d3161ba126055c628a89162f16b0cf --- /dev/null +++ b/diffusion/diffusion.py @@ -0,0 +1,396 @@ +from collections import deque +from functools import partial +from inspect import isfunction + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + def noise(): + return torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=0.02): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +class GaussianDiffusion(nn.Module): + def __init__(self, + denoise_fn, + out_dims=128, + timesteps=1000, + k_step=1000, + max_beta=0.02, + spec_min=-12, + spec_max=2): + + super().__init__() + self.denoise_fn = denoise_fn + self.out_dims = out_dims + betas = beta_schedule['linear'](timesteps, max_beta=max_beta) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.k_step = k_step if k_step>0 and k_step 1: + if method == 'dpm-solver' or method == 'dpm-solver++': + from .dpm_solver_pytorch import ( + DPM_Solver, + NoiseScheduleVP, + model_wrapper, + ) + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define dpm-solver and sample by singlestep DPM-Solver. + # (We recommend singlestep DPM-Solver for unconditional sampling) + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + if method == 'dpm-solver': + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + elif method == 'dpm-solver++': + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = dpm_solver.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() + elif method == 'pndm': + self.noise_list = deque(maxlen=4) + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + elif method == 'ddim': + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_ddim( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_ddim( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + elif method == 'unipc': + from .uni_pc import NoiseScheduleVP, UniPC, model_wrapper + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define uni_pc and sample by multistep UniPC. + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = uni_pc.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() + else: + raise NotImplementedError(method) + else: + if use_tqdm: + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + else: + for i in reversed(range(0, t)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x.squeeze(1).transpose(1, 2) # [B, T, M] + return self.denorm_spec(x) + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min diff --git a/diffusion/diffusion_onnx.py b/diffusion/diffusion_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..f01e463515bd6dccd02fe49b1db1f5af64fc746b --- /dev/null +++ b/diffusion/diffusion_onnx.py @@ -0,0 +1,614 @@ +import math +from collections import deque +from functools import partial +from inspect import isfunction + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Conv1d, Mish +from tqdm import tqdm + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def extract(a, t): + return a[t].reshape((1, 1, 1, 1)) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + def noise(): + return torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=0.02): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +def extract_1(a, t): + return a[t].reshape((1, 1, 1, 1)) + + +def predict_stage0(noise_pred, noise_pred_prev): + return (noise_pred + noise_pred_prev) / 2 + + +def predict_stage1(noise_pred, noise_list): + return (noise_pred * 3 + - noise_list[-1]) / 2 + + +def predict_stage2(noise_pred, noise_list): + return (noise_pred * 23 + - noise_list[-1] * 16 + + noise_list[-2] * 5) / 12 + + +def predict_stage3(noise_pred, noise_list): + return (noise_pred * 55 + - noise_list[-1] * 59 + + noise_list[-2] * 37 + - noise_list[-3] * 9) / 24 + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.half_dim = dim // 2 + self.emb = 9.21034037 / (self.half_dim - 1) + self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0) + self.emb = self.emb.cpu() + + def forward(self, x): + emb = self.emb * x + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.residual_channels = residual_channels + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + y = self.dilated_conv(y) + conditioner + + gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + + y = torch.sigmoid(gate) * torch.tanh(filter_1) + y = self.output_projection(y) + + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + + return (x + residual) / 1.41421356, skip + + +class DiffNet(nn.Module): + def __init__(self, in_dims, n_layers, n_chans, n_hidden): + super().__init__() + self.encoder_hidden = n_hidden + self.residual_layers = n_layers + self.residual_channels = n_chans + self.input_projection = Conv1d(in_dims, self.residual_channels, 1) + self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) + dim = self.residual_channels + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock(self.encoder_hidden, self.residual_channels, 1) + for i in range(self.residual_layers) + ]) + self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1) + self.output_projection = Conv1d(self.residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + x = spec.squeeze(0) + x = self.input_projection(x) # x [B, residual_channel, T] + x = F.relu(x) + # skip = torch.randn_like(x) + diffusion_step = diffusion_step.float() + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + + x, skip = self.residual_layers[0](x, cond, diffusion_step) + # noinspection PyTypeChecker + for layer in self.residual_layers[1:]: + x, skip_connection = layer.forward(x, cond, diffusion_step) + skip = skip + skip_connection + x = skip / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, 80, T] + return x.unsqueeze(1) + + +class AfterDiffusion(nn.Module): + def __init__(self, spec_max, spec_min, v_type='a'): + super().__init__() + self.spec_max = spec_max + self.spec_min = spec_min + self.type = v_type + + def forward(self, x): + x = x.squeeze(1).permute(0, 2, 1) + mel_out = (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + if self.type == 'nsf-hifigan-log10': + mel_out = mel_out * 0.434294 + return mel_out.transpose(2, 1) + + +class Pred(nn.Module): + def __init__(self, alphas_cumprod): + super().__init__() + self.alphas_cumprod = alphas_cumprod + + def forward(self, x_1, noise_t, t_1, t_prev): + a_t = extract(self.alphas_cumprod, t_1).cpu() + a_prev = extract(self.alphas_cumprod, t_prev).cpu() + a_t_sq, a_prev_sq = a_t.sqrt().cpu(), a_prev.sqrt().cpu() + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x_1 + x_delta.cpu() + + return x_pred + + +class GaussianDiffusion(nn.Module): + def __init__(self, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256, + timesteps=1000, + k_step=1000, + max_beta=0.02, + spec_min=-12, + spec_max=2): + super().__init__() + self.denoise_fn = DiffNet(out_dims, n_layers, n_chans, n_hidden) + self.out_dims = out_dims + self.mel_bins = out_dims + self.n_hidden = n_hidden + betas = beta_schedule['linear'](timesteps, max_beta=max_beta) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.k_step = k_step + + self.noise_list = deque(maxlen=4) + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims]) + self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims]) + self.ad = AfterDiffusion(self.spec_max, self.spec_min) + self.xp = Pred(self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): + """ + Use the PLMS method from + [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). + """ + + def get_x_pred(x, noise_t, t): + a_t = extract(self.alphas_cumprod, t) + a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t))) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / ( + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x + x_delta + + return x_pred + + noise_list = self.noise_list + noise_pred = self.denoise_fn(x, t, cond=cond) + + if len(noise_list) == 0: + x_pred = get_x_pred(x, noise_pred, t) + noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond) + noise_pred_prime = (noise_pred + noise_pred_prev) / 2 + elif len(noise_list) == 1: + noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 + elif len(noise_list) == 2: + noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 + else: + noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 + + x_prev = get_x_pred(x, noise_pred_prime, t) + noise_list.append(noise_pred) + + return x_prev + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if loss_type == 'l1': + loss = (noise - x_recon).abs().mean() + elif loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def org_forward(self, + condition, + init_noise=None, + gt_spec=None, + infer=True, + infer_speedup=100, + method='pndm', + k_step=1000, + use_tqdm=True): + """ + conditioning diffusion, use fastspeech2 encoder output as the condition + """ + cond = condition + b, device = condition.shape[0], condition.device + if not infer: + spec = self.norm_spec(gt_spec) + t = torch.randint(0, self.k_step, (b,), device=device).long() + norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + return self.p_losses(norm_spec, t, cond=cond) + else: + shape = (cond.shape[0], 1, self.out_dims, cond.shape[2]) + + if gt_spec is None: + t = self.k_step + if init_noise is None: + x = torch.randn(shape, device=device) + else: + x = init_noise + else: + t = k_step + norm_spec = self.norm_spec(gt_spec) + norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] + x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long()) + + if method is not None and infer_speedup > 1: + if method == 'dpm-solver': + from .dpm_solver_pytorch import ( + DPM_Solver, + NoiseScheduleVP, + model_wrapper, + ) + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define dpm-solver and sample by singlestep DPM-Solver. + # (We recommend singlestep DPM-Solver for unconditional sampling) + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + dpm_solver = DPM_Solver(model_fn, noise_schedule) + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = dpm_solver.sample( + x, + steps=steps, + order=3, + skip_type="time_uniform", + method="singlestep", + ) + if use_tqdm: + self.bar.close() + elif method == 'pndm': + self.noise_list = deque(maxlen=4) + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_plms( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + raise NotImplementedError(method) + else: + if use_tqdm: + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + else: + for i in reversed(range(0, t)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x.squeeze(1).transpose(1, 2) # [B, T, M] + return self.denorm_spec(x).transpose(2, 1) + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def get_x_pred(self, x_1, noise_t, t_1, t_prev): + a_t = extract(self.alphas_cumprod, t_1) + a_prev = extract(self.alphas_cumprod, t_prev) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x_1 + x_delta + return x_pred + + def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True): + cond = torch.randn([1, self.n_hidden, 10]).cpu() + if init_noise is None: + x = torch.randn((1, 1, self.mel_bins, cond.shape[2]), dtype=torch.float32).cpu() + else: + x = init_noise + pndms = 100 + + org_y_x = self.org_forward(cond, init_noise=x) + + device = cond.device + n_frames = cond.shape[2] + step_range = torch.arange(0, self.k_step, pndms, dtype=torch.long, device=device).flip(0) + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + + ot = step_range[0] + ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) + if export_denoise: + torch.onnx.export( + self.denoise_fn, + (x.cpu(), ot_1.cpu(), cond.cpu()), + f"{project_name}_denoise.onnx", + input_names=["noise", "time", "condition"], + output_names=["noise_pred"], + dynamic_axes={ + "noise": [3], + "condition": [2] + }, + opset_version=16 + ) + + for t in step_range: + t_1 = torch.full((1,), t, device=device, dtype=torch.long) + noise_pred = self.denoise_fn(x, t_1, cond) + t_prev = t_1 - pndms + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + if export_pred: + torch.onnx.export( + self.xp, + (x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()), + f"{project_name}_pred.onnx", + input_names=["noise", "noise_pred", "time", "time_prev"], + output_names=["noise_pred_o"], + dynamic_axes={ + "noise": [3], + "noise_pred": [3] + }, + opset_version=16 + ) + + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) + + elif plms_noise_stage == 1: + noise_pred_prime = predict_stage1(noise_pred, noise_list) + + elif plms_noise_stage == 2: + noise_pred_prime = predict_stage2(noise_pred, noise_list) + + else: + noise_pred_prime = predict_stage3(noise_pred, noise_list) + + noise_pred = noise_pred.unsqueeze(0) + + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) + if export_after: + torch.onnx.export( + self.ad, + x.cpu(), + f"{project_name}_after.onnx", + input_names=["x"], + output_names=["mel_out"], + dynamic_axes={ + "x": [3] + }, + opset_version=16 + ) + x = self.ad(x) + + print((x == org_y_x).all()) + return x + + def forward(self, condition=None, init_noise=None, pndms=None, k_step=None): + cond = condition + x = init_noise + + device = cond.device + n_frames = cond.shape[2] + step_range = torch.arange(0, k_step.item(), pndms.item(), dtype=torch.long, device=device).flip(0) + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + + for t in step_range: + t_1 = torch.full((1,), t, device=device, dtype=torch.long) + noise_pred = self.denoise_fn(x, t_1, cond) + t_prev = t_1 - pndms + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) + + elif plms_noise_stage == 1: + noise_pred_prime = predict_stage1(noise_pred, noise_list) + + elif plms_noise_stage == 2: + noise_pred_prime = predict_stage2(noise_pred, noise_list) + + else: + noise_pred_prime = predict_stage3(noise_pred, noise_list) + + noise_pred = noise_pred.unsqueeze(0) + + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) + x = self.ad(x) + return x diff --git a/diffusion/dpm_solver_pytorch.py b/diffusion/dpm_solver_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..83ed73e22d37cb8ef224425dcfd6bb3dcba74578 --- /dev/null +++ b/diffusion/dpm_solver_pytorch.py @@ -0,0 +1,1307 @@ +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1. + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1. + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/diffusion/how to export onnx.md b/diffusion/how to export onnx.md new file mode 100644 index 0000000000000000000000000000000000000000..5aae72cadb25fea4c592f0fb908daf5b6ee1c5f4 --- /dev/null +++ b/diffusion/how to export onnx.md @@ -0,0 +1,4 @@ +- Open [onnx_export](onnx_export.py) +- project_name = "dddsp" change "project_name" to your project name +- model_path = f'{project_name}/model_500000.pt' change "model_path" to your model path +- Run \ No newline at end of file diff --git a/diffusion/infer_gt_mel.py b/diffusion/infer_gt_mel.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdf1fed45a7df3eef954c24fd0d4b0dcb0fc1b8 --- /dev/null +++ b/diffusion/infer_gt_mel.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F + +from diffusion.unit2mel import load_model_vocoder + + +class DiffGtMel: + def __init__(self, project_path=None, device=None): + self.project_path = project_path + if device is not None: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = None + self.vocoder = None + self.args = None + + def flush_model(self, project_path, ddsp_config=None): + if (self.model is None) or (project_path != self.project_path): + model, vocoder, args = load_model_vocoder(project_path, device=self.device) + if self.check_args(ddsp_config, args): + self.model = model + self.vocoder = vocoder + self.args = args + + def check_args(self, args1, args2): + if args1.data.block_size != args2.data.block_size: + raise ValueError("DDSP与DIFF模型的block_size不一致") + if args1.data.sampling_rate != args2.data.sampling_rate: + raise ValueError("DDSP与DIFF模型的sampling_rate不一致") + if args1.data.encoder != args2.data.encoder: + raise ValueError("DDSP与DIFF模型的encoder不一致") + return True + + def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', + spk_mix_dict=None, start_frame=0): + input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate) + out_mel = self.model( + hubert, + f0, + volume, + spk_id=spk_id, + spk_mix_dict=spk_mix_dict, + gt_spec=input_mel, + infer=True, + infer_speedup=acc, + method=method, + k_step=k_step, + use_tqdm=False) + if start_frame > 0: + out_mel = out_mel[:, start_frame:, :] + f0 = f0[:, start_frame:, :] + output = self.vocoder.infer(out_mel, f0) + if start_frame > 0: + output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0)) + return output + + def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', silence_front=0, + use_silence=False, spk_mix_dict=None): + start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) + if use_silence: + audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:] + f0 = f0[:, start_frame:, :] + hubert = hubert[:, start_frame:, :] + volume = volume[:, start_frame:, :] + _start_frame = 0 + else: + _start_frame = start_frame + audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, + method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame) + if use_silence: + if start_frame > 0: + audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0)) + return audio diff --git a/diffusion/logger/__init__.py b/diffusion/logger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusion/logger/saver.py b/diffusion/logger/saver.py new file mode 100644 index 0000000000000000000000000000000000000000..954ce99b37f6c983999d4f7e1b08dcc5b7d99bc4 --- /dev/null +++ b/diffusion/logger/saver.py @@ -0,0 +1,145 @@ +''' +author: wayn391@mastertones +''' + +import datetime +import os +import time + +import matplotlib.pyplot as plt +import torch +import yaml +from torch.utils.tensorboard import SummaryWriter + + +class Saver(object): + def __init__( + self, + args, + initial_global_step=-1): + + self.expdir = args.env.expdir + self.sample_rate = args.data.sampling_rate + + # cold start + self.global_step = initial_global_step + self.init_time = time.time() + self.last_time = time.time() + + # makedirs + os.makedirs(self.expdir, exist_ok=True) + + # path + self.path_log_info = os.path.join(self.expdir, 'log_info.txt') + + # ckpt + os.makedirs(self.expdir, exist_ok=True) + + # writer + self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) + + # save config + path_config = os.path.join(self.expdir, 'config.yaml') + with open(path_config, "w") as out_config: + yaml.dump(dict(args), out_config) + + + def log_info(self, msg): + '''log method''' + if isinstance(msg, dict): + msg_list = [] + for k, v in msg.items(): + tmp_str = '' + if isinstance(v, int): + tmp_str = '{}: {:,}'.format(k, v) + else: + tmp_str = '{}: {}'.format(k, v) + + msg_list.append(tmp_str) + msg_str = '\n'.join(msg_list) + else: + msg_str = msg + + # dsplay + print(msg_str) + + # save + with open(self.path_log_info, 'a') as fp: + fp.write(msg_str+'\n') + + def log_value(self, dict): + for k, v in dict.items(): + self.writer.add_scalar(k, v, self.global_step) + + def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): + spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) + spec = spec_cat[0] + if isinstance(spec, torch.Tensor): + spec = spec.cpu().numpy() + fig = plt.figure(figsize=(12, 9)) + plt.pcolor(spec.T, vmin=vmin, vmax=vmax) + plt.tight_layout() + self.writer.add_figure(name, fig, self.global_step) + + def log_audio(self, dict): + for k, v in dict.items(): + self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) + + def get_interval_time(self, update=True): + cur_time = time.time() + time_interval = cur_time - self.last_time + if update: + self.last_time = cur_time + return time_interval + + def get_total_time(self, to_str=True): + total_time = time.time() - self.init_time + if to_str: + total_time = str(datetime.timedelta( + seconds=total_time))[:-5] + return total_time + + def save_model( + self, + model, + optimizer, + name='model', + postfix='', + to_json=False): + # path + if postfix: + postfix = '_' + postfix + path_pt = os.path.join( + self.expdir , name+postfix+'.pt') + + # check + print(' [*] model checkpoint saved: {}'.format(path_pt)) + + # save + if optimizer is not None: + torch.save({ + 'global_step': self.global_step, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict()}, path_pt) + else: + torch.save({ + 'global_step': self.global_step, + 'model': model.state_dict()}, path_pt) + + + def delete_model(self, name='model', postfix=''): + # path + if postfix: + postfix = '_' + postfix + path_pt = os.path.join( + self.expdir , name+postfix+'.pt') + + # delete + if os.path.exists(path_pt): + os.remove(path_pt) + print(' [*] model checkpoint deleted: {}'.format(path_pt)) + + def global_step_increment(self): + self.global_step += 1 + + diff --git a/diffusion/logger/utils.py b/diffusion/logger/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a907de7dc4ece50746f87f92ee1985d50d34fcbb --- /dev/null +++ b/diffusion/logger/utils.py @@ -0,0 +1,127 @@ +import json +import os + +import torch +import yaml + + +def traverse_dir( + root_dir, + extensions, + amount=None, + str_include=None, + str_exclude=None, + is_pure=False, + is_sort=False, + is_ext=True): + + file_list = [] + cnt = 0 + for root, _, files in os.walk(root_dir): + for file in files: + if any([file.endswith(f".{ext}") for ext in extensions]): + # path + mix_path = os.path.join(root, file) + pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path + + # amount + if (amount is not None) and (cnt == amount): + if is_sort: + file_list.sort() + return file_list + + # check string + if (str_include is not None) and (str_include not in pure_path): + continue + if (str_exclude is not None) and (str_exclude in pure_path): + continue + + if not is_ext: + ext = pure_path.split('.')[-1] + pure_path = pure_path[:-(len(ext)+1)] + file_list.append(pure_path) + cnt += 1 + if is_sort: + file_list.sort() + return file_list + + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def get_network_paras_amount(model_dict): + info = dict() + for model_name, model in model_dict.items(): + # all_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + info[model_name] = trainable_params + return info + + +def load_config(path_config): + with open(path_config, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + # print(args) + return args + +def save_config(path_config,config): + config = dict(config) + with open(path_config, "w") as f: + yaml.dump(config, f) + +def to_json(path_params, path_json): + params = torch.load(path_params, map_location=torch.device('cpu')) + raw_state_dict = {} + for k, v in params.items(): + val = v.flatten().numpy().tolist() + raw_state_dict[k] = val + + with open(path_json, 'w') as outfile: + json.dump(raw_state_dict, outfile,indent= "\t") + + +def convert_tensor_to_numpy(tensor, is_squeeze=True): + if is_squeeze: + tensor = tensor.squeeze() + if tensor.requires_grad: + tensor = tensor.detach() + if tensor.is_cuda: + tensor = tensor.cpu() + return tensor.numpy() + + +def load_model( + expdir, + model, + optimizer, + name='model', + postfix='', + device='cpu'): + if postfix == '': + postfix = '_' + postfix + path = os.path.join(expdir, name+postfix) + path_pt = traverse_dir(expdir, ['pt'], is_ext=False) + global_step = 0 + if len(path_pt) > 0: + steps = [s[len(path):] for s in path_pt] + maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) + if maxstep >= 0: + path_pt = path+str(maxstep)+'.pt' + else: + path_pt = path+'best.pt' + print(' [*] restoring model from', path_pt) + ckpt = torch.load(path_pt, map_location=torch.device(device)) + global_step = ckpt['global_step'] + model.load_state_dict(ckpt['model'], strict=False) + if ckpt.get("optimizer") is not None: + optimizer.load_state_dict(ckpt['optimizer']) + return global_step, model, optimizer diff --git a/diffusion/onnx_export.py b/diffusion/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4ea2274b60ad1c7bb48596e493df46e18ef539 --- /dev/null +++ b/diffusion/onnx_export.py @@ -0,0 +1,235 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import yaml +from diffusion_onnx import GaussianDiffusion + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model_vocoder( + model_path, + device='cpu'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load model + model = Unit2Mel( + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + 128, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max) + + print(' [Loading] ' + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt['model']) + model.eval() + return model, args + + +class Unit2Mel(nn.Module): + def __init__( + self, + input_channel, + n_spk, + use_pitch_aug=False, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256, + timesteps=1000, + k_step_max=1000): + super().__init__() + + self.unit_embed = nn.Linear(input_channel, n_hidden) + self.f0_embed = nn.Linear(1, n_hidden) + self.volume_embed = nn.Linear(1, n_hidden) + if use_pitch_aug: + self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) + else: + self.aug_shift_embed = None + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, n_hidden) + + self.timesteps = timesteps if timesteps is not None else 1000 + self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max 1: # [N, S] * [S, B, 1, H] + g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + x = x.transpose(1, 2) + g + return x + else: + return x.transpose(1, 2) + + + def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, + gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + spk_embed_mix = torch.zeros((1,1,self.hidden_size)) + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + spk_embeddd = self.spk_embed(spk_id_torch) + self.speaker_map[k] = spk_embeddd + spk_embed_mix = spk_embed_mix + v * spk_embeddd + x = x + spk_embed_mix + else: + x = x + self.spk_embed(spk_id - 1) + self.speaker_map = self.speaker_map.unsqueeze(0) + self.speaker_map = self.speaker_map.detach() + return x.transpose(1, 2) + + def OnnxExport(self, project_name=None, init_noise=None, export_encoder=True, export_denoise=True, export_pred=True, export_after=True): + hubert_hidden_size = 768 + n_frames = 100 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spk_mix = [] + spks = {} + if self.n_spk is not None and self.n_spk > 1: + for i in range(self.n_spk): + spk_mix.append(1.0/float(self.n_spk)) + spks.update({i:1.0/float(self.n_spk)}) + spk_mix = torch.tensor(spk_mix) + spk_mix = spk_mix.repeat(n_frames, 1) + self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + self.forward(hubert, mel2ph, f0, volume, spk_mix) + if export_encoder: + torch.onnx.export( + self, + (hubert, mel2ph, f0, volume, spk_mix), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], + output_names=["mel_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "volume": [1], + "mel2ph": [1], + "spk_mix": [0], + }, + opset_version=16 + ) + + self.decoder.OnnxExport(project_name, init_noise=init_noise, export_denoise=export_denoise, export_pred=export_pred, export_after=export_after) + + def ExportOnnx(self, project_name=None): + hubert_hidden_size = 768 + n_frames = 100 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + mel2ph = torch.arange(end=n_frames).unsqueeze(0).long() + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spk_mix = [] + spks = {} + if self.n_spk is not None and self.n_spk > 1: + for i in range(self.n_spk): + spk_mix.append(1.0/float(self.n_spk)) + spks.update({i:1.0/float(self.n_spk)}) + spk_mix = torch.tensor(spk_mix) + self.orgforward(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + self.forward(hubert, mel2ph, f0, volume, spk_mix) + + torch.onnx.export( + self, + (hubert, mel2ph, f0, volume, spk_mix), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "f0", "volume", "spk_mix"], + output_names=["mel_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "volume": [1], + "mel2ph": [1] + }, + opset_version=16 + ) + + condition = torch.randn(1,self.decoder.n_hidden,n_frames) + noise = torch.randn((1, 1, self.decoder.mel_bins, condition.shape[2]), dtype=torch.float32) + pndm_speedup = torch.LongTensor([100]) + K_steps = torch.LongTensor([1000]) + self.decoder = torch.jit.script(self.decoder) + self.decoder(condition, noise, pndm_speedup, K_steps) + + torch.onnx.export( + self.decoder, + (condition, noise, pndm_speedup, K_steps), + f"{project_name}_diffusion.onnx", + input_names=["condition", "noise", "pndm_speedup", "K_steps"], + output_names=["mel"], + dynamic_axes={ + "condition": [2], + "noise": [3], + }, + opset_version=16 + ) + + +if __name__ == "__main__": + project_name = "dddsp" + model_path = f'{project_name}/model_500000.pt' + + model, _ = load_model_vocoder(model_path) + + # 分开Diffusion导出(需要使用MoeSS/MoeVoiceStudio或者自己编写Pndm/Dpm采样) + model.OnnxExport(project_name, export_encoder=True, export_denoise=True, export_pred=True, export_after=True) + + # 合并Diffusion导出(Encoder和Diffusion分开,直接将Encoder的结果和初始噪声输入Diffusion即可) + # model.ExportOnnx(project_name) + diff --git a/diffusion/solver.py b/diffusion/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..52657cc5ab39216c6f721676c89853a089d29fdd --- /dev/null +++ b/diffusion/solver.py @@ -0,0 +1,200 @@ +import time + +import librosa +import numpy as np +import torch +from torch import autocast +from torch.cuda.amp import GradScaler + +from diffusion.logger import utils +from diffusion.logger.saver import Saver + + +def test(args, model, vocoder, loader_test, saver): + print(' [*] testing...') + model.eval() + + # losses + test_loss = 0. + + # intialization + num_batches = len(loader_test) + rtf_all = [] + + # run + with torch.no_grad(): + for bidx, data in enumerate(loader_test): + fn = data['name'][0].split("/")[-1] + speaker = data['name'][0].split("/")[-2] + print('--------') + print('{}/{} - {}'.format(bidx, num_batches, fn)) + + # unpack data + for k in data.keys(): + if not k.startswith('name'): + data[k] = data[k].to(args.device) + print('>>', data['name'][0]) + + # forward + st_time = time.time() + mel = model( + data['units'], + data['f0'], + data['volume'], + data['spk_id'], + gt_spec=None if model.k_step_max == model.timesteps else data['mel'], + infer=True, + infer_speedup=args.infer.speedup, + method=args.infer.method, + k_step=model.k_step_max + ) + signal = vocoder.infer(mel, data['f0']) + ed_time = time.time() + + # RTF + run_time = ed_time - st_time + song_time = signal.shape[-1] / args.data.sampling_rate + rtf = run_time / song_time + print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) + rtf_all.append(rtf) + + # loss + for i in range(args.train.batch_size): + loss = model( + data['units'], + data['f0'], + data['volume'], + data['spk_id'], + gt_spec=data['mel'], + infer=False, + k_step=model.k_step_max) + test_loss += loss.item() + + # log mel + saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel) + + # log audi + path_audio = data['name_ext'][0] + audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio) + audio = torch.from_numpy(audio).unsqueeze(0).to(signal) + saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal}) + # report + test_loss /= args.train.batch_size + test_loss /= num_batches + + # check + print(' [test_loss] test_loss:', test_loss) + print(' Real Time Factor', np.mean(rtf_all)) + return test_loss + + +def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): + # saver + saver = Saver(args, initial_global_step=initial_global_step) + + # model size + params_count = utils.get_network_paras_amount({'model': model}) + saver.log_info('--- model size ---') + saver.log_info(params_count) + + # run + num_batches = len(loader_train) + model.train() + saver.log_info('======= start training =======') + scaler = GradScaler() + if args.train.amp_dtype == 'fp32': + dtype = torch.float32 + elif args.train.amp_dtype == 'fp16': + dtype = torch.float16 + elif args.train.amp_dtype == 'bf16': + dtype = torch.bfloat16 + else: + raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) + saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step") + for epoch in range(args.train.epochs): + for batch_idx, data in enumerate(loader_train): + saver.global_step_increment() + optimizer.zero_grad() + + # unpack data + for k in data.keys(): + if not k.startswith('name'): + data[k] = data[k].to(args.device) + + # forward + if dtype == torch.float32: + loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], + aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max) + else: + with autocast(device_type=args.device, dtype=dtype): + loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], + aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max) + + # handle nan loss + if torch.isnan(loss): + raise ValueError(' [x] nan loss ') + else: + # backpropagate + if dtype == torch.float32: + loss.backward() + optimizer.step() + else: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + scheduler.step() + + # log loss + if saver.global_step % args.train.interval_log == 0: + current_lr = optimizer.param_groups[0]['lr'] + saver.log_info( + 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( + epoch, + batch_idx, + num_batches, + args.env.expdir, + args.train.interval_log/saver.get_interval_time(), + current_lr, + loss.item(), + saver.get_total_time(), + saver.global_step + ) + ) + + saver.log_value({ + 'train/loss': loss.item() + }) + + saver.log_value({ + 'train/lr': current_lr + }) + + # validation + if saver.global_step % args.train.interval_val == 0: + optimizer_save = optimizer if args.train.save_opt else None + + # save latest + saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') + last_val_step = saver.global_step - args.train.interval_val + if last_val_step % args.train.interval_force_save != 0: + saver.delete_model(postfix=f'{last_val_step}') + + # run testing set + test_loss = test(args, model, vocoder, loader_test, saver) + + # log loss + saver.log_info( + ' --- --- \nloss: {:.3f}. '.format( + test_loss, + ) + ) + + saver.log_value({ + 'validation/loss': test_loss + }) + + model.train() + + diff --git a/diffusion/uni_pc.py b/diffusion/uni_pc.py new file mode 100644 index 0000000000000000000000000000000000000000..72d8f518f2eb2cc639bd72bcaf29b8a09075d87e --- /dev/null +++ b/diffusion/uni_pc.py @@ -0,0 +1,733 @@ +import math + +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + def t_fn(log_alpha_t): + return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2.0 * (1.0 + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="data_prediction", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + variant='bh1' + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["data_prediction", "noise_prediction"] + + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + self.variant = variant + self.predict_x0 = algorithm_type == "data_prediction" + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + #print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.cat(b) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + #print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + torch.exp(log_alpha_t - log_alpha_prev_0) * x + - sigma_t * h_phi_1 * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # Init the first `order` values by lower order multistep UniPC. + for step in range(1, order): + t = timesteps[step] + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(model_x) + + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + if step == steps: + #print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, t) + model_prev_list[-1] = model_x + else: + raise ValueError("Got wrong method {}".format(method)) + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/diffusion/unit2mel.py b/diffusion/unit2mel.py new file mode 100644 index 0000000000000000000000000000000000000000..5087f2a512aba1d265d82c644ac6c9859a34d422 --- /dev/null +++ b/diffusion/unit2mel.py @@ -0,0 +1,167 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +import yaml + +from .diffusion import GaussianDiffusion +from .vocoder import Vocoder +from .wavenet import WaveNet + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def load_model_vocoder( + model_path, + device='cpu', + config_path = None + ): + if config_path is None: + config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') + else: + config_file = config_path + + with open(config_file, "r") as config: + args = yaml.safe_load(config) + args = DotDict(args) + + # load vocoder + vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device) + + # load model + model = Unit2Mel( + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + vocoder.dimension, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max + ) + + print(' [Loading] ' + model_path) + ckpt = torch.load(model_path, map_location=torch.device(device)) + model.to(device) + model.load_state_dict(ckpt['model']) + model.eval() + print(f'Loaded diffusion model, sampler is {args.infer.method}, speedup: {args.infer.speedup} ') + return model, vocoder, args + + +class Unit2Mel(nn.Module): + def __init__( + self, + input_channel, + n_spk, + use_pitch_aug=False, + out_dims=128, + n_layers=20, + n_chans=384, + n_hidden=256, + timesteps=1000, + k_step_max=1000 + ): + super().__init__() + self.unit_embed = nn.Linear(input_channel, n_hidden) + self.f0_embed = nn.Linear(1, n_hidden) + self.volume_embed = nn.Linear(1, n_hidden) + if use_pitch_aug: + self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) + else: + self.aug_shift_embed = None + self.n_spk = n_spk + if n_spk is not None and n_spk > 1: + self.spk_embed = nn.Embedding(n_spk, n_hidden) + + self.timesteps = timesteps if timesteps is not None else 1000 + self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max 1: + if spk_mix_dict is not None: + spk_embed_mix = torch.zeros((1,1,self.hidden_size)) + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + spk_embeddd = self.spk_embed(spk_id_torch) + self.speaker_map[k] = spk_embeddd + spk_embed_mix = spk_embed_mix + v * spk_embeddd + x = x + spk_embed_mix + else: + x = x + self.spk_embed(spk_id - 1) + self.speaker_map = self.speaker_map.unsqueeze(0) + self.speaker_map = self.speaker_map.detach() + return x.transpose(1, 2) + + def init_spkmix(self, n_spk): + self.speaker_map = torch.zeros((n_spk,1,1,self.n_hidden)) + hubert_hidden_size = self.input_channel + n_frames = 10 + hubert = torch.randn((1, n_frames, hubert_hidden_size)) + f0 = torch.randn((1, n_frames)) + volume = torch.randn((1, n_frames)) + spks = {} + for i in range(n_spk): + spks.update({i:1.0/float(self.n_spk)}) + self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks) + + def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None, + gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True): + + ''' + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + ''' + + if not self.training and gt_spec is not None and k_step>self.k_step_max: + raise Exception("The shallow diffusion k_step is greater than the maximum diffusion k_step(k_step_max)!") + + if not self.training and gt_spec is None and self.k_step_max!=self.timesteps: + raise Exception("This model can only be used for shallow diffusion and can not infer alone!") + + x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) + if self.n_spk is not None and self.n_spk > 1: + if spk_mix_dict is not None: + for k, v in spk_mix_dict.items(): + spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) + x = x + v * self.spk_embed(spk_id_torch) + else: + if spk_id.shape[1] > 1: + g = spk_id.reshape((spk_id.shape[0], spk_id.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + x = x + g + else: + x = x + self.spk_embed(spk_id) + if self.aug_shift_embed is not None and aug_shift is not None: + x = x + self.aug_shift_embed(aug_shift / 5) + x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm) + + return x + diff --git a/diffusion/vocoder.py b/diffusion/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ec9c80e65d160cd9364a7e2106fd0a94a814e3c9 --- /dev/null +++ b/diffusion/vocoder.py @@ -0,0 +1,95 @@ +import torch +from torchaudio.transforms import Resample + +from vdecoder.nsf_hifigan.models import load_config, load_model +from vdecoder.nsf_hifigan.nvSTFT import STFT + + +class Vocoder: + def __init__(self, vocoder_type, vocoder_ckpt, device = None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + + if vocoder_type == 'nsf-hifigan': + self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device) + elif vocoder_type == 'nsf-hifigan-log10': + self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device) + else: + raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") + + self.resample_kernel = {} + self.vocoder_sample_rate = self.vocoder.sample_rate() + self.vocoder_hop_size = self.vocoder.hop_size() + self.dimension = self.vocoder.dimension() + + def extract(self, audio, sample_rate, keyshift=0): + + # resample + if sample_rate == self.vocoder_sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + # extract + mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins + return mel + + def infer(self, mel, f0): + f0 = f0[:,:mel.size(1),0] # B, n_frames + audio = self.vocoder(mel, f0) + return audio + + +class NsfHifiGAN(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.model_path = model_path + self.model = None + self.h = load_config(model_path) + self.stft = STFT( + self.h.sampling_rate, + self.h.num_mels, + self.h.n_fft, + self.h.win_size, + self.h.hop_size, + self.h.fmin, + self.h.fmax) + + def sample_rate(self): + return self.h.sampling_rate + + def hop_size(self): + return self.h.hop_size + + def dimension(self): + return self.h.num_mels + + def extract(self, audio, keyshift=0): + mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins + return mel + + def forward(self, mel, f0): + if self.model is None: + print('| Load HifiGAN: ', self.model_path) + self.model, self.h = load_model(self.model_path, device=self.device) + with torch.no_grad(): + c = mel.transpose(1, 2) + audio = self.model(c, f0) + return audio + +class NsfHifiGANLog10(NsfHifiGAN): + def forward(self, mel, f0): + if self.model is None: + print('| Load HifiGAN: ', self.model_path) + self.model, self.h = load_model(self.model_path, device=self.device) + with torch.no_grad(): + c = 0.434294 * mel.transpose(1, 2) + audio = self.model(c, f0) + return audio \ No newline at end of file diff --git a/diffusion/wavenet.py b/diffusion/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..3d48c7eaaa0e8191b27a5d1890eb657cbcc0d143 --- /dev/null +++ b/diffusion/wavenet.py @@ -0,0 +1,108 @@ +import math +from math import sqrt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Mish + + +class Conv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.residual_channels = residual_channels + self.dilated_conv = nn.Conv1d( + residual_channels, + 2 * residual_channels, + kernel_size=3, + padding=dilation, + dilation=dilation + ) + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) + self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + # Using torch.split instead of torch.chunk to avoid using onnx::Slice + gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + + # Using torch.split instead of torch.chunk to avoid using onnx::Slice + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) + return (x + residual) / math.sqrt(2.0), skip + + +class WaveNet(nn.Module): + def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): + super().__init__() + self.input_projection = Conv1d(in_dims, n_chans, 1) + self.diffusion_embedding = SinusoidalPosEmb(n_chans) + self.mlp = nn.Sequential( + nn.Linear(n_chans, n_chans * 4), + Mish(), + nn.Linear(n_chans * 4, n_chans) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock( + encoder_hidden=n_hidden, + residual_channels=n_chans, + dilation=1 + ) + for i in range(n_layers) + ]) + self.skip_projection = Conv1d(n_chans, n_chans, 1) + self.output_projection = Conv1d(n_chans, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + :param spec: [B, 1, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec.squeeze(1) + x = self.input_projection(x) # [B, residual_channel, T] + + x = F.relu(x) + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + skip = [] + for layer in self.residual_layers: + x, skip_connection = layer(x, cond, diffusion_step) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, mel_bins, T] + return x[:, None, :, :] diff --git a/edgetts/tts.py b/edgetts/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..c850cbc16eed0703fad5b1b07b4704602a0028e7 --- /dev/null +++ b/edgetts/tts.py @@ -0,0 +1,47 @@ +import asyncio +import random +import sys + +import edge_tts +from edge_tts import VoicesManager +from langdetect import DetectorFactory, detect + +DetectorFactory.seed = 0 + +TEXT = sys.argv[1] +LANG = detect(TEXT) if sys.argv[2] == "Auto" else sys.argv[2] +RATE = sys.argv[3] +VOLUME = sys.argv[4] +GENDER = sys.argv[5] if len(sys.argv) == 6 else None +OUTPUT_FILE = "tts.wav" + +print("Running TTS...") +print(f"Text: {TEXT}, Language: {LANG}, Gender: {GENDER}, Rate: {RATE}, Volume: {VOLUME}") + +async def _main() -> None: + voices = await VoicesManager.create() + if GENDER is not None: + # From "zh-cn" to "zh-CN" etc. + if LANG == "zh-cn" or LANG == "zh-tw": + LOCALE = LANG[:-2] + LANG[-2:].upper() + voice = voices.find(Gender=GENDER, Locale=LOCALE) + else: + voice = voices.find(Gender=GENDER, Language=LANG) + VOICE = random.choice(voice)["Name"] + print(f"Using random {LANG} voice: {VOICE}") + else: + VOICE = LANG + + communicate = edge_tts.Communicate(text = TEXT, voice = VOICE, rate = RATE, volume = VOLUME) + await communicate.save(OUTPUT_FILE) + +if __name__ == "__main__": + if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(_main()) + else: + loop = asyncio.get_event_loop_policy().get_event_loop() + try: + loop.run_until_complete(_main()) + finally: + loop.close() diff --git a/edgetts/tts_voices.py b/edgetts/tts_voices.py new file mode 100644 index 0000000000000000000000000000000000000000..b59620a361b1ac9ac264a8b7d4bf44549c39ecf1 --- /dev/null +++ b/edgetts/tts_voices.py @@ -0,0 +1,306 @@ +#List of Supported Voices for edge_TTS +SUPPORTED_VOICES = { + 'zh-CN-XiaoxiaoNeural': 'zh-CN', + 'zh-CN-XiaoyiNeural': 'zh-CN', + 'zh-CN-YunjianNeural': 'zh-CN', + 'zh-CN-YunxiNeural': 'zh-CN', + 'zh-CN-YunxiaNeural': 'zh-CN', + 'zh-CN-YunyangNeural': 'zh-CN', + 'zh-HK-HiuGaaiNeural': 'zh-HK', + 'zh-HK-HiuMaanNeural': 'zh-HK', + 'zh-HK-WanLungNeural': 'zh-HK', + 'zh-TW-HsiaoChenNeural': 'zh-TW', + 'zh-TW-YunJheNeural': 'zh-TW', + 'zh-TW-HsiaoYuNeural': 'zh-TW', + 'af-ZA-AdriNeural': 'af-ZA', + 'af-ZA-WillemNeural': 'af-ZA', + 'am-ET-AmehaNeural': 'am-ET', + 'am-ET-MekdesNeural': 'am-ET', + 'ar-AE-FatimaNeural': 'ar-AE', + 'ar-AE-HamdanNeural': 'ar-AE', + 'ar-BH-AliNeural': 'ar-BH', + 'ar-BH-LailaNeural': 'ar-BH', + 'ar-DZ-AminaNeural': 'ar-DZ', + 'ar-DZ-IsmaelNeural': 'ar-DZ', + 'ar-EG-SalmaNeural': 'ar-EG', + 'ar-EG-ShakirNeural': 'ar-EG', + 'ar-IQ-BasselNeural': 'ar-IQ', + 'ar-IQ-RanaNeural': 'ar-IQ', + 'ar-JO-SanaNeural': 'ar-JO', + 'ar-JO-TaimNeural': 'ar-JO', + 'ar-KW-FahedNeural': 'ar-KW', + 'ar-KW-NouraNeural': 'ar-KW', + 'ar-LB-LaylaNeural': 'ar-LB', + 'ar-LB-RamiNeural': 'ar-LB', + 'ar-LY-ImanNeural': 'ar-LY', + 'ar-LY-OmarNeural': 'ar-LY', + 'ar-MA-JamalNeural': 'ar-MA', + 'ar-MA-MounaNeural': 'ar-MA', + 'ar-OM-AbdullahNeural': 'ar-OM', + 'ar-OM-AyshaNeural': 'ar-OM', + 'ar-QA-AmalNeural': 'ar-QA', + 'ar-QA-MoazNeural': 'ar-QA', + 'ar-SA-HamedNeural': 'ar-SA', + 'ar-SA-ZariyahNeural': 'ar-SA', + 'ar-SY-AmanyNeural': 'ar-SY', + 'ar-SY-LaithNeural': 'ar-SY', + 'ar-TN-HediNeural': 'ar-TN', + 'ar-TN-ReemNeural': 'ar-TN', + 'ar-YE-MaryamNeural': 'ar-YE', + 'ar-YE-SalehNeural': 'ar-YE', + 'az-AZ-BabekNeural': 'az-AZ', + 'az-AZ-BanuNeural': 'az-AZ', + 'bg-BG-BorislavNeural': 'bg-BG', + 'bg-BG-KalinaNeural': 'bg-BG', + 'bn-BD-NabanitaNeural': 'bn-BD', + 'bn-BD-PradeepNeural': 'bn-BD', + 'bn-IN-BashkarNeural': 'bn-IN', + 'bn-IN-TanishaaNeural': 'bn-IN', + 'bs-BA-GoranNeural': 'bs-BA', + 'bs-BA-VesnaNeural': 'bs-BA', + 'ca-ES-EnricNeural': 'ca-ES', + 'ca-ES-JoanaNeural': 'ca-ES', + 'cs-CZ-AntoninNeural': 'cs-CZ', + 'cs-CZ-VlastaNeural': 'cs-CZ', + 'cy-GB-AledNeural': 'cy-GB', + 'cy-GB-NiaNeural': 'cy-GB', + 'da-DK-ChristelNeural': 'da-DK', + 'da-DK-JeppeNeural': 'da-DK', + 'de-AT-IngridNeural': 'de-AT', + 'de-AT-JonasNeural': 'de-AT', + 'de-CH-JanNeural': 'de-CH', + 'de-CH-LeniNeural': 'de-CH', + 'de-DE-AmalaNeural': 'de-DE', + 'de-DE-ConradNeural': 'de-DE', + 'de-DE-KatjaNeural': 'de-DE', + 'de-DE-KillianNeural': 'de-DE', + 'el-GR-AthinaNeural': 'el-GR', + 'el-GR-NestorasNeural': 'el-GR', + 'en-AU-NatashaNeural': 'en-AU', + 'en-AU-WilliamNeural': 'en-AU', + 'en-CA-ClaraNeural': 'en-CA', + 'en-CA-LiamNeural': 'en-CA', + 'en-GB-LibbyNeural': 'en-GB', + 'en-GB-MaisieNeural': 'en-GB', + 'en-GB-RyanNeural': 'en-GB', + 'en-GB-SoniaNeural': 'en-GB', + 'en-GB-ThomasNeural': 'en-GB', + 'en-HK-SamNeural': 'en-HK', + 'en-HK-YanNeural': 'en-HK', + 'en-IE-ConnorNeural': 'en-IE', + 'en-IE-EmilyNeural': 'en-IE', + 'en-IN-NeerjaNeural': 'en-IN', + 'en-IN-PrabhatNeural': 'en-IN', + 'en-KE-AsiliaNeural': 'en-KE', + 'en-KE-ChilembaNeural': 'en-KE', + 'en-NG-AbeoNeural': 'en-NG', + 'en-NG-EzinneNeural': 'en-NG', + 'en-NZ-MitchellNeural': 'en-NZ', + 'en-NZ-MollyNeural': 'en-NZ', + 'en-PH-JamesNeural': 'en-PH', + 'en-PH-RosaNeural': 'en-PH', + 'en-SG-LunaNeural': 'en-SG', + 'en-SG-WayneNeural': 'en-SG', + 'en-TZ-ElimuNeural': 'en-TZ', + 'en-TZ-ImaniNeural': 'en-TZ', + 'en-US-AnaNeural': 'en-US', + 'en-US-AriaNeural': 'en-US', + 'en-US-ChristopherNeural': 'en-US', + 'en-US-EricNeural': 'en-US', + 'en-US-GuyNeural': 'en-US', + 'en-US-JennyNeural': 'en-US', + 'en-US-MichelleNeural': 'en-US', + 'en-ZA-LeahNeural': 'en-ZA', + 'en-ZA-LukeNeural': 'en-ZA', + 'es-AR-ElenaNeural': 'es-AR', + 'es-AR-TomasNeural': 'es-AR', + 'es-BO-MarceloNeural': 'es-BO', + 'es-BO-SofiaNeural': 'es-BO', + 'es-CL-CatalinaNeural': 'es-CL', + 'es-CL-LorenzoNeural': 'es-CL', + 'es-CO-GonzaloNeural': 'es-CO', + 'es-CO-SalomeNeural': 'es-CO', + 'es-CR-JuanNeural': 'es-CR', + 'es-CR-MariaNeural': 'es-CR', + 'es-CU-BelkysNeural': 'es-CU', + 'es-CU-ManuelNeural': 'es-CU', + 'es-DO-EmilioNeural': 'es-DO', + 'es-DO-RamonaNeural': 'es-DO', + 'es-EC-AndreaNeural': 'es-EC', + 'es-EC-LuisNeural': 'es-EC', + 'es-ES-AlvaroNeural': 'es-ES', + 'es-ES-ElviraNeural': 'es-ES', + 'es-ES-ManuelEsCUNeural': 'es-ES', + 'es-GQ-JavierNeural': 'es-GQ', + 'es-GQ-TeresaNeural': 'es-GQ', + 'es-GT-AndresNeural': 'es-GT', + 'es-GT-MartaNeural': 'es-GT', + 'es-HN-CarlosNeural': 'es-HN', + 'es-HN-KarlaNeural': 'es-HN', + 'es-MX-DaliaNeural': 'es-MX', + 'es-MX-JorgeNeural': 'es-MX', + 'es-MX-LorenzoEsCLNeural': 'es-MX', + 'es-NI-FedericoNeural': 'es-NI', + 'es-NI-YolandaNeural': 'es-NI', + 'es-PA-MargaritaNeural': 'es-PA', + 'es-PA-RobertoNeural': 'es-PA', + 'es-PE-AlexNeural': 'es-PE', + 'es-PE-CamilaNeural': 'es-PE', + 'es-PR-KarinaNeural': 'es-PR', + 'es-PR-VictorNeural': 'es-PR', + 'es-PY-MarioNeural': 'es-PY', + 'es-PY-TaniaNeural': 'es-PY', + 'es-SV-LorenaNeural': 'es-SV', + 'es-SV-RodrigoNeural': 'es-SV', + 'es-US-AlonsoNeural': 'es-US', + 'es-US-PalomaNeural': 'es-US', + 'es-UY-MateoNeural': 'es-UY', + 'es-UY-ValentinaNeural': 'es-UY', + 'es-VE-PaolaNeural': 'es-VE', + 'es-VE-SebastianNeural': 'es-VE', + 'et-EE-AnuNeural': 'et-EE', + 'et-EE-KertNeural': 'et-EE', + 'fa-IR-DilaraNeural': 'fa-IR', + 'fa-IR-FaridNeural': 'fa-IR', + 'fi-FI-HarriNeural': 'fi-FI', + 'fi-FI-NooraNeural': 'fi-FI', + 'fil-PH-AngeloNeural': 'fil-PH', + 'fil-PH-BlessicaNeural': 'fil-PH', + 'fr-BE-CharlineNeural': 'fr-BE', + 'fr-BE-GerardNeural': 'fr-BE', + 'fr-CA-AntoineNeural': 'fr-CA', + 'fr-CA-JeanNeural': 'fr-CA', + 'fr-CA-SylvieNeural': 'fr-CA', + 'fr-CH-ArianeNeural': 'fr-CH', + 'fr-CH-FabriceNeural': 'fr-CH', + 'fr-FR-DeniseNeural': 'fr-FR', + 'fr-FR-EloiseNeural': 'fr-FR', + 'fr-FR-HenriNeural': 'fr-FR', + 'ga-IE-ColmNeural': 'ga-IE', + 'ga-IE-OrlaNeural': 'ga-IE', + 'gl-ES-RoiNeural': 'gl-ES', + 'gl-ES-SabelaNeural': 'gl-ES', + 'gu-IN-DhwaniNeural': 'gu-IN', + 'gu-IN-NiranjanNeural': 'gu-IN', + 'he-IL-AvriNeural': 'he-IL', + 'he-IL-HilaNeural': 'he-IL', + 'hi-IN-MadhurNeural': 'hi-IN', + 'hi-IN-SwaraNeural': 'hi-IN', + 'hr-HR-GabrijelaNeural': 'hr-HR', + 'hr-HR-SreckoNeural': 'hr-HR', + 'hu-HU-NoemiNeural': 'hu-HU', + 'hu-HU-TamasNeural': 'hu-HU', + 'id-ID-ArdiNeural': 'id-ID', + 'id-ID-GadisNeural': 'id-ID', + 'is-IS-GudrunNeural': 'is-IS', + 'is-IS-GunnarNeural': 'is-IS', + 'it-IT-DiegoNeural': 'it-IT', + 'it-IT-ElsaNeural': 'it-IT', + 'it-IT-IsabellaNeural': 'it-IT', + 'ja-JP-KeitaNeural': 'ja-JP', + 'ja-JP-NanamiNeural': 'ja-JP', + 'jv-ID-DimasNeural': 'jv-ID', + 'jv-ID-SitiNeural': 'jv-ID', + 'ka-GE-EkaNeural': 'ka-GE', + 'ka-GE-GiorgiNeural': 'ka-GE', + 'kk-KZ-AigulNeural': 'kk-KZ', + 'kk-KZ-DauletNeural': 'kk-KZ', + 'km-KH-PisethNeural': 'km-KH', + 'km-KH-SreymomNeural': 'km-KH', + 'kn-IN-GaganNeural': 'kn-IN', + 'kn-IN-SapnaNeural': 'kn-IN', + 'ko-KR-InJoonNeural': 'ko-KR', + 'ko-KR-SunHiNeural': 'ko-KR', + 'lo-LA-ChanthavongNeural': 'lo-LA', + 'lo-LA-KeomanyNeural': 'lo-LA', + 'lt-LT-LeonasNeural': 'lt-LT', + 'lt-LT-OnaNeural': 'lt-LT', + 'lv-LV-EveritaNeural': 'lv-LV', + 'lv-LV-NilsNeural': 'lv-LV', + 'mk-MK-AleksandarNeural': 'mk-MK', + 'mk-MK-MarijaNeural': 'mk-MK', + 'ml-IN-MidhunNeural': 'ml-IN', + 'ml-IN-SobhanaNeural': 'ml-IN', + 'mn-MN-BataaNeural': 'mn-MN', + 'mn-MN-YesuiNeural': 'mn-MN', + 'mr-IN-AarohiNeural': 'mr-IN', + 'mr-IN-ManoharNeural': 'mr-IN', + 'ms-MY-OsmanNeural': 'ms-MY', + 'ms-MY-YasminNeural': 'ms-MY', + 'mt-MT-GraceNeural': 'mt-MT', + 'mt-MT-JosephNeural': 'mt-MT', + 'my-MM-NilarNeural': 'my-MM', + 'my-MM-ThihaNeural': 'my-MM', + 'nb-NO-FinnNeural': 'nb-NO', + 'nb-NO-PernilleNeural': 'nb-NO', + 'ne-NP-HemkalaNeural': 'ne-NP', + 'ne-NP-SagarNeural': 'ne-NP', + 'nl-BE-ArnaudNeural': 'nl-BE', + 'nl-BE-DenaNeural': 'nl-BE', + 'nl-NL-ColetteNeural': 'nl-NL', + 'nl-NL-FennaNeural': 'nl-NL', + 'nl-NL-MaartenNeural': 'nl-NL', + 'pl-PL-MarekNeural': 'pl-PL', + 'pl-PL-ZofiaNeural': 'pl-PL', + 'ps-AF-GulNawazNeural': 'ps-AF', + 'ps-AF-LatifaNeural': 'ps-AF', + 'pt-BR-AntonioNeural': 'pt-BR', + 'pt-BR-FranciscaNeural': 'pt-BR', + 'pt-PT-DuarteNeural': 'pt-PT', + 'pt-PT-RaquelNeural': 'pt-PT', + 'ro-RO-AlinaNeural': 'ro-RO', + 'ro-RO-EmilNeural': 'ro-RO', + 'ru-RU-DmitryNeural': 'ru-RU', + 'ru-RU-SvetlanaNeural': 'ru-RU', + 'si-LK-SameeraNeural': 'si-LK', + 'si-LK-ThiliniNeural': 'si-LK', + 'sk-SK-LukasNeural': 'sk-SK', + 'sk-SK-ViktoriaNeural': 'sk-SK', + 'sl-SI-PetraNeural': 'sl-SI', + 'sl-SI-RokNeural': 'sl-SI', + 'so-SO-MuuseNeural': 'so-SO', + 'so-SO-UbaxNeural': 'so-SO', + 'sq-AL-AnilaNeural': 'sq-AL', + 'sq-AL-IlirNeural': 'sq-AL', + 'sr-RS-NicholasNeural': 'sr-RS', + 'sr-RS-SophieNeural': 'sr-RS', + 'su-ID-JajangNeural': 'su-ID', + 'su-ID-TutiNeural': 'su-ID', + 'sv-SE-MattiasNeural': 'sv-SE', + 'sv-SE-SofieNeural': 'sv-SE', + 'sw-KE-RafikiNeural': 'sw-KE', + 'sw-KE-ZuriNeural': 'sw-KE', + 'sw-TZ-DaudiNeural': 'sw-TZ', + 'sw-TZ-RehemaNeural': 'sw-TZ', + 'ta-IN-PallaviNeural': 'ta-IN', + 'ta-IN-ValluvarNeural': 'ta-IN', + 'ta-LK-KumarNeural': 'ta-LK', + 'ta-LK-SaranyaNeural': 'ta-LK', + 'ta-MY-KaniNeural': 'ta-MY', + 'ta-MY-SuryaNeural': 'ta-MY', + 'ta-SG-AnbuNeural': 'ta-SG', + 'ta-SG-VenbaNeural': 'ta-SG', + 'te-IN-MohanNeural': 'te-IN', + 'te-IN-ShrutiNeural': 'te-IN', + 'th-TH-NiwatNeural': 'th-TH', + 'th-TH-PremwadeeNeural': 'th-TH', + 'tr-TR-AhmetNeural': 'tr-TR', + 'tr-TR-EmelNeural': 'tr-TR', + 'uk-UA-OstapNeural': 'uk-UA', + 'uk-UA-PolinaNeural': 'uk-UA', + 'ur-IN-GulNeural': 'ur-IN', + 'ur-IN-SalmanNeural': 'ur-IN', + 'ur-PK-AsadNeural': 'ur-PK', + 'ur-PK-UzmaNeural': 'ur-PK', + 'uz-UZ-MadinaNeural': 'uz-UZ', + 'uz-UZ-SardorNeural': 'uz-UZ', + 'vi-VN-HoaiMyNeural': 'vi-VN', + 'vi-VN-NamMinhNeural': 'vi-VN', + 'zu-ZA-ThandoNeural': 'zu-ZA', + 'zu-ZA-ThembaNeural': 'zu-ZA', +} + +SUPPORTED_LANGUAGES = [ + "Auto", + *SUPPORTED_VOICES.keys() +] \ No newline at end of file diff --git a/export_index_for_onnx.py b/export_index_for_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..067a5cc58f68ef0737c80d7f4ec8c8ee2f090602 --- /dev/null +++ b/export_index_for_onnx.py @@ -0,0 +1,20 @@ +import os +import pickle + +import faiss + +path = "crs" +indexs_file_path = f"checkpoints/{path}/feature_and_index.pkl" +indexs_out_dir = f"checkpoints/{path}/" + +with open("feature_and_index.pkl",mode="rb") as f: + indexs = pickle.load(f) + +for k in indexs: + print(f"Save {k} index") + faiss.write_index( + indexs[k], + os.path.join(indexs_out_dir,f"Index-{k}.index") + ) + +print("Saved all index") \ No newline at end of file diff --git a/flask_api.py b/flask_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5547e0071f8465f9e92d66c96554c8535c7fe89b --- /dev/null +++ b/flask_api.py @@ -0,0 +1,60 @@ +import io +import logging + +import soundfile +import torch +import torchaudio +from flask import Flask, request, send_file +from flask_cors import CORS + +from inference.infer_tool import RealTimeVC, Svc + +app = Flask(__name__) + +CORS(app) + +logging.getLogger('numba').setLevel(logging.WARNING) + + +@app.route("/voiceChangeModel", methods=["POST"]) +def voice_change_model(): + request_form = request.form + wave_file = request.files.get("sample", None) + # 变调信息 + f_pitch_change = float(request_form.get("fPitchChange", 0)) + # DAW所需的采样率 + daw_sample = int(float(request_form.get("sampleRate", 0))) + speaker_id = int(float(request_form.get("sSpeakId", 0))) + # http获得wav文件并转换 + input_wav_path = io.BytesIO(wave_file.read()) + + # 模型推理 + if raw_infer: + # out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path) + out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0, + auto_predict_f0=False, noice_scale=0.4, f0_filter=False) + tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample) + else: + out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0, + auto_predict_f0=False, noice_scale=0.4, f0_filter=False) + tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample) + # 返回音频 + out_wav_path = io.BytesIO() + soundfile.write(out_wav_path, tar_audio.cpu().numpy(), daw_sample, format="wav") + out_wav_path.seek(0) + return send_file(out_wav_path, download_name="temp.wav", as_attachment=True) + + +if __name__ == '__main__': + # 启用则为直接切片合成,False为交叉淡化方式 + # vst插件调整0.3-0.5s切片时间可以降低延迟,直接切片方法会有连接处爆音、交叉淡化会有轻微重叠声音 + # 自行选择能接受的方法,或将vst最大切片时间调整为1s,此处设为Ture,延迟大音质稳定一些 + raw_infer = True + # 每个模型和config是唯一对应的 + model_name = "logs/32k/G_174000-Copy1.pth" + config_name = "configs/config.json" + cluster_model_path = "logs/44k/kmeans_10000.pt" + svc_model = Svc(model_name, config_name, cluster_model_path=cluster_model_path) + svc = RealTimeVC() + # 此处与vst插件对应,不建议更改 + app.run(port=6842, host="0.0.0.0", debug=False, threaded=False) diff --git a/flask_api_full_song.py b/flask_api_full_song.py new file mode 100644 index 0000000000000000000000000000000000000000..29fbd720de71cff73903f71b102fb1aa2848bad6 --- /dev/null +++ b/flask_api_full_song.py @@ -0,0 +1,55 @@ +import io + +import numpy as np +import soundfile +from flask import Flask, request, send_file + +from inference import infer_tool, slicer + +app = Flask(__name__) + + +@app.route("/wav2wav", methods=["POST"]) +def wav2wav(): + request_form = request.form + audio_path = request_form.get("audio_path", None) # wav文件地址 + tran = int(float(request_form.get("tran", 0))) # 音调 + spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config) + wav_format = request_form.get("wav_format", 'wav') # 范围文件格式 + infer_tool.format_wav(audio_path) + chunks = slicer.cut(audio_path, db_thresh=-40) + audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks) + + audio = [] + for (slice_tag, data) in audio_data: + print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') + + length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample)) + if slice_tag: + print('jump empty segment') + _audio = np.zeros(length) + else: + # padd + pad_len = int(audio_sr * 0.5) + data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])]) + raw_path = io.BytesIO() + soundfile.write(raw_path, data, audio_sr, format="wav") + raw_path.seek(0) + out_audio, out_sr = svc_model.infer(spk, tran, raw_path) + svc_model.clear_empty() + _audio = out_audio.cpu().numpy() + pad_len = int(svc_model.target_sample * 0.5) + _audio = _audio[pad_len:-pad_len] + + audio.extend(list(infer_tool.pad_array(_audio, length))) + out_wav_path = io.BytesIO() + soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format) + out_wav_path.seek(0) + return send_file(out_wav_path, download_name=f"temp.{wav_format}", as_attachment=True) + + +if __name__ == '__main__': + model_name = "logs/44k/G_60000.pth" # 模型地址 + config_name = "configs/config.json" # config地址 + svc_model = infer_tool.Svc(model_name, config_name) + app.run(port=1145, host="0.0.0.0", debug=False, threaded=False) diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/infer_tool.py b/inference/infer_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..06ebcaead5aec969c1c12e49aaa80683626b1e1c --- /dev/null +++ b/inference/infer_tool.py @@ -0,0 +1,546 @@ +import gc +import hashlib +import io +import json +import logging +import os +import pickle +import time +from pathlib import Path + +import librosa +import numpy as np + +# import onnxruntime +import soundfile +import torch +import torchaudio + +import cluster +import utils +from diffusion.unit2mel import load_model_vocoder +from inference import slicer +from models import SynthesizerTrn + +logging.getLogger('matplotlib').setLevel(logging.WARNING) + + +def read_temp(file_name): + if not os.path.exists(file_name): + with open(file_name, "w") as f: + f.write(json.dumps({"info": "temp_dict"})) + return {} + else: + try: + with open(file_name, "r") as f: + data = f.read() + data_dict = json.loads(data) + if os.path.getsize(file_name) > 50 * 1024 * 1024: + f_name = file_name.replace("\\", "/").split("/")[-1] + print(f"clean {f_name}") + for wav_hash in list(data_dict.keys()): + if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600: + del data_dict[wav_hash] + except Exception as e: + print(e) + print(f"{file_name} error,auto rebuild file") + data_dict = {"info": "temp_dict"} + return data_dict + + +def write_temp(file_name, data): + with open(file_name, "w") as f: + f.write(json.dumps(data)) + + +def timeit(func): + def run(*args, **kwargs): + t = time.time() + res = func(*args, **kwargs) + print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t)) + return res + + return run + + +def format_wav(audio_path): + if Path(audio_path).suffix == '.wav': + return + raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None) + soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate) + + +def get_end_file(dir_path, end): + file_lists = [] + for root, dirs, files in os.walk(dir_path): + files = [f for f in files if f[0] != '.'] + dirs[:] = [d for d in dirs if d[0] != '.'] + for f_file in files: + if f_file.endswith(end): + file_lists.append(os.path.join(root, f_file).replace("\\", "/")) + return file_lists + + +def get_md5(content): + return hashlib.new("md5", content).hexdigest() + +def fill_a_to_b(a, b): + if len(a) < len(b): + for _ in range(0, len(b) - len(a)): + a.append(a[0]) + +def mkdir(paths: list): + for path in paths: + if not os.path.exists(path): + os.mkdir(path) + +def pad_array(arr, target_length): + current_length = arr.shape[0] + if current_length >= target_length: + return arr + else: + pad_width = target_length - current_length + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0)) + return padded_arr + +def split_list_by_n(list_collection, n, pre=0): + for i in range(0, len(list_collection), n): + yield list_collection[i-pre if i-pre>=0 else i: i + n] + + +class F0FilterException(Exception): + pass + +class Svc(object): + def __init__(self, net_g_path, config_path, + device=None, + cluster_model_path="logs/44k/kmeans_10000.pt", + nsf_hifigan_enhance = False, + diffusion_model_path="logs/44k/diffusion/model_0.pt", + diffusion_config_path="configs/diffusion.yaml", + shallow_diffusion = False, + only_diffusion = False, + spk_mix_enable = False, + feature_retrieval = False + ): + self.net_g_path = net_g_path + self.only_diffusion = only_diffusion + self.shallow_diffusion = shallow_diffusion + self.feature_retrieval = feature_retrieval + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.net_g_ms = None + if not self.only_diffusion: + self.hps_ms = utils.get_hparams_from_file(config_path,True) + self.target_sample = self.hps_ms.data.sampling_rate + self.hop_size = self.hps_ms.data.hop_length + self.spk2id = self.hps_ms.spk + self.unit_interpolate_mode = self.hps_ms.data.unit_interpolate_mode if self.hps_ms.data.unit_interpolate_mode is not None else 'left' + self.vol_embedding = self.hps_ms.model.vol_embedding if self.hps_ms.model.vol_embedding is not None else False + self.speech_encoder = self.hps_ms.model.speech_encoder if self.hps_ms.model.speech_encoder is not None else 'vec768l12' + + self.nsf_hifigan_enhance = nsf_hifigan_enhance + if self.shallow_diffusion or self.only_diffusion: + if os.path.exists(diffusion_model_path) and os.path.exists(diffusion_model_path): + self.diffusion_model,self.vocoder,self.diffusion_args = load_model_vocoder(diffusion_model_path,self.dev,config_path=diffusion_config_path) + if self.only_diffusion: + self.target_sample = self.diffusion_args.data.sampling_rate + self.hop_size = self.diffusion_args.data.block_size + self.spk2id = self.diffusion_args.spk + self.dtype = torch.float32 + self.speech_encoder = self.diffusion_args.data.encoder + self.unit_interpolate_mode = self.diffusion_args.data.unit_interpolate_mode if self.diffusion_args.data.unit_interpolate_mode is not None else 'left' + if spk_mix_enable: + self.diffusion_model.init_spkmix(len(self.spk2id)) + else: + print("No diffusion model or config found. Shallow diffusion mode will False") + self.shallow_diffusion = self.only_diffusion = False + + # load hubert and model + if not self.only_diffusion: + self.load_model(spk_mix_enable) + self.hubert_model = utils.get_speech_encoder(self.speech_encoder,device=self.dev) + self.volume_extractor = utils.Volume_Extractor(self.hop_size) + else: + self.hubert_model = utils.get_speech_encoder(self.diffusion_args.data.encoder,device=self.dev) + self.volume_extractor = utils.Volume_Extractor(self.diffusion_args.data.block_size) + + if os.path.exists(cluster_model_path): + if self.feature_retrieval: + with open(cluster_model_path,"rb") as f: + self.cluster_model = pickle.load(f) + self.big_npy = None + self.now_spk_id = -1 + else: + self.cluster_model = cluster.get_cluster_model(cluster_model_path) + else: + self.feature_retrieval=False + + if self.shallow_diffusion : + self.nsf_hifigan_enhance = False + if self.nsf_hifigan_enhance: + from modules.enhancer import Enhancer + self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model',device=self.dev) + + def load_model(self, spk_mix_enable=False): + # get model configuration + self.net_g_ms = SynthesizerTrn( + self.hps_ms.data.filter_length // 2 + 1, + self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, + **self.hps_ms.model) + _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None) + self.dtype = list(self.net_g_ms.parameters())[0].dtype + if "half" in self.net_g_path and torch.cuda.is_available(): + _ = self.net_g_ms.half().eval().to(self.dev) + else: + _ = self.net_g_ms.eval().to(self.dev) + if spk_mix_enable: + self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev) + + def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05): + + if not hasattr(self,"f0_predictor_object") or self.f0_predictor_object is None or f0_predictor != self.f0_predictor_object.name: + self.f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold) + f0, uv = self.f0_predictor_object.compute_f0_uv(wav) + + if f0_filter and sum(f0) == 0: + raise F0FilterException("No voice detected") + f0 = torch.FloatTensor(f0).to(self.dev) + uv = torch.FloatTensor(uv).to(self.dev) + + f0 = f0 * 2 ** (tran / 12) + f0 = f0.unsqueeze(0) + uv = uv.unsqueeze(0) + + wav = torch.from_numpy(wav).to(self.dev) + if not hasattr(self,"audio16k_resample_transform"): + self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to(self.dev) + wav16k = self.audio16k_resample_transform(wav[None,:])[0] + + c = self.hubert_model.encoder(wav16k) + c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode) + + if cluster_infer_ratio !=0: + if self.feature_retrieval: + speaker_id = self.spk2id.get(speaker) + if not speaker_id and type(speaker) is int: + if len(self.spk2id.__dict__) >= speaker: + speaker_id = speaker + if speaker_id is None: + raise RuntimeError("The name you entered is not in the speaker list!") + feature_index = self.cluster_model[speaker_id] + feat_np = np.ascontiguousarray(c.transpose(0,1).cpu().numpy()) + if self.big_npy is None or self.now_spk_id != speaker_id: + self.big_npy = feature_index.reconstruct_n(0, feature_index.ntotal) + self.now_spk_id = speaker_id + print("starting feature retrieval...") + score, ix = feature_index.search(feat_np, k=8) + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + c = cluster_infer_ratio * npy + (1 - cluster_infer_ratio) * feat_np + c = torch.FloatTensor(c).to(self.dev).transpose(0,1) + print("end feature retrieval...") + else: + cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, speaker).T + cluster_c = torch.FloatTensor(cluster_c).to(self.dev) + c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c + + c = c.unsqueeze(0) + return c, f0, uv + + def infer(self, speaker, tran, raw_path, + cluster_infer_ratio=0, + auto_predict_f0=False, + noice_scale=0.4, + f0_filter=False, + f0_predictor='pm', + enhancer_adaptive_key = 0, + cr_threshold = 0.05, + k_step = 100, + frame = 0, + spk_mix = False, + second_encoding = False, + loudness_envelope_adjustment = 1 + ): + torchaudio.set_audio_backend("soundfile") + wav, sr = torchaudio.load(raw_path) + if not hasattr(self,"audio_resample_transform") or self.audio16k_resample_transform.orig_freq != sr: + self.audio_resample_transform = torchaudio.transforms.Resample(sr,self.target_sample) + wav = self.audio_resample_transform(wav).numpy()[0] + if spk_mix: + c, f0, uv = self.get_unit_f0(wav, tran, 0, None, f0_filter,f0_predictor,cr_threshold=cr_threshold) + n_frames = f0.size(1) + sid = speaker[:, frame:frame+n_frames].transpose(0,1) + else: + speaker_id = self.spk2id.get(speaker) + if not speaker_id and type(speaker) is int: + if len(self.spk2id.__dict__) >= speaker: + speaker_id = speaker + if speaker_id is None: + raise RuntimeError("The name you entered is not in the speaker list!") + sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0) + c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter,f0_predictor,cr_threshold=cr_threshold) + n_frames = f0.size(1) + c = c.to(self.dtype) + f0 = f0.to(self.dtype) + uv = uv.to(self.dtype) + with torch.no_grad(): + start = time.time() + vol = None + if not self.only_diffusion: + vol = self.volume_extractor.extract(torch.FloatTensor(wav).to(self.dev)[None,:])[None,:].to(self.dev) if self.vol_embedding else None + audio,f0 = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale,vol=vol) + audio = audio[0,0].data.float() + audio_mel = self.vocoder.extract(audio[None,:],self.target_sample) if self.shallow_diffusion else None + else: + audio = torch.FloatTensor(wav).to(self.dev) + audio_mel = None + if self.dtype != torch.float32: + c = c.to(torch.float32) + f0 = f0.to(torch.float32) + uv = uv.to(torch.float32) + if self.only_diffusion or self.shallow_diffusion: + vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None] + if self.shallow_diffusion and second_encoding: + if not hasattr(self,"audio16k_resample_transform"): + self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to(self.dev) + audio16k = self.audio16k_resample_transform(audio[None,:])[0] + c = self.hubert_model.encoder(audio16k) + c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode) + f0 = f0[:,:,None] + c = c.transpose(-1,-2) + audio_mel = self.diffusion_model( + c, + f0, + vol, + spk_id = sid, + spk_mix_dict = None, + gt_spec=audio_mel, + infer=True, + infer_speedup=self.diffusion_args.infer.speedup, + method=self.diffusion_args.infer.method, + k_step=k_step) + audio = self.vocoder.infer(audio_mel, f0).squeeze() + if self.nsf_hifigan_enhance: + audio, _ = self.enhancer.enhance( + audio[None,:], + self.target_sample, + f0[:,:,None], + self.hps_ms.data.hop_length, + adaptive_key = enhancer_adaptive_key) + if loudness_envelope_adjustment != 1: + audio = utils.change_rms(wav,self.target_sample,audio,self.target_sample,loudness_envelope_adjustment) + use_time = time.time() - start + print("vits use time:{}".format(use_time)) + return audio, audio.shape[-1], n_frames + + def clear_empty(self): + # clean up vram + torch.cuda.empty_cache() + + def unload_model(self): + # unload model + self.net_g_ms = self.net_g_ms.to("cpu") + del self.net_g_ms + if hasattr(self,"enhancer"): + self.enhancer.enhancer = self.enhancer.enhancer.to("cpu") + del self.enhancer.enhancer + del self.enhancer + gc.collect() + + def slice_inference(self, + raw_audio_path, + spk, + tran, + slice_db, + cluster_infer_ratio, + auto_predict_f0, + noice_scale, + pad_seconds=0.5, + clip_seconds=0, + lg_num=0, + lgr_num =0.75, + f0_predictor='pm', + enhancer_adaptive_key = 0, + cr_threshold = 0.05, + k_step = 100, + use_spk_mix = False, + second_encoding = False, + loudness_envelope_adjustment = 1 + ): + if use_spk_mix: + if len(self.spk2id) == 1: + spk = self.spk2id.keys()[0] + use_spk_mix = False + wav_path = Path(raw_audio_path).with_suffix('.wav') + chunks = slicer.cut(wav_path, db_thresh=slice_db) + audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks) + per_size = int(clip_seconds*audio_sr) + lg_size = int(lg_num*audio_sr) + lg_size_r = int(lg_size*lgr_num) + lg_size_c_l = (lg_size-lg_size_r)//2 + lg_size_c_r = lg_size-lg_size_r-lg_size_c_l + lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0 + + if use_spk_mix: + assert len(self.spk2id) == len(spk) + audio_length = 0 + for (slice_tag, data) in audio_data: + aud_length = int(np.ceil(len(data) / audio_sr * self.target_sample)) + if slice_tag: + audio_length += aud_length // self.hop_size + continue + if per_size != 0: + datas = split_list_by_n(data, per_size,lg_size) + else: + datas = [data] + for k,dat in enumerate(datas): + pad_len = int(audio_sr * pad_seconds) + per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) + a_length = per_length + 2 * pad_len + audio_length += a_length // self.hop_size + audio_length += len(audio_data) + spk_mix_tensor = torch.zeros(size=(len(spk), audio_length)).to(self.dev) + for i in range(len(spk)): + last_end = None + for mix in spk[i]: + if mix[3]<0. or mix[2]<0.: + raise RuntimeError("mix value must higer Than zero!") + begin = int(audio_length * mix[0]) + end = int(audio_length * mix[1]) + length = end - begin + if length<=0: + raise RuntimeError("begin Must lower Than end!") + step = (mix[3] - mix[2])/length + if last_end is not None: + if last_end != begin: + raise RuntimeError("[i]EndTime Must Equal [i+1]BeginTime!") + last_end = end + if step == 0.: + spk_mix_data = torch.zeros(length).to(self.dev) + mix[2] + else: + spk_mix_data = torch.arange(mix[2],mix[3],step).to(self.dev) + if(len(spk_mix_data)0 or p_len - len(f0) - pad_size>0): + f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') + + f0 *= pow(2, f0_up_key / 12) + f0_mel = 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1 + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > 255] = 255 + f0_coarse = np.rint(f0_mel).astype(np.int) + return f0_coarse, f0 + +def clean_pitch(input_pitch): + num_nan = np.sum(input_pitch == 1) + if num_nan / len(input_pitch) > 0.9: + input_pitch[input_pitch != 1] = 1 + return input_pitch + + +def plt_pitch(input_pitch): + input_pitch = input_pitch.astype(float) + input_pitch[input_pitch == 1] = np.nan + return input_pitch + + +def f0_to_pitch(ff): + f0_pitch = 69 + 12 * np.log2(ff / 440) + return f0_pitch + + +def fill_a_to_b(a, b): + if len(a) < len(b): + for _ in range(0, len(b) - len(a)): + a.append(a[0]) + + +def mkdir(paths: list): + for path in paths: + if not os.path.exists(path): + os.mkdir(path) + + +class VitsSvc(object): + def __init__(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.SVCVITS = None + self.hps = None + self.speakers = None + self.hubert_soft = utils.get_hubert_model() + + def set_device(self, device): + self.device = torch.device(device) + self.hubert_soft.to(self.device) + if self.SVCVITS is not None: + self.SVCVITS.to(self.device) + + def loadCheckpoint(self, path): + self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") + self.SVCVITS = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + **self.hps.model) + _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.SVCVITS, None) + _ = self.SVCVITS.eval().to(self.device) + self.speakers = self.hps.spk + + def get_units(self, source, sr): + source = source.unsqueeze(0).to(self.device) + with torch.inference_mode(): + units = self.hubert_soft.units(source) + return units + + + def get_unit_pitch(self, in_path, tran): + source, sr = torchaudio.load(in_path) + source = torchaudio.functional.resample(source, sr, 16000) + if len(source.shape) == 2 and source.shape[1] >= 2: + source = torch.mean(source, dim=0).unsqueeze(0) + soft = self.get_units(source, sr).squeeze(0).cpu().numpy() + f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0]*2, tran) + return soft, f0 + + def infer(self, speaker_id, tran, raw_path): + speaker_id = self.speakers[speaker_id] + sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0) + soft, pitch = self.get_unit_pitch(raw_path, tran) + f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.device) + stn_tst = torch.FloatTensor(soft) + with torch.no_grad(): + x_tst = stn_tst.unsqueeze(0).to(self.device) + x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2) + audio,_ = self.SVCVITS.infer(x_tst, f0=f0, g=sid)[0,0].data.float() + return audio, audio.shape[-1] + + def inference(self,srcaudio,chara,tran,slice_db): + sampling_rate, audio = srcaudio + audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio.transpose(1, 0)) + if sampling_rate != 16000: + audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) + soundfile.write("tmpwav.wav", audio, 16000, format="wav") + chunks = slicer.cut("tmpwav.wav", db_thresh=slice_db) + audio_data, audio_sr = slicer.chunks2audio("tmpwav.wav", chunks) + audio = [] + for (slice_tag, data) in audio_data: + length = int(np.ceil(len(data) / audio_sr * self.hps.data.sampling_rate)) + raw_path = io.BytesIO() + soundfile.write(raw_path, data, audio_sr, format="wav") + raw_path.seek(0) + if slice_tag: + _audio = np.zeros(length) + else: + out_audio, out_sr = self.infer(chara, tran, raw_path) + _audio = out_audio.cpu().numpy() + audio.extend(list(_audio)) + audio = (np.array(audio) * 32768.0).astype('int16') + return (self.hps.data.sampling_rate,audio) diff --git a/inference/slicer.py b/inference/slicer.py new file mode 100644 index 0000000000000000000000000000000000000000..b05840bcf6bdced0b6e2adbecb1a1dd5b3dee462 --- /dev/null +++ b/inference/slicer.py @@ -0,0 +1,142 @@ +import librosa +import torch +import torchaudio + + +class Slicer: + def __init__(self, + sr: int, + threshold: float = -40., + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 20, + max_sil_kept: int = 5000): + if not min_length >= min_interval >= hop_size: + raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') + if not max_sil_kept >= hop_size: + raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] + else: + return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] + + # @timeit + def slice(self, waveform): + if len(waveform.shape) > 1: + samples = librosa.to_mono(waveform) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} + rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start: i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() + pos += i - self.max_sil_kept + pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start + pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start + pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if silence_start is not None and total_frames - silence_start >= self.min_interval: + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} + else: + chunks = [] + # 第一段静音并非从头开始,补上有声片段 + if sil_tags[0][0]: + chunks.append( + {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"}) + for i in range(0, len(sil_tags)): + # 标识有声片段(跳过第一段) + if i: + chunks.append({"slice": False, + "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"}) + # 标识所有静音片段 + chunks.append({"slice": True, + "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"}) + # 最后一段静音并非结尾,补上结尾片段 + if sil_tags[-1][1] * self.hop_size < len(waveform): + chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"}) + chunk_dict = {} + for i in range(len(chunks)): + chunk_dict[str(i)] = chunks[i] + return chunk_dict + + +def cut(audio_path, db_thresh=-30, min_len=5000): + audio, sr = librosa.load(audio_path, sr=None) + slicer = Slicer( + sr=sr, + threshold=db_thresh, + min_length=min_len + ) + chunks = slicer.slice(audio) + return chunks + + +def chunks2audio(audio_path, chunks): + chunks = dict(chunks) + audio, sr = torchaudio.load(audio_path) + if len(audio.shape) == 2 and audio.shape[1] >= 2: + audio = torch.mean(audio, dim=0).unsqueeze(0) + audio = audio.cpu().numpy()[0] + result = [] + for k, v in chunks.items(): + tag = v["split_time"].split(",") + if tag[0] != tag[1]: + result.append((v["slice"], audio[int(tag[0]):int(tag[1])])) + return result, sr diff --git a/inference_main.py b/inference_main.py new file mode 100644 index 0000000000000000000000000000000000000000..a99f6ec4916848ee96f575ceb7a1d71bbe2ca464 --- /dev/null +++ b/inference_main.py @@ -0,0 +1,155 @@ +import logging + +import soundfile + +from inference import infer_tool +from inference.infer_tool import Svc +from spkmix import spk_mix_map + +logging.getLogger('numba').setLevel(logging.WARNING) +chunks_dict = infer_tool.read_temp("inference/chunks_temp.json") + + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='sovits4 inference') + + # 一定要设置的部分 + parser.add_argument('-m', '--model_path', type=str, default="logs/44k/G_37600.pth", help='模型路径') + parser.add_argument('-c', '--config_path', type=str, default="logs/44k/config.json", help='配置文件路径') + parser.add_argument('-cl', '--clip', type=float, default=0, help='音频强制切片,默认0为自动切片,单位为秒/s') + parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["君の知らない物語-src.wav"], help='wav文件名列表,放在raw文件夹下') + parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0], help='音高调整,支持正负(半音)') + parser.add_argument('-s', '--spk_list', type=str, nargs='+', default=['buyizi'], help='合成目标说话人名称') + + # 可选项部分 + parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=False, help='语音转换自动预测音高,转换歌声时不要打开这个会严重跑调') + parser.add_argument('-cm', '--cluster_model_path', type=str, default="", help='聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填') + parser.add_argument('-cr', '--cluster_infer_ratio', type=float, default=0, help='聚类方案或特征检索占比,范围0-1,若没有训练聚类模型或特征检索则默认0即可') + parser.add_argument('-lg', '--linear_gradient', type=float, default=0, help='两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值0,单位为秒') + parser.add_argument('-f0p', '--f0_predictor', type=str, default="pm", help='选择F0预测器,可选择crepe,pm,dio,harvest,rmvpe,fcpe默认为pm(注意:crepe为原F0使用均值滤波器)') + parser.add_argument('-eh', '--enhance', action='store_true', default=False, help='是否使用NSF_HIFIGAN增强器,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭') + parser.add_argument('-shd', '--shallow_diffusion', action='store_true', default=False, help='是否使用浅层扩散,使用后可解决一部分电音问题,默认关闭,该选项打开时,NSF_HIFIGAN增强器将会被禁止') + parser.add_argument('-usm', '--use_spk_mix', action='store_true', default=False, help='是否使用角色融合') + parser.add_argument('-lea', '--loudness_envelope_adjustment', type=float, default=1, help='输入源响度包络替换输出响度包络融合比例,越靠近1越使用输出响度包络') + parser.add_argument('-fr', '--feature_retrieval', action='store_true', default=False, help='是否使用特征检索,如果使用聚类模型将被禁用,且cm与cr参数将会变成特征检索的索引路径与混合比例') + + # 浅扩散设置 + parser.add_argument('-dm', '--diffusion_model_path', type=str, default="logs/44k/diffusion/model_0.pt", help='扩散模型路径') + parser.add_argument('-dc', '--diffusion_config_path', type=str, default="logs/44k/diffusion/config.yaml", help='扩散模型配置文件路径') + parser.add_argument('-ks', '--k_step', type=int, default=100, help='扩散步数,越大越接近扩散模型的结果,默认100') + parser.add_argument('-se', '--second_encoding', action='store_true', default=False, help='二次编码,浅扩散前会对原始音频进行二次编码,玄学选项,有时候效果好,有时候效果差') + parser.add_argument('-od', '--only_diffusion', action='store_true', default=False, help='纯扩散模式,该模式不会加载sovits模型,以扩散模型推理') + + + # 不用动的部分 + parser.add_argument('-sd', '--slice_db', type=int, default=-40, help='默认-40,嘈杂的音频可以-30,干声保留呼吸可以-50') + parser.add_argument('-d', '--device', type=str, default=None, help='推理设备,None则为自动选择cpu和gpu') + parser.add_argument('-ns', '--noice_scale', type=float, default=0.4, help='噪音级别,会影响咬字和音质,较为玄学') + parser.add_argument('-p', '--pad_seconds', type=float, default=0.5, help='推理音频pad秒数,由于未知原因开头结尾会有异响,pad一小段静音段后就不会出现') + parser.add_argument('-wf', '--wav_format', type=str, default='flac', help='音频输出格式') + parser.add_argument('-lgr', '--linear_gradient_retain', type=float, default=0.75, help='自动音频切片后,需要舍弃每段切片的头尾。该参数设置交叉长度保留的比例,范围0-1,左开右闭') + parser.add_argument('-eak', '--enhancer_adaptive_key', type=int, default=0, help='使增强器适应更高的音域(单位为半音数)|默认为0') + parser.add_argument('-ft', '--f0_filter_threshold', type=float, default=0.05,help='F0过滤阈值,只有使用crepe时有效. 数值范围从0-1. 降低该值可减少跑调概率,但会增加哑音') + + + args = parser.parse_args() + + clean_names = args.clean_names + trans = args.trans + spk_list = args.spk_list + slice_db = args.slice_db + wav_format = args.wav_format + auto_predict_f0 = args.auto_predict_f0 + cluster_infer_ratio = args.cluster_infer_ratio + noice_scale = args.noice_scale + pad_seconds = args.pad_seconds + clip = args.clip + lg = args.linear_gradient + lgr = args.linear_gradient_retain + f0p = args.f0_predictor + enhance = args.enhance + enhancer_adaptive_key = args.enhancer_adaptive_key + cr_threshold = args.f0_filter_threshold + diffusion_model_path = args.diffusion_model_path + diffusion_config_path = args.diffusion_config_path + k_step = args.k_step + only_diffusion = args.only_diffusion + shallow_diffusion = args.shallow_diffusion + use_spk_mix = args.use_spk_mix + second_encoding = args.second_encoding + loudness_envelope_adjustment = args.loudness_envelope_adjustment + + if cluster_infer_ratio != 0: + if args.cluster_model_path == "": + if args.feature_retrieval: # 若指定了占比但没有指定模型路径,则按是否使用特征检索分配默认的模型路径 + args.cluster_model_path = "logs/44k/feature_and_index.pkl" + else: + args.cluster_model_path = "logs/44k/kmeans_10000.pt" + else: # 若未指定占比,则无论是否指定模型路径,都将其置空以避免之后的模型加载 + args.cluster_model_path = "" + + svc_model = Svc(args.model_path, + args.config_path, + args.device, + args.cluster_model_path, + enhance, + diffusion_model_path, + diffusion_config_path, + shallow_diffusion, + only_diffusion, + use_spk_mix, + args.feature_retrieval) + + infer_tool.mkdir(["raw", "results"]) + + if len(spk_mix_map)<=1: + use_spk_mix = False + if use_spk_mix: + spk_list = [spk_mix_map] + + infer_tool.fill_a_to_b(trans, clean_names) + for clean_name, tran in zip(clean_names, trans): + raw_audio_path = f"raw/{clean_name}" + if "." not in raw_audio_path: + raw_audio_path += ".wav" + infer_tool.format_wav(raw_audio_path) + for spk in spk_list: + kwarg = { + "raw_audio_path" : raw_audio_path, + "spk" : spk, + "tran" : tran, + "slice_db" : slice_db, + "cluster_infer_ratio" : cluster_infer_ratio, + "auto_predict_f0" : auto_predict_f0, + "noice_scale" : noice_scale, + "pad_seconds" : pad_seconds, + "clip_seconds" : clip, + "lg_num": lg, + "lgr_num" : lgr, + "f0_predictor" : f0p, + "enhancer_adaptive_key" : enhancer_adaptive_key, + "cr_threshold" : cr_threshold, + "k_step":k_step, + "use_spk_mix":use_spk_mix, + "second_encoding":second_encoding, + "loudness_envelope_adjustment":loudness_envelope_adjustment + } + audio = svc_model.slice_inference(**kwarg) + key = "auto" if auto_predict_f0 else f"{tran}key" + cluster_name = "" if cluster_infer_ratio == 0 else f"_{cluster_infer_ratio}" + isdiffusion = "sovits" + if shallow_diffusion : + isdiffusion = "sovdiff" + if only_diffusion : + isdiffusion = "diff" + if use_spk_mix: + spk = "spk_mix" + res_path = f'results/{clean_name}_{key}_{spk}{cluster_name}_{isdiffusion}_{f0p}.{wav_format}' + soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format) + svc_model.clear_empty() + +if __name__ == '__main__': + main() diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..24338fa2c1f6c15e60f5f341c7e3df2301f74eb8 --- /dev/null +++ b/models.py @@ -0,0 +1,533 @@ +import torch +from torch import nn +from torch.nn import Conv1d, Conv2d +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +import modules.attentions as attentions +import modules.commons as commons +import modules.modules as modules +import utils +from modules.commons import get_padding +from utils import f0_to_coarse + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + share_parameter=False + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=gin_channels) if share_parameter else None + + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, + gin_channels=gin_channels, mean_only=True, wn_sharing_parameter=self.wn)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + +class TransformerCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + share_parameter=False + ): + + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None + + for i in range(n_flows): + self.flows.append( + modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + # print(x.shape,x_lengths.shape) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + kernel_size, + n_layers, + gin_channels=0, + filter_channels=None, + n_heads=None, + p_dropout=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.gin_channels = gin_channels + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + self.f0_emb = nn.Embedding(256, hidden_channels) + + self.enc_ = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + + def forward(self, x, x_mask, f0=None, noice_scale=1): + x = x + self.f0_emb(f0).transpose(1, 2) + x = self.enc_(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask + + return z, m, logs, x_mask + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SpeakerEncoder(torch.nn.Module): + def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): + super(SpeakerEncoder, self).__init__() + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): + mel_slices = [] + for i in range(0, total_frames - partial_frames, partial_hop): + mel_range = torch.arange(i, i + partial_frames) + mel_slices.append(mel_range) + + return mel_slices + + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): + mel_len = mel.size(1) + last_mel = mel[:, -partial_frames:] + + if mel_len > partial_frames: + mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) + mels = list(mel[:, s] for s in mel_slices) + mels.append(last_mel) + mels = torch.stack(tuple(mels), 0).squeeze(1) + + with torch.no_grad(): + partial_embeds = self(mels) + embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) + # embed = embed / torch.linalg.norm(embed, 2) + else: + with torch.no_grad(): + embed = self(last_mel) + + return embed + +class F0Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=0): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, norm_f0, x_mask, spk_emb=None): + x = torch.detach(x) + if (spk_emb is not None): + x = x + self.cond(spk_emb) + x += self.f0_prenet(norm_f0) + x = self.prenet(x) * x_mask + x = self.decoder(x * x_mask, x_mask) + x = self.proj(x) * x_mask + return x + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + n_speakers, + sampling_rate=44100, + vol_embedding=False, + vocoder_name = "nsf-hifigan", + use_depthwise_conv = False, + use_automatic_f0_prediction = True, + flow_share_parameter = False, + n_flow_layer = 4, + n_layers_trans_flow = 3, + use_transformer_flow = False, + **kwargs): + + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.ssl_dim = ssl_dim + self.vol_embedding = vol_embedding + self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.use_depthwise_conv = use_depthwise_conv + self.use_automatic_f0_prediction = use_automatic_f0_prediction + self.n_layers_trans_flow = n_layers_trans_flow + if vol_embedding: + self.emb_vol = nn.Linear(1, hidden_channels) + + self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) + + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels=filter_channels, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout + ) + hps = { + "sampling_rate": sampling_rate, + "inter_channels": inter_channels, + "resblock": resblock, + "resblock_kernel_sizes": resblock_kernel_sizes, + "resblock_dilation_sizes": resblock_dilation_sizes, + "upsample_rates": upsample_rates, + "upsample_initial_channel": upsample_initial_channel, + "upsample_kernel_sizes": upsample_kernel_sizes, + "gin_channels": gin_channels, + "use_depthwise_conv":use_depthwise_conv + } + + modules.set_Conv1dModel(self.use_depthwise_conv) + + if vocoder_name == "nsf-hifigan": + from vdecoder.hifigan.models import Generator + self.dec = Generator(h=hps) + elif vocoder_name == "nsf-snake-hifigan": + from vdecoder.hifiganwithsnake.models import Generator + self.dec = Generator(h=hps) + else: + print("[?] Unkown vocoder: use default(nsf-hifigan)") + from vdecoder.hifigan.models import Generator + self.dec = Generator(h=hps) + + self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + if use_transformer_flow: + self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) + else: + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) + if self.use_automatic_f0_prediction: + self.f0_decoder = F0Decoder( + 1, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=gin_channels + ) + self.emb_uv = nn.Embedding(2, hidden_channels) + self.character_mix = False + + def EnableCharacterMix(self, n_speakers_map, device): + self.speaker_map = torch.zeros((n_speakers_map, 1, 1, self.gin_channels)).to(device) + for i in range(n_speakers_map): + self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]).to(device)) + self.speaker_map = self.speaker_map.unsqueeze(0).to(device) + self.character_mix = True + + def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None, vol = None): + g = self.emb_g(g).transpose(1,2) + + # vol proj + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 + + # ssl prenet + x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol + + # f0 predict + if self.use_automatic_f0_prediction: + lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 + norm_lf0 = utils.normalize_f0(lf0, x_mask, uv) + pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) + else: + lf0 = 0 + norm_lf0 = 0 + pred_lf0 = 0 + # encoder + z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0)) + z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) + + # flow + z_p = self.flow(z, spec_mask, g=g) + z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(z, f0, spec_lengths, self.segment_size) + + # nsf decoder + o = self.dec(z_slice, g=g, f0=pitch_slice) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 + + @torch.no_grad() + def infer(self, c, f0, uv, g=None, noice_scale=0.35, seed=52468, predict_f0=False, vol = None): + + if c.device == torch.device("cuda"): + torch.cuda.manual_seed_all(seed) + else: + torch.manual_seed(seed) + + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + + if self.character_mix and len(g) > 1: # [N, S] * [S, B, 1, H] + g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + else: + if g.dim() == 1: + g = g.unsqueeze(0) + g = self.emb_g(g).transpose(1, 2) + + x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) + # vol proj + + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 + + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol + + + if self.use_automatic_f0_prediction and predict_f0: + lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 + norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) + pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) + f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) + + z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g, f0=f0) + return o,f0 + diff --git a/modules/DSConv.py b/modules/DSConv.py new file mode 100644 index 0000000000000000000000000000000000000000..44c2bf60e9cd2b837ca95fb6436768782057014a --- /dev/null +++ b/modules/DSConv.py @@ -0,0 +1,76 @@ +import torch.nn as nn +from torch.nn.utils import remove_weight_norm, weight_norm + + +class Depthwise_Separable_Conv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + dilation = 1, + bias = True, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight') + self.point_conv = remove_weight_norm(self.point_conv, name = 'weight') + +class Depthwise_Separable_TransposeConv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + output_padding = 0, + bias = True, + dilation = 1, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + remove_weight_norm(self.depth_conv, name = 'weight') + remove_weight_norm(self.point_conv, name = 'weight') + + +def weight_norm_modules(module, name = 'weight', dim = 0): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.weight_norm() + return module + else: + return weight_norm(module,name,dim) + +def remove_weight_norm_modules(module, name = 'weight'): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.remove_weight_norm() + else: + remove_weight_norm(module,name) \ No newline at end of file diff --git a/modules/F0Predictor/CrepeF0Predictor.py b/modules/F0Predictor/CrepeF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..c0854b64ed3bff96ed3381a7ef666c784aefd995 --- /dev/null +++ b/modules/F0Predictor/CrepeF0Predictor.py @@ -0,0 +1,34 @@ +import torch + +from modules.F0Predictor.crepe import CrepePitchExtractor +from modules.F0Predictor.F0Predictor import F0Predictor + + +class CrepeF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"): + self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + self.device = device + self.threshold = threshold + self.sampling_rate = sampling_rate + self.name = "crepe" + + def compute_f0(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len) + return f0 + + def compute_f0_uv(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len) + return f0,uv \ No newline at end of file diff --git a/modules/F0Predictor/DioF0Predictor.py b/modules/F0Predictor/DioF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..178dd2e8a02b79e5af113300f00d6a4dc2fb2a07 --- /dev/null +++ b/modules/F0Predictor/DioF0Predictor.py @@ -0,0 +1,74 @@ +import numpy as np +import pyworld + +from modules.F0Predictor.F0Predictor import F0Predictor + + +class DioF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + self.sampling_rate = sampling_rate + self.name = "dio" + + def interpolate_f0(self,f0): + ''' + 对F0进行插值处理 + ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector + + def resize_f0(self,x, target_len): + source = np.array(x) + source[source<0.001] = np.nan + target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) + res = np.nan_to_num(target) + return res + + def compute_f0(self,wav,p_len=None): + if p_len is None: + p_len = wav.shape[0]//self.hop_length + f0, t = pyworld.dio( + wav.astype(np.double), + fs=self.sampling_rate, + f0_floor=self.f0_min, + f0_ceil=self.f0_max, + frame_period=1000 * self.hop_length / self.sampling_rate, + ) + f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate) + for index, pitch in enumerate(f0): + f0[index] = round(pitch, 1) + return self.interpolate_f0(self.resize_f0(f0, p_len))[0] + + def compute_f0_uv(self,wav,p_len=None): + if p_len is None: + p_len = wav.shape[0]//self.hop_length + f0, t = pyworld.dio( + wav.astype(np.double), + fs=self.sampling_rate, + f0_floor=self.f0_min, + f0_ceil=self.f0_max, + frame_period=1000 * self.hop_length / self.sampling_rate, + ) + f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate) + for index, pitch in enumerate(f0): + f0[index] = round(pitch, 1) + return self.interpolate_f0(self.resize_f0(f0, p_len)) diff --git a/modules/F0Predictor/F0Predictor.py b/modules/F0Predictor/F0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..69d8a9bd28729e33d092a5af8e2ce544c1330c3b --- /dev/null +++ b/modules/F0Predictor/F0Predictor.py @@ -0,0 +1,16 @@ +class F0Predictor(object): + def compute_f0(self,wav,p_len): + ''' + input: wav:[signal_length] + p_len:int + output: f0:[signal_length//hop_length] + ''' + pass + + def compute_f0_uv(self,wav,p_len): + ''' + input: wav:[signal_length] + p_len:int + output: f0:[signal_length//hop_length],uv:[signal_length//hop_length] + ''' + pass \ No newline at end of file diff --git a/modules/F0Predictor/FCPEF0Predictor.py b/modules/F0Predictor/FCPEF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..91913c75a7b33d77a154c48cb9482ddd43393a6a --- /dev/null +++ b/modules/F0Predictor/FCPEF0Predictor.py @@ -0,0 +1,109 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn.functional as F + +from modules.F0Predictor.F0Predictor import F0Predictor + +from .fcpe.model import FCPEInfer + + +class FCPEF0Predictor(F0Predictor): + def __init__(self, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sampling_rate=44100, + threshold=0.05): + self.fcpe = FCPEInfer(model_path="pretrain/fcpe.pt", device=device, dtype=dtype) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.threshold = threshold + self.sampling_rate = sampling_rate + self.dtype = dtype + self.name = "fcpe" + + def repeat_expand( + self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" + ): + ndim = content.ndim + + if content.ndim == 1: + content = content[None, None] + elif content.ndim == 2: + content = content[None] + + assert content.ndim == 3 + + is_np = isinstance(content, np.ndarray) + if is_np: + content = torch.from_numpy(content) + + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + + if is_np: + results = results.numpy() + + if ndim == 1: + return results[0, 0] + elif ndim == 2: + return results[0] + + def post_process(self, x, sampling_rate, f0, pad_to): + if isinstance(f0, np.ndarray): + f0 = torch.from_numpy(f0).float().to(x.device) + + if pad_to is None: + return f0 + + f0 = self.repeat_expand(f0, pad_to) + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + # 去掉0频率, 并线性插值 + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(), vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[ + 0]).cpu().numpy(), vuv_vector.cpu().numpy() + + # 大概可以用 torch 重写? + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + # vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + + return f0, vuv_vector.cpu().numpy() + + def compute_f0(self, wav, p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0] // self.hop_length + else: + assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" + f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn, rtn + return self.post_process(x, self.sampling_rate, f0, p_len)[0] + + def compute_f0_uv(self, wav, p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0] // self.hop_length + else: + assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" + f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn, rtn + return self.post_process(x, self.sampling_rate, f0, p_len) \ No newline at end of file diff --git a/modules/F0Predictor/HarvestF0Predictor.py b/modules/F0Predictor/HarvestF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..f36b332f7b42802918ce3e232a6609413394acf9 --- /dev/null +++ b/modules/F0Predictor/HarvestF0Predictor.py @@ -0,0 +1,69 @@ +import numpy as np +import pyworld + +from modules.F0Predictor.F0Predictor import F0Predictor + + +class HarvestF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + self.sampling_rate = sampling_rate + self.name = "harvest" + + def interpolate_f0(self,f0): + ''' + 对F0进行插值处理 + ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector + def resize_f0(self,x, target_len): + source = np.array(x) + source[source<0.001] = np.nan + target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) + res = np.nan_to_num(target) + return res + + def compute_f0(self,wav,p_len=None): + if p_len is None: + p_len = wav.shape[0]//self.hop_length + f0, t = pyworld.harvest( + wav.astype(np.double), + fs=self.hop_length, + f0_ceil=self.f0_max, + f0_floor=self.f0_min, + frame_period=1000 * self.hop_length / self.sampling_rate, + ) + f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.fs) + return self.interpolate_f0(self.resize_f0(f0, p_len))[0] + + def compute_f0_uv(self,wav,p_len=None): + if p_len is None: + p_len = wav.shape[0]//self.hop_length + f0, t = pyworld.harvest( + wav.astype(np.double), + fs=self.sampling_rate, + f0_floor=self.f0_min, + f0_ceil=self.f0_max, + frame_period=1000 * self.hop_length / self.sampling_rate, + ) + f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate) + return self.interpolate_f0(self.resize_f0(f0, p_len)) diff --git a/modules/F0Predictor/PMF0Predictor.py b/modules/F0Predictor/PMF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..2af3f6e7ee7c5c4e10899f9988e1d9b92aa52157 --- /dev/null +++ b/modules/F0Predictor/PMF0Predictor.py @@ -0,0 +1,72 @@ +import numpy as np +import parselmouth + +from modules.F0Predictor.F0Predictor import F0Predictor + + +class PMF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + self.sampling_rate = sampling_rate + self.name = "pm" + + def interpolate_f0(self,f0): + ''' + 对F0进行插值处理 + ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector + + + def compute_f0(self,wav,p_len=None): + x = wav + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + time_step = self.hop_length / self.sampling_rate * 1000 + f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac( + time_step=time_step / 1000, voicing_threshold=0.6, + pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency'] + + pad_size=(p_len - len(f0) + 1) // 2 + if(pad_size>0 or p_len - len(f0) - pad_size>0): + f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') + f0,uv = self.interpolate_f0(f0) + return f0 + + def compute_f0_uv(self,wav,p_len=None): + x = wav + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + time_step = self.hop_length / self.sampling_rate * 1000 + f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac( + time_step=time_step / 1000, voicing_threshold=0.6, + pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency'] + + pad_size=(p_len - len(f0) + 1) // 2 + if(pad_size>0 or p_len - len(f0) - pad_size>0): + f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') + f0,uv = self.interpolate_f0(f0) + return f0,uv diff --git a/modules/F0Predictor/RMVPEF0Predictor.py b/modules/F0Predictor/RMVPEF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..9313887be084e99059e6c76fffba323de1f3c835 --- /dev/null +++ b/modules/F0Predictor/RMVPEF0Predictor.py @@ -0,0 +1,107 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn.functional as F + +from modules.F0Predictor.F0Predictor import F0Predictor + +from .rmvpe import RMVPE + + +class RMVPEF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05): + self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.threshold = threshold + self.sampling_rate = sampling_rate + self.dtype = dtype + self.name = "rmvpe" + + def repeat_expand( + self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" + ): + ndim = content.ndim + + if content.ndim == 1: + content = content[None, None] + elif content.ndim == 2: + content = content[None] + + assert content.ndim == 3 + + is_np = isinstance(content, np.ndarray) + if is_np: + content = torch.from_numpy(content) + + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + + if is_np: + results = results.numpy() + + if ndim == 1: + return results[0, 0] + elif ndim == 2: + return results[0] + + def post_process(self, x, sampling_rate, f0, pad_to): + if isinstance(f0, np.ndarray): + f0 = torch.from_numpy(f0).float().to(x.device) + + if pad_to is None: + return f0 + + f0 = self.repeat_expand(f0, pad_to) + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + # 去掉0频率, 并线性插值 + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(),vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0]).cpu().numpy() ,vuv_vector.cpu().numpy() + + # 大概可以用 torch 重写? + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + + return f0,vuv_vector.cpu().numpy() + + def compute_f0(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn,rtn + return self.post_process(x,self.sampling_rate,f0,p_len)[0] + + def compute_f0_uv(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn,rtn + return self.post_process(x,self.sampling_rate,f0,p_len) \ No newline at end of file diff --git a/modules/F0Predictor/__init__.py b/modules/F0Predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/F0Predictor/crepe.py b/modules/F0Predictor/crepe.py new file mode 100644 index 0000000000000000000000000000000000000000..e68f19cb39eb79931926ffd312fb61e30bf39d72 --- /dev/null +++ b/modules/F0Predictor/crepe.py @@ -0,0 +1,340 @@ +from typing import Optional, Union + +try: + from typing import Literal +except Exception: + from typing_extensions import Literal +import numpy as np +import torch +import torchcrepe +from torch import nn +from torch.nn import functional as F + +#from:https://github.com/fishaudio/fish-diffusion + +def repeat_expand( + content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" +): + """Repeat content to target length. + This is a wrapper of torch.nn.functional.interpolate. + + Args: + content (torch.Tensor): tensor + target_len (int): target length + mode (str, optional): interpolation mode. Defaults to "nearest". + + Returns: + torch.Tensor: tensor + """ + + ndim = content.ndim + + if content.ndim == 1: + content = content[None, None] + elif content.ndim == 2: + content = content[None] + + assert content.ndim == 3 + + is_np = isinstance(content, np.ndarray) + if is_np: + content = torch.from_numpy(content) + + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + + if is_np: + results = results.numpy() + + if ndim == 1: + return results[0, 0] + elif ndim == 2: + return results[0] + + +class BasePitchExtractor: + def __init__( + self, + hop_length: int = 512, + f0_min: float = 50.0, + f0_max: float = 1100.0, + keep_zeros: bool = True, + ): + """Base pitch extractor. + + Args: + hop_length (int, optional): Hop length. Defaults to 512. + f0_min (float, optional): Minimum f0. Defaults to 50.0. + f0_max (float, optional): Maximum f0. Defaults to 1100.0. + keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True. + """ + + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + self.keep_zeros = keep_zeros + + def __call__(self, x, sampling_rate=44100, pad_to=None): + raise NotImplementedError("BasePitchExtractor is not callable.") + + def post_process(self, x, sampling_rate, f0, pad_to): + if isinstance(f0, np.ndarray): + f0 = torch.from_numpy(f0).float().to(x.device) + + if pad_to is None: + return f0 + + f0 = repeat_expand(f0, pad_to) + + if self.keep_zeros: + return f0 + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + # 去掉0频率, 并线性插值 + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy() + + # 大概可以用 torch 重写? + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + + return f0,vuv_vector.cpu().numpy() + + +class MaskedAvgPool1d(nn.Module): + def __init__( + self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 + ): + """An implementation of mean pooling that supports masked values. + + Args: + kernel_size (int): The size of the median pooling window. + stride (int, optional): The stride of the median pooling window. Defaults to None. + padding (int, optional): The padding of the median pooling window. Defaults to 0. + """ + + super(MaskedAvgPool1d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride or kernel_size + self.padding = padding + + def forward(self, x, mask=None): + ndim = x.dim() + if ndim == 2: + x = x.unsqueeze(1) + + assert ( + x.dim() == 3 + ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" + + # Apply the mask by setting masked elements to zero, or make NaNs zero + if mask is None: + mask = ~torch.isnan(x) + + # Ensure mask has the same shape as the input tensor + assert x.shape == mask.shape, "Input tensor and mask must have the same shape" + + masked_x = torch.where(mask, x, torch.zeros_like(x)) + # Create a ones kernel with the same number of channels as the input tensor + ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device) + + # Perform sum pooling + sum_pooled = nn.functional.conv1d( + masked_x, + ones_kernel, + stride=self.stride, + padding=self.padding, + groups=x.size(1), + ) + + # Count the non-masked (valid) elements in each pooling window + valid_count = nn.functional.conv1d( + mask.float(), + ones_kernel, + stride=self.stride, + padding=self.padding, + groups=x.size(1), + ) + valid_count = valid_count.clamp(min=1) # Avoid division by zero + + # Perform masked average pooling + avg_pooled = sum_pooled / valid_count + + # Fill zero values with NaNs + avg_pooled[avg_pooled == 0] = float("nan") + + if ndim == 2: + return avg_pooled.squeeze(1) + + return avg_pooled + + +class MaskedMedianPool1d(nn.Module): + def __init__( + self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 + ): + """An implementation of median pooling that supports masked values. + + This implementation is inspired by the median pooling implementation in + https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 + + Args: + kernel_size (int): The size of the median pooling window. + stride (int, optional): The stride of the median pooling window. Defaults to None. + padding (int, optional): The padding of the median pooling window. Defaults to 0. + """ + + super(MaskedMedianPool1d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride or kernel_size + self.padding = padding + + def forward(self, x, mask=None): + ndim = x.dim() + if ndim == 2: + x = x.unsqueeze(1) + + assert ( + x.dim() == 3 + ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" + + if mask is None: + mask = ~torch.isnan(x) + + assert x.shape == mask.shape, "Input tensor and mask must have the same shape" + + masked_x = torch.where(mask, x, torch.zeros_like(x)) + + x = F.pad(masked_x, (self.padding, self.padding), mode="reflect") + mask = F.pad( + mask.float(), (self.padding, self.padding), mode="constant", value=0 + ) + + x = x.unfold(2, self.kernel_size, self.stride) + mask = mask.unfold(2, self.kernel_size, self.stride) + + x = x.contiguous().view(x.size()[:3] + (-1,)) + mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device) + + # Combine the mask with the input tensor + #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf"))) + x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device)) + + # Sort the masked tensor along the last dimension + x_sorted, _ = torch.sort(x_masked, dim=-1) + + # Compute the count of non-masked (valid) values + valid_count = mask.sum(dim=-1) + + # Calculate the index of the median value for each pooling window + median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0) + + # Gather the median values using the calculated indices + median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) + + # Fill infinite values with NaNs + median_pooled[torch.isinf(median_pooled)] = float("nan") + + if ndim == 2: + return median_pooled.squeeze(1) + + return median_pooled + + +class CrepePitchExtractor(BasePitchExtractor): + def __init__( + self, + hop_length: int = 512, + f0_min: float = 50.0, + f0_max: float = 1100.0, + threshold: float = 0.05, + keep_zeros: bool = False, + device = None, + model: Literal["full", "tiny"] = "full", + use_fast_filters: bool = True, + decoder="viterbi" + ): + super().__init__(hop_length, f0_min, f0_max, keep_zeros) + if decoder == "viterbi": + self.decoder = torchcrepe.decode.viterbi + elif decoder == "argmax": + self.decoder = torchcrepe.decode.argmax + elif decoder == "weighted_argmax": + self.decoder = torchcrepe.decode.weighted_argmax + else: + raise "Unknown decoder" + self.threshold = threshold + self.model = model + self.use_fast_filters = use_fast_filters + self.hop_length = hop_length + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + if self.use_fast_filters: + self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device) + self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device) + + def __call__(self, x, sampling_rate=44100, pad_to=None): + """Extract pitch using crepe. + + + Args: + x (torch.Tensor): Audio signal, shape (1, T). + sampling_rate (int, optional): Sampling rate. Defaults to 44100. + pad_to (int, optional): Pad to length. Defaults to None. + + Returns: + torch.Tensor: Pitch, shape (T // hop_length,). + """ + + assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor." + assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels." + + x = x.to(self.dev) + f0, pd = torchcrepe.predict( + x, + sampling_rate, + self.hop_length, + self.f0_min, + self.f0_max, + pad=True, + model=self.model, + batch_size=1024, + device=x.device, + return_periodicity=True, + decoder=self.decoder + ) + + # Filter, remove silence, set uv threshold, refer to the original warehouse readme + if self.use_fast_filters: + pd = self.median_filter(pd) + else: + pd = torchcrepe.filter.median(pd, 3) + + pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, self.hop_length) + f0 = torchcrepe.threshold.At(self.threshold)(f0, pd) + + if self.use_fast_filters: + f0 = self.mean_filter(f0) + else: + f0 = torchcrepe.filter.mean(f0, 3) + + f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0] + + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if pad_to is None else np.zeros(pad_to) + return rtn,rtn + + return self.post_process(x, sampling_rate, f0, pad_to) diff --git a/modules/F0Predictor/fcpe/__init__.py b/modules/F0Predictor/fcpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d580715a8e22d38a3e291b58c31c3d8f634ac83 --- /dev/null +++ b/modules/F0Predictor/fcpe/__init__.py @@ -0,0 +1,3 @@ +from .model import FCPEInfer # noqa: F401 +from .nvSTFT import STFT # noqa: F401 +from .pcmer import PCmer # noqa: F401 diff --git a/modules/F0Predictor/fcpe/model.py b/modules/F0Predictor/fcpe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..200c2fa546744bdd1ca9bfa2feaa550633c5038a --- /dev/null +++ b/modules/F0Predictor/fcpe/model.py @@ -0,0 +1,262 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +from torchaudio.transforms import Resample + +from .nvSTFT import STFT +from .pcmer import PCmer + + +def l2_regularization(model, l2_alpha): + l2_loss = [] + for module in model.modules(): + if type(module) is nn.Conv2d: + l2_loss.append((module.weight ** 2).sum() / 2.0) + return l2_alpha * sum(l2_loss) + + +class FCPE(nn.Module): + def __init__( + self, + input_channel=128, + out_dims=360, + n_layers=12, + n_chans=512, + use_siren=False, + use_full=False, + loss_mse_scale=10, + loss_l2_regularization=False, + loss_l2_regularization_scale=1, + loss_grad1_mse=False, + loss_grad1_mse_scale=1, + f0_max=1975.5, + f0_min=32.70, + confidence=False, + threshold=0.05, + use_input_conv=True + ): + super().__init__() + if use_siren is True: + raise ValueError("Siren is not supported yet.") + if use_full is True: + raise ValueError("Full model is not supported yet.") + + self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10 + self.loss_l2_regularization = loss_l2_regularization if (loss_l2_regularization is not None) else False + self.loss_l2_regularization_scale = loss_l2_regularization_scale if (loss_l2_regularization_scale + is not None) else 1 + self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False + self.loss_grad1_mse_scale = loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1 + self.f0_max = f0_max if (f0_max is not None) else 1975.5 + self.f0_min = f0_min if (f0_min is not None) else 32.70 + self.confidence = confidence if (confidence is not None) else False + self.threshold = threshold if (threshold is not None) else 0.05 + self.use_input_conv = use_input_conv if (use_input_conv is not None) else True + + self.cent_table_b = torch.Tensor( + np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], + out_dims)) + self.register_buffer("cent_table", self.cent_table_b) + + # conv in stack + _leaky = nn.LeakyReLU() + self.stack = nn.Sequential( + nn.Conv1d(input_channel, n_chans, 3, 1, 1), + nn.GroupNorm(4, n_chans), + _leaky, + nn.Conv1d(n_chans, n_chans, 3, 1, 1)) + + # transformer + self.decoder = PCmer( + num_layers=n_layers, + num_heads=8, + dim_model=n_chans, + dim_keys=n_chans, + dim_values=n_chans, + residual_dropout=0.1, + attention_dropout=0.1) + self.norm = nn.LayerNorm(n_chans) + + # out + self.n_out = out_dims + self.dense_out = weight_norm( + nn.Linear(n_chans, self.n_out)) + + def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder = "local_argmax"): + """ + input: + B x n_frames x n_unit + return: + dict of B x n_frames x feat + """ + if cdecoder == "argmax": + self.cdecoder = self.cents_decoder + elif cdecoder == "local_argmax": + self.cdecoder = self.cents_local_decoder + if self.use_input_conv: + x = self.stack(mel.transpose(1, 2)).transpose(1, 2) + else: + x = mel + x = self.decoder(x) + x = self.norm(x) + x = self.dense_out(x) # [B,N,D] + x = torch.sigmoid(x) + if not infer: + gt_cent_f0 = self.f0_to_cent(gt_f0) # mel f0 #[B,N,1] + gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0) # #[B,N,out_dim] + loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0) # bce loss + # l2 regularization + if self.loss_l2_regularization: + loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale) + x = loss_all + if infer: + x = self.cdecoder(x) + x = self.cent_to_f0(x) + if not return_hz_f0: + x = (1 + x / 700).log() + return x + + def cents_decoder(self, y, mask=True): + B, N, _ = y.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) # cents: [B,N,1] + if mask: + confident = torch.max(y, dim=-1, keepdim=True)[0] + confident_mask = torch.ones_like(confident) + confident_mask[confident <= self.threshold] = float("-INF") + rtn = rtn * confident_mask + if self.confidence: + return rtn, confident + else: + return rtn + + def cents_local_decoder(self, y, mask=True): + B, N, _ = y.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + confident, max_index = torch.max(y, dim=-1, keepdim=True) + local_argmax_index = torch.arange(0,9).to(max_index.device) + (max_index - 4) + local_argmax_index[local_argmax_index<0] = 0 + local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1 + ci_l = torch.gather(ci,-1,local_argmax_index) + y_l = torch.gather(y,-1,local_argmax_index) + rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1] + if mask: + confident_mask = torch.ones_like(confident) + confident_mask[confident <= self.threshold] = float("-INF") + rtn = rtn * confident_mask + if self.confidence: + return rtn, confident + else: + return rtn + + def cent_to_f0(self, cent): + return 10. * 2 ** (cent / 1200.) + + def f0_to_cent(self, f0): + return 1200. * torch.log2(f0 / 10.) + + def gaussian_blurred_cent(self, cents): # cents: [B,N,1] + mask = (cents > 0.1) & (cents < (1200. * np.log2(self.f0_max / 10.))) + B, N, _ = cents.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + return torch.exp(-torch.square(ci - cents) / 1250) * mask.float() + + +class FCPEInfer: + def __init__(self, model_path, device=None, dtype=torch.float32): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + ckpt = torch.load(model_path, map_location=torch.device(self.device)) + self.args = DotDict(ckpt["config"]) + self.dtype = dtype + model = FCPE( + input_channel=self.args.model.input_channel, + out_dims=self.args.model.out_dims, + n_layers=self.args.model.n_layers, + n_chans=self.args.model.n_chans, + use_siren=self.args.model.use_siren, + use_full=self.args.model.use_full, + loss_mse_scale=self.args.loss.loss_mse_scale, + loss_l2_regularization=self.args.loss.loss_l2_regularization, + loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, + loss_grad1_mse=self.args.loss.loss_grad1_mse, + loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, + f0_max=self.args.model.f0_max, + f0_min=self.args.model.f0_min, + confidence=self.args.model.confidence, + ) + model.to(self.device).to(self.dtype) + model.load_state_dict(ckpt['model']) + model.eval() + self.model = model + self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device) + + @torch.no_grad() + def __call__(self, audio, sr, threshold=0.05): + self.model.threshold = threshold + audio = audio[None,:] + mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype) + f0 = self.model(mel=mel, infer=True, return_hz_f0=True) + return f0 + + +class Wav2Mel: + + def __init__(self, args, device=None, dtype=torch.float32): + # self.args = args + self.sampling_rate = args.mel.sampling_rate + self.hop_size = args.mel.hop_size + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.dtype = dtype + self.stft = STFT( + args.mel.sampling_rate, + args.mel.num_mels, + args.mel.n_fft, + args.mel.win_size, + args.mel.hop_size, + args.mel.fmin, + args.mel.fmax + ) + self.resample_kernel = {} + + def extract_nvstft(self, audio, keyshift=0, train=False): + mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2) # B, n_frames, bins + return mel + + def extract_mel(self, audio, sample_rate, keyshift=0, train=False): + audio = audio.to(self.dtype).to(self.device) + # resample + if sample_rate == self.sampling_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + # extract + mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train) # B, n_frames, bins + n_frames = int(audio.shape[1] // self.hop_size) + 1 + if n_frames > int(mel.shape[1]): + mel = torch.cat((mel, mel[:, -1:, :]), 1) + if n_frames < int(mel.shape[1]): + mel = mel[:, :n_frames, :] + return mel + + def __call__(self, audio, sample_rate, keyshift=0, train=False): + return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train) + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ diff --git a/modules/F0Predictor/fcpe/nvSTFT.py b/modules/F0Predictor/fcpe/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b97435f8977d659f594b41fa3f8993ee85f02ee9 --- /dev/null +++ b/modules/F0Predictor/fcpe/nvSTFT.py @@ -0,0 +1,133 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 48000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 48000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, keyshift=0, speed=1, center=False, train=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(n_fft * factor)) + win_size_new = int(np.round(win_size * factor)) + hop_length_new = int(np.round(hop_length * speed)) + if not train: + mel_basis = self.mel_basis + hann_window = self.hann_window + else: + mel_basis = {} + hann_window = {} + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + mel_basis_key = str(fmax)+'_'+str(y.device) + if mel_basis_key not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) + + keyshift_key = str(keyshift)+'_'+str(y.device) + if keyshift_key not in hann_window: + hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) + + pad_left = (win_size_new - hop_length_new) //2 + pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) + if pad_right < y.size(-1): + mode = 'reflect' + else: + mode = 'constant' + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) + y = y.squeeze(1) + + spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9)) + if keyshift != 0: + size = n_fft // 2 + 1 + resize = spec.size(1) + if resize < size: + spec = F.pad(spec, (0, 0, 0, size-resize)) + spec = spec[:, :size, :] * win_size / win_size_new + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/modules/F0Predictor/fcpe/pcmer.py b/modules/F0Predictor/fcpe/pcmer.py new file mode 100644 index 0000000000000000000000000000000000000000..5c12678007ad62e1d370533fe37307226dc48492 --- /dev/null +++ b/modules/F0Predictor/fcpe/pcmer.py @@ -0,0 +1,369 @@ +import math +from functools import partial + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from local_attention import LocalAttention +from torch import nn + +#import fast_transformers.causal_product.causal_product_cuda + +def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): + b, h, *_ = data.shape + # (batch size, head, length, model_dim) + + # normalize model dim + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + + # what is ration?, projection_matrix.shape[0] --> 266 + + ratio = (projection_matrix.shape[0] ** -0.5) + + projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = projection.type_as(data) + + #data_dash = w^T x + data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + + + # diag_data = D**2 + diag_data = data ** 2 + diag_data = torch.sum(diag_data, dim=-1) + diag_data = (diag_data / 2.0) * (data_normalizer ** 2) + diag_data = diag_data.unsqueeze(dim=-1) + + #print () + if is_query: + data_dash = ratio * ( + torch.exp(data_dash - diag_data - + torch.max(data_dash, dim=-1, keepdim=True).values) + eps) + else: + data_dash = ratio * ( + torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps) + + return data_dash.type_as(data) + +def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None): + unstructured_block = torch.randn((cols, cols), device = device) + q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced') + q, r = map(lambda t: t.to(device), (q, r)) + + # proposed by @Parskatt + # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf + if qr_uniform_q: + d = torch.diag(r, 0) + q *= d.sign() + return q.t() +def exists(val): + return val is not None + +def empty(tensor): + return tensor.numel() == 0 + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val): + return (val,) if not isinstance(val, tuple) else val + +class PCmer(nn.Module): + """The encoder that is used in the Transformer model.""" + + def __init__(self, + num_layers, + num_heads, + dim_model, + dim_keys, + dim_values, + residual_dropout, + attention_dropout): + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + self.dim_model = dim_model + self.dim_values = dim_values + self.dim_keys = dim_keys + self.residual_dropout = residual_dropout + self.attention_dropout = attention_dropout + + self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)]) + + # METHODS ######################################################################################################## + + def forward(self, phone, mask=None): + + # apply all layers to the input + for (i, layer) in enumerate(self._layers): + phone = layer(phone, mask) + # provide the final sequence + return phone + + +# ==================================================================================================================== # +# CLASS _ E N C O D E R L A Y E R # +# ==================================================================================================================== # + + +class _EncoderLayer(nn.Module): + """One layer of the encoder. + + Attributes: + attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence. + feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism. + """ + + def __init__(self, parent: PCmer): + """Creates a new instance of ``_EncoderLayer``. + + Args: + parent (Encoder): The encoder that the layers is created for. + """ + super().__init__() + + + self.conformer = ConformerConvModule(parent.dim_model) + self.norm = nn.LayerNorm(parent.dim_model) + self.dropout = nn.Dropout(parent.residual_dropout) + + # selfatt -> fastatt: performer! + self.attn = SelfAttention(dim = parent.dim_model, + heads = parent.num_heads, + causal = False) + + # METHODS ######################################################################################################## + + def forward(self, phone, mask=None): + + # compute attention sub-layer + phone = phone + (self.attn(self.norm(phone), mask=mask)) + + phone = phone + (self.conformer(phone)) + + return phone + +def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + +# helper classes + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + +class Transpose(nn.Module): + def __init__(self, dims): + super().__init__() + assert len(dims) == 2, 'dims must be a tuple of two dimensions' + self.dims = dims + + def forward(self, x): + return x.transpose(*self.dims) + +class GLU(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + out, gate = x.chunk(2, dim=self.dim) + return out * gate.sigmoid() + +class DepthWiseConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size, padding): + super().__init__() + self.padding = padding + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) + + def forward(self, x): + x = F.pad(x, self.padding) + return self.conv(x) + +class ConformerConvModule(nn.Module): + def __init__( + self, + dim, + causal = False, + expansion_factor = 2, + kernel_size = 31, + dropout = 0.): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Transpose((1, 2)), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), + #nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Transpose((1, 2)), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +def linear_attention(q, k, v): + if v is None: + #print (k.size(), q.size()) + out = torch.einsum('...ed,...nd->...ne', k, q) + return out + + else: + k_cumsum = k.sum(dim = -2) + #k_cumsum = k.sum(dim = -2) + D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8) + + context = torch.einsum('...nd,...ne->...de', k, v) + #print ("TRUEEE: ", context.size(), q.size(), D_inv.size()) + out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) + return out + +def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None): + nb_full_blocks = int(nb_rows / nb_columns) + #print (nb_full_blocks) + block_list = [] + + for _ in range(nb_full_blocks): + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) + block_list.append(q) + # block_list[n] is a orthogonal matrix ... (model_dim * model_dim) + #print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1))) + #print (nb_rows, nb_full_blocks, nb_columns) + remaining_rows = nb_rows - nb_full_blocks * nb_columns + #print (remaining_rows) + if remaining_rows > 0: + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) + #print (q[:remaining_rows].size()) + block_list.append(q[:remaining_rows]) + + final_matrix = torch.cat(block_list) + + if scaling == 0: + multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) + elif scaling == 1: + multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) + else: + raise ValueError(f'Invalid scaling {scaling}') + + return torch.diag(multiplier) @ final_matrix + +class FastAttention(nn.Module): + def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False): + super().__init__() + nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + + self.dim_heads = dim_heads + self.nb_features = nb_features + self.ortho_scaling = ortho_scaling + + self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q) + projection_matrix = self.create_projection() + self.register_buffer('projection_matrix', projection_matrix) + + self.generalized_attention = generalized_attention + self.kernel_fn = kernel_fn + + # if this is turned on, no projection will be used + # queries and keys will be softmax-ed as in the original efficient attention paper + self.no_projection = no_projection + + self.causal = causal + + @torch.no_grad() + def redraw_projection_matrix(self): + projections = self.create_projection() + self.projection_matrix.copy_(projections) + del projections + + def forward(self, q, k, v): + device = q.device + + if self.no_projection: + q = q.softmax(dim = -1) + k = torch.exp(k) if self.causal else k.softmax(dim = -2) + else: + create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) + + q = create_kernel(q, is_query = True) + k = create_kernel(k, is_query = False) + + attn_fn = linear_attention if not self.causal else self.causal_linear_fn + if v is None: + out = attn_fn(q, k, None) + return out + else: + out = attn_fn(q, k, v) + return out +class SelfAttention(nn.Module): + def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False): + super().__init__() + assert dim % heads == 0, 'dimension must be divisible by number of heads' + dim_head = default(dim_head, dim // heads) + inner_dim = dim_head * heads + self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection) + + self.heads = heads + self.global_heads = heads - local_heads + self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None + + #print (heads, nb_features, dim_head) + #name_embedding = torch.zeros(110, heads, dim_head, dim_head) + #self.name_embedding = nn.Parameter(name_embedding, requires_grad=True) + + + self.to_q = nn.Linear(dim, inner_dim) + self.to_k = nn.Linear(dim, inner_dim) + self.to_v = nn.Linear(dim, inner_dim) + self.to_out = nn.Linear(inner_dim, dim) + self.dropout = nn.Dropout(dropout) + + @torch.no_grad() + def redraw_projection_matrix(self): + self.fast_attention.redraw_projection_matrix() + #torch.nn.init.zeros_(self.name_embedding) + #print (torch.sum(self.name_embedding)) + def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs): + _, _, _, h, gh = *x.shape, self.heads, self.global_heads + + cross_attend = exists(context) + + context = default(context, x) + context_mask = default(context_mask, mask) if not cross_attend else context_mask + #print (torch.sum(self.name_embedding)) + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) + + attn_outs = [] + #print (name) + #print (self.name_embedding[name].size()) + if not empty(q): + if exists(context_mask): + global_mask = context_mask[:, None, :, None] + v.masked_fill_(~global_mask, 0.) + if cross_attend: + pass + #print (torch.sum(self.name_embedding)) + #out = self.fast_attention(q,self.name_embedding[name],None) + #print (torch.sum(self.name_embedding[...,-1:])) + else: + out = self.fast_attention(q, k, v) + attn_outs.append(out) + + if not empty(lq): + assert not cross_attend, 'local attention is not compatible with cross attention' + out = self.local_attn(lq, lk, lv, input_mask = mask) + attn_outs.append(out) + + out = torch.cat(attn_outs, dim = 1) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return self.dropout(out) \ No newline at end of file diff --git a/modules/F0Predictor/rmvpe/__init__.py b/modules/F0Predictor/rmvpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2dcf9e971ac4fcea29fe2e312d591fd0447f95d --- /dev/null +++ b/modules/F0Predictor/rmvpe/__init__.py @@ -0,0 +1,10 @@ +from .constants import * # noqa: F403 +from .inference import RMVPE # noqa: F401 +from .model import E2E, E2E0 # noqa: F401 +from .spec import MelSpectrogram # noqa: F401 +from .utils import ( # noqa: F401 + cycle, + summary, + to_local_average_cents, + to_viterbi_cents, +) diff --git a/modules/F0Predictor/rmvpe/constants.py b/modules/F0Predictor/rmvpe/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f52efc9b40f49bb746dae6807a817bffce4375 --- /dev/null +++ b/modules/F0Predictor/rmvpe/constants.py @@ -0,0 +1,9 @@ +SAMPLE_RATE = 16000 + +N_CLASS = 360 + +N_MELS = 128 +MEL_FMIN = 30 +MEL_FMAX = SAMPLE_RATE // 2 +WINDOW_LENGTH = 1024 +CONST = 1997.3794084376191 diff --git a/modules/F0Predictor/rmvpe/deepunet.py b/modules/F0Predictor/rmvpe/deepunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b0171d562ac58526c7693a15124e181c78ad0a18 --- /dev/null +++ b/modules/F0Predictor/rmvpe/deepunet.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn + +from .constants import N_MELS + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + + nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + if self.is_shortcut: + return self.conv(x) + self.shortcut(x) + else: + return self.conv(x) + x + + +class ResEncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): + super(ResEncoderBlock, self).__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for i in range(self.n_blocks): + x = self.conv[i](x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks-1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i in range(self.n_blocks): + x = self.conv2[i](x) + return x + + +class Encoder(nn.Module): + def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): + super(Encoder, self).__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for i in range(self.n_encoders): + _, x = self.layers[i](x) + concat_tensors.append(_) + return x, concat_tensors + + +class Intermediate(nn.Module): + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) + for i in range(self.n_inters-1): + self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) + + def forward(self, x): + for i in range(self.n_inters): + x = self.layers[i](x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) + in_channels = out_channels + + def forward(self, x, concat_tensors): + for i in range(self.n_decoders): + x = self.layers[i](x, concat_tensors[-1-i]) + return x + + +class TimbreFilter(nn.Module): + def __init__(self, latent_rep_channels): + super(TimbreFilter, self).__init__() + self.layers = nn.ModuleList() + for latent_rep in latent_rep_channels: + self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) + + def forward(self, x_tensors): + out_tensors = [] + for i, layer in enumerate(self.layers): + out_tensors.append(layer(x_tensors[i])) + return out_tensors + + +class DeepUnet(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + concat_tensors = self.tf(concat_tensors) + x = self.decoder(x, concat_tensors) + return x + + +class DeepUnet0(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet0, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x diff --git a/modules/F0Predictor/rmvpe/inference.py b/modules/F0Predictor/rmvpe/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..02d21881e5ccbf969759f4ef8030abce3083ce8c --- /dev/null +++ b/modules/F0Predictor/rmvpe/inference.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from torchaudio.transforms import Resample + +from .constants import * # noqa: F403 +from .model import E2E0 +from .spec import MelSpectrogram +from .utils import to_local_average_cents, to_viterbi_cents + + +class RMVPE: + def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160): + self.resample_kernel = {} + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + model = E2E0(4, 1, (2, 2)) + ckpt = torch.load(model_path, map_location=torch.device(self.device)) + model.load_state_dict(ckpt['model']) + model = model.to(dtype).to(self.device) + model.eval() + self.model = model + self.dtype = dtype + self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 + self.resample_kernel = {} + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03, use_viterbi=False): + if use_viterbi: + cents_pred = to_viterbi_cents(hidden, thred=thred) + else: + cents_pred = to_local_average_cents(hidden, thred=thred) + f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device) + return f0 + + def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False): + audio = audio.unsqueeze(0).to(self.dtype).to(self.device) + if sample_rate == 16000: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + mel_extractor = self.mel_extractor.to(self.device) + mel = mel_extractor(audio_res, center=True).to(self.dtype) + hidden = self.mel2hidden(mel) + f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi) + return f0 diff --git a/modules/F0Predictor/rmvpe/model.py b/modules/F0Predictor/rmvpe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b6b643b113a0eee9a9142016c15444273002c5 --- /dev/null +++ b/modules/F0Predictor/rmvpe/model.py @@ -0,0 +1,67 @@ +from torch import nn + +from .constants import * # noqa: F403 +from .deepunet import DeepUnet, DeepUnet0 +from .seq import BiGRU +from .spec import MelSpectrogram + + +class E2E(nn.Module): + def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E, self).__init__() + self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 + self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 + nn.Linear(512, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, x): + mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + # x = self.fc(x) + hidden_vec = 0 + if len(self.fc) == 4: + for i in range(len(self.fc)): + x = self.fc[i](x) + if i == 0: + hidden_vec = x + return hidden_vec, x + + +class E2E0(nn.Module): + def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E0, self).__init__() + self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 + nn.Linear(512, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + return x diff --git a/modules/F0Predictor/rmvpe/seq.py b/modules/F0Predictor/rmvpe/seq.py new file mode 100644 index 0000000000000000000000000000000000000000..0d48e49d72e14d34f048ca0b5824ea1f335e9a0d --- /dev/null +++ b/modules/F0Predictor/rmvpe/seq.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + + +class BiLSTM(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiLSTM, self).__init__() + self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.lstm(x)[0] + diff --git a/modules/F0Predictor/rmvpe/spec.py b/modules/F0Predictor/rmvpe/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..349d05e4541ccad31cbbb24372a89cda7c0aacc0 --- /dev/null +++ b/modules/F0Predictor/rmvpe/spec.py @@ -0,0 +1,67 @@ +import numpy as np +import torch +import torch.nn.functional as F +from librosa.filters import mel + + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp = 1e-5 + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = win_length if n_fft is None else n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + + keyshift_key = str(keyshift)+'_'+str(audio.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) + + fft = torch.stft( + audio, + n_fft=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window=self.hann_window[keyshift_key], + center=center, + return_complex=True) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec \ No newline at end of file diff --git a/modules/F0Predictor/rmvpe/utils.py b/modules/F0Predictor/rmvpe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4395255f8608da2bce0b1f15d6bd2b2bd02c1fe7 --- /dev/null +++ b/modules/F0Predictor/rmvpe/utils.py @@ -0,0 +1,107 @@ +import sys +from functools import reduce + +import librosa +import numpy as np +import torch +from torch.nn.modules.module import _addindent + +from .constants import * # noqa: F403 + + +def cycle(iterable): + while True: + for item in iterable: + yield item + + +def summary(model, file=sys.stdout): + def repr(model): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for name, p in model._parameters.items(): + if hasattr(p, 'shape'): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + if file is sys.stdout: + main_str += ', \033[92m{:,}\033[0m params'.format(total_params) + else: + main_str += ', {:,} params'.format(total_params) + return main_str, total_params + + string, count = repr(model) + if file is not None: + if isinstance(file, str): + file = open(file, 'w') + print(string, file=file) + file.flush() + + return count + + +def to_local_average_cents(salience, center=None, thred=0.05): + """ + find the weighted average cents near the argmax bin + """ + + if not hasattr(to_local_average_cents, 'cents_mapping'): + # the bin number-to-cents mapping + to_local_average_cents.cents_mapping = ( + 20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405 + + if salience.ndim == 1: + if center is None: + center = int(torch.argmax(salience)) + start = max(0, center - 4) + end = min(len(salience), center + 5) + salience = salience[start:end] + product_sum = torch.sum( + salience * to_local_average_cents.cents_mapping[start:end]) + weight_sum = torch.sum(salience) + return product_sum / weight_sum if torch.max(salience) > thred else 0 + if salience.ndim == 2: + return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in + range(salience.shape[0])]).to(salience.device) + + raise Exception("label should be either 1d or 2d ndarray") + +def to_viterbi_cents(salience, thred=0.05): + # Create viterbi transition matrix + if not hasattr(to_viterbi_cents, 'transition'): + xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405 + transition = torch.maximum(30 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + to_viterbi_cents.transition = transition + + # Convert to probability + prob = salience.T + prob = prob / prob.sum(axis=0) + + # Perform viterbi decoding + path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64) + + return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in + range(len(path))]).to(salience.device) + \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/attentions.py b/modules/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d75bc65e45f8e27460c18e0d267605a752f013 --- /dev/null +++ b/modules/attentions.py @@ -0,0 +1,363 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +import modules.commons as commons +from modules.DSConv import weight_norm_modules +from modules.modules import LayerNorm + + +class FFT(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., + proximal_bias=False, proximal_init=True, isflow = False, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + if isflow: + cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name='weight') + self.gin_channels = kwargs["gin_channels"] + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, + proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g = None): + """ + x: decoder input + h: encoder output + """ + if g is not None: + g = self.cond_layer(g) + + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + x = x * x_mask + for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + x = commons.fused_add_tanh_sigmoid_multiply( + x, + g_l, + torch.IntTensor([self.hidden_channels])) + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + x = x * x_mask + return x + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length -1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/modules/commons.py b/modules/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..761379da55793b7f2eca1c9ba511ec767ac1d90e --- /dev/null +++ b/modules/commons.py @@ -0,0 +1,183 @@ +import math + +import torch +from torch.nn import functional as F + + +def slice_pitch_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + +def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) + return ret, ret_pitch, ids_str + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if "Depthwise_Separable" in classname: + m.depth_conv.weight.data.normal_(mean, std) + m.point_conv.weight.data.normal_(mean, std) + elif classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def rand_spec_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/modules/enhancer.py b/modules/enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f0dd0460ff6d6153f9277dfa90763bc03861db --- /dev/null +++ b/modules/enhancer.py @@ -0,0 +1,107 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torchaudio.transforms import Resample + +from vdecoder.nsf_hifigan.models import load_model +from vdecoder.nsf_hifigan.nvSTFT import STFT + + +class Enhancer: + def __init__(self, enhancer_type, enhancer_ckpt, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + + if enhancer_type == 'nsf-hifigan': + self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device) + else: + raise ValueError(f" [x] Unknown enhancer: {enhancer_type}") + + self.resample_kernel = {} + self.enhancer_sample_rate = self.enhancer.sample_rate() + self.enhancer_hop_size = self.enhancer.hop_size() + + def enhance(self, + audio, # 1, T + sample_rate, + f0, # 1, n_frames, 1 + hop_size, + adaptive_key = 0, + silence_front = 0 + ): + # enhancer start time + start_frame = int(silence_front * sample_rate / hop_size) + real_silence_front = start_frame * hop_size / sample_rate + audio = audio[:, int(np.round(real_silence_front * sample_rate)) : ] + f0 = f0[: , start_frame :, :] + + # adaptive parameters + adaptive_factor = 2 ** ( -adaptive_key / 12) + adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100)) + real_factor = self.enhancer_sample_rate / adaptive_sample_rate + + # resample the ddsp output + if sample_rate == adaptive_sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + str(adaptive_sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width = 128).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + + n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1) + + # resample f0 + f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy() + f0_np *= real_factor + time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor + time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames) + f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1]) + f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames + + # enhance + enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res) + + # resample the enhanced output + if adaptive_factor != 0: + key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width = 128).to(self.device) + enhanced_audio = self.resample_kernel[key_str](enhanced_audio) + + # pad the silence frames + if start_frame > 0: + enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0)) + + return enhanced_audio, enhancer_sample_rate + + +class NsfHifiGAN(torch.nn.Module): + def __init__(self, model_path, device=None): + super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + print('| Load HifiGAN: ', model_path) + self.model, self.h = load_model(model_path, device=self.device) + + def sample_rate(self): + return self.h.sampling_rate + + def hop_size(self): + return self.h.hop_size + + def forward(self, audio, f0): + stft = STFT( + self.h.sampling_rate, + self.h.num_mels, + self.h.n_fft, + self.h.win_size, + self.h.hop_size, + self.h.fmin, + self.h.fmax) + with torch.no_grad(): + mel = stft.get_mel(audio) + enhanced_audio = self.model(mel, f0[:,:mel.size(-1)]).view(-1) + return enhanced_audio, self.h.sampling_rate \ No newline at end of file diff --git a/modules/losses.py b/modules/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..494e979a60ba069114cac609bf6454a99c1019e3 --- /dev/null +++ b/modules/losses.py @@ -0,0 +1,58 @@ +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + #print(logs_p) + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/modules/mel_processing.py b/modules/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..c21e4bffb6d9f5fd7b45a84176b3e6206f7d29db --- /dev/null +++ b/modules/mel_processing.py @@ -0,0 +1,83 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + y_dtype = y.dtype + if y.dtype == torch.bfloat16: + y = y.to(torch.float32) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec).to(y_dtype) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + '_' + str(spec.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + + return spec diff --git a/modules/modules.py b/modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a622d4f264a8d89a62a1b549efa71f4c37eb7ca1 --- /dev/null +++ b/modules/modules.py @@ -0,0 +1,356 @@ +import torch +from torch import nn +from torch.nn import functional as F + +import modules.attentions as attentions +import modules.commons as commons +from modules.commons import get_padding, init_weights +from modules.DSConv import ( + Depthwise_Separable_Conv1D, + remove_weight_norm_modules, + weight_norm_modules, +) + +LRELU_SLOPE = 0.1 + +Conv1dModel = nn.Conv1d + +def set_Conv1dModel(use_depthwise_conv): + global Conv1dModel + Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers-1): + self.conv_layers.append(Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + self.hidden_channels =hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) + self.cond_layer = weight_norm_modules(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = Conv1dModel(hidden_channels, 2*hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = weight_norm_modules(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm_modules(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:,:self.hidden_channels,:] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:,self.hidden_channels:,:] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + remove_weight_norm_modules(self.cond_layer) + for l in self.in_layers: + remove_weight_norm_modules(l) + for l in self.res_skip_layers: + remove_weight_norm_modules(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm_modules(l) + for l in self.convs2: + remove_weight_norm_modules(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm_modules(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels,1)) + self.logs = nn.Parameter(torch.zeros(channels,1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1,2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + wn_sharing_parameter=None + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + +class TransformerCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels = 0 + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/onnx_export.py b/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d18f4c537acf2cad8e69f964384153cce53475 --- /dev/null +++ b/onnx_export.py @@ -0,0 +1,144 @@ +import argparse +import json + +import torch + +import utils +from onnxexport.model_onnx_speaker_mix import SynthesizerTrn + +parser = argparse.ArgumentParser(description='SoVitsSvc OnnxExport') + +def OnnxExport(path=None): + device = torch.device("cpu") + hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") + SVCVITS = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model) + _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) + _ = SVCVITS.eval().to(device) + for i in SVCVITS.parameters(): + i.requires_grad = False + + num_frames = 200 + + test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels) + test_pitch = torch.rand(1, num_frames) + test_vol = torch.rand(1, num_frames) + test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0) + test_uv = torch.ones(1, num_frames, dtype=torch.float32) + test_noise = torch.randn(1, 192, num_frames) + test_sid = torch.LongTensor([0]) + export_mix = True + if len(hps.spk) < 2: + export_mix = False + + if export_mix: + spk_mix = [] + n_spk = len(hps.spk) + for i in range(n_spk): + spk_mix.append(1.0/float(n_spk)) + test_sid = torch.tensor(spk_mix) + SVCVITS.export_chara_mix(hps.spk) + test_sid = test_sid.unsqueeze(0) + test_sid = test_sid.repeat(num_frames, 1) + + SVCVITS.eval() + + if export_mix: + daxes = { + "c": [0, 1], + "f0": [1], + "mel2ph": [1], + "uv": [1], + "noise": [2], + "sid":[0] + } + else: + daxes = { + "c": [0, 1], + "f0": [1], + "mel2ph": [1], + "uv": [1], + "noise": [2] + } + + input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] + output_names = ["audio", ] + + if SVCVITS.vol_embedding: + input_names.append("vol") + vol_dadict = {"vol" : [1]} + daxes.update(vol_dadict) + test_inputs = ( + test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device), + test_vol.to(device) + ) + else: + test_inputs = ( + test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device) + ) + + # SVCVITS = torch.jit.script(SVCVITS) + SVCVITS(test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device), + test_vol.to(device)) + + SVCVITS.dec.OnnxExport() + + torch.onnx.export( + SVCVITS, + test_inputs, + f"checkpoints/{path}/{path}_SoVits.onnx", + dynamic_axes=daxes, + do_constant_folding=False, + opset_version=16, + verbose=False, + input_names=input_names, + output_names=output_names + ) + + vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9" + spklist = [] + for key in hps.spk.keys(): + spklist.append(key) + + MoeVSConf = { + "Folder" : f"{path}", + "Name" : f"{path}", + "Type" : "SoVits", + "Rate" : hps.data.sampling_rate, + "Hop" : hps.data.hop_length, + "Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}", + "SoVits4": True, + "SoVits3": False, + "CharaMix": export_mix, + "Volume": SVCVITS.vol_embedding, + "HiddenSize": SVCVITS.gin_channels, + "Characters": spklist, + "Cluster": "" + } + + with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile: + json.dump(MoeVSConf, MoeVsConfFile, indent = 4) + + +if __name__ == '__main__': + parser.add_argument('-n', '--model_name', type=str, default="TransformerFlow", help='模型文件夹名(根目录下新建ckeckpoints文件夹,在此文件夹下建立一个新的文件夹,放置模型,该文件夹名即为此项)') + args = parser.parse_args() + path = args.model_name + OnnxExport(path) diff --git a/onnx_export_old.py b/onnx_export_old.py new file mode 100644 index 0000000000000000000000000000000000000000..27f49ddedcb9e0381e70ff51003412358eb566ab --- /dev/null +++ b/onnx_export_old.py @@ -0,0 +1,56 @@ +import torch + +import utils +from onnxexport.model_onnx import SynthesizerTrn + + +def main(NetExport): + path = "SoVits4.0" + if NetExport: + device = torch.device("cpu") + hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") + SVCVITS = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model) + _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) + _ = SVCVITS.eval().to(device) + for i in SVCVITS.parameters(): + i.requires_grad = False + + n_frame = 10 + test_hidden_unit = torch.rand(1, n_frame, 256) + test_pitch = torch.rand(1, n_frame) + test_mel2ph = torch.arange(0, n_frame, dtype=torch.int64)[None] # torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0) + test_uv = torch.ones(1, n_frame, dtype=torch.float32) + test_noise = torch.randn(1, 192, n_frame) + test_sid = torch.LongTensor([0]) + input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] + output_names = ["audio", ] + + torch.onnx.export(SVCVITS, + ( + test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device) + ), + f"checkpoints/{path}/model.onnx", + dynamic_axes={ + "c": [0, 1], + "f0": [1], + "mel2ph": [1], + "uv": [1], + "noise": [2], + }, + do_constant_folding=False, + opset_version=16, + verbose=False, + input_names=input_names, + output_names=output_names) + + +if __name__ == '__main__': + main(True) diff --git a/onnxexport/model_onnx.py b/onnxexport/model_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..0f83c03d8da1cc44972e09349a539f3da114f425 --- /dev/null +++ b/onnxexport/model_onnx.py @@ -0,0 +1,333 @@ +import torch +from torch import nn +from torch.nn import Conv1d, Conv2d +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +import modules.attentions as attentions +import modules.commons as commons +import modules.modules as modules +import utils +from modules.commons import get_padding +from utils import f0_to_coarse +from vdecoder.hifigan.models import Generator + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, + gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + # print(x.shape,x_lengths.shape) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + kernel_size, + n_layers, + gin_channels=0, + filter_channels=None, + n_heads=None, + p_dropout=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.gin_channels = gin_channels + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + self.f0_emb = nn.Embedding(256, hidden_channels) + + self.enc_ = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + + def forward(self, x, x_mask, f0=None, z=None): + x = x + self.f0_emb(f0).transpose(1, 2) + x = self.enc_(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + z * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class F0Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=0): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, norm_f0, x_mask, spk_emb=None): + x = torch.detach(x) + if spk_emb is not None: + x = x + self.cond(spk_emb) + x += self.f0_prenet(norm_f0) + x = self.prenet(x) * x_mask + x = self.decoder(x * x_mask, x_mask) + x = self.proj(x) * x_mask + return x + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + n_speakers, + sampling_rate=44100, + **kwargs): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.ssl_dim = ssl_dim + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) + + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels=filter_channels, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout + ) + hps = { + "sampling_rate": sampling_rate, + "inter_channels": inter_channels, + "resblock": resblock, + "resblock_kernel_sizes": resblock_kernel_sizes, + "resblock_dilation_sizes": resblock_dilation_sizes, + "upsample_rates": upsample_rates, + "upsample_initial_channel": upsample_initial_channel, + "upsample_kernel_sizes": upsample_kernel_sizes, + "gin_channels": gin_channels, + } + self.dec = Generator(h=hps) + self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + self.f0_decoder = F0Decoder( + 1, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=gin_channels + ) + self.emb_uv = nn.Embedding(2, hidden_channels) + self.predict_f0 = False + + def forward(self, c, f0, mel2ph, uv, noise=None, g=None): + + decoder_inp = F.pad(c, [0, 0, 1, 0]) + mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]]) + c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H] + + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + g = g.unsqueeze(0) + g = self.emb_g(g).transpose(1, 2) + x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + + if self.predict_f0: + lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 + norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) + pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) + f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) + + z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g, f0=f0) + return o diff --git a/onnxexport/model_onnx_speaker_mix.py b/onnxexport/model_onnx_speaker_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..5e95bce15a17560289e50e4d2344137e2771a1a9 --- /dev/null +++ b/onnxexport/model_onnx_speaker_mix.py @@ -0,0 +1,366 @@ +import torch +from torch import nn +from torch.nn import functional as F + +import modules.attentions as attentions +import modules.commons as commons +import modules.modules as modules +import utils +from utils import f0_to_coarse + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + share_parameter=False + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=gin_channels) if share_parameter else None + + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, + gin_channels=gin_channels, mean_only=True, wn_sharing_parameter=self.wn)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + +class TransformerCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + share_parameter=False + ): + + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None + + for i in range(n_flows): + self.flows.append( + modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + # print(x.shape,x_lengths.shape) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + kernel_size, + n_layers, + gin_channels=0, + filter_channels=None, + n_heads=None, + p_dropout=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.gin_channels = gin_channels + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + self.f0_emb = nn.Embedding(256, hidden_channels) + + self.enc_ = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + + def forward(self, x, x_mask, f0=None, z=None): + x = x + self.f0_emb(f0).transpose(1, 2) + x = self.enc_(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + z * torch.exp(logs)) * x_mask + + return z, m, logs, x_mask + + +class F0Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=0): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, norm_f0, x_mask, spk_emb=None): + x = torch.detach(x) + if (spk_emb is not None): + x = x + self.cond(spk_emb) + x += self.f0_prenet(norm_f0) + x = self.prenet(x) * x_mask + x = self.decoder(x * x_mask, x_mask) + x = self.proj(x) * x_mask + return x + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + n_speakers, + sampling_rate=44100, + vol_embedding=False, + vocoder_name = "nsf-hifigan", + use_depthwise_conv = False, + use_automatic_f0_prediction = True, + flow_share_parameter = False, + n_flow_layer = 4, + n_layers_trans_flow = 3, + use_transformer_flow = False, + **kwargs): + + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.ssl_dim = ssl_dim + self.vol_embedding = vol_embedding + self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.use_depthwise_conv = use_depthwise_conv + self.use_automatic_f0_prediction = use_automatic_f0_prediction + self.n_layers_trans_flow = n_layers_trans_flow + if vol_embedding: + self.emb_vol = nn.Linear(1, hidden_channels) + + self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) + + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels=filter_channels, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout + ) + hps = { + "sampling_rate": sampling_rate, + "inter_channels": inter_channels, + "resblock": resblock, + "resblock_kernel_sizes": resblock_kernel_sizes, + "resblock_dilation_sizes": resblock_dilation_sizes, + "upsample_rates": upsample_rates, + "upsample_initial_channel": upsample_initial_channel, + "upsample_kernel_sizes": upsample_kernel_sizes, + "gin_channels": gin_channels, + "use_depthwise_conv":use_depthwise_conv + } + + modules.set_Conv1dModel(self.use_depthwise_conv) + + if vocoder_name == "nsf-hifigan": + from vdecoder.hifigan.models import Generator + self.dec = Generator(h=hps) + elif vocoder_name == "nsf-snake-hifigan": + from vdecoder.hifiganwithsnake.models import Generator + self.dec = Generator(h=hps) + else: + print("[?] Unkown vocoder: use default(nsf-hifigan)") + from vdecoder.hifigan.models import Generator + self.dec = Generator(h=hps) + + self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + if use_transformer_flow: + self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter=flow_share_parameter) + else: + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter=flow_share_parameter) + if self.use_automatic_f0_prediction: + self.f0_decoder = F0Decoder( + 1, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=gin_channels + ) + self.emb_uv = nn.Embedding(2, hidden_channels) + self.predict_f0 = False + self.speaker_map = [] + self.export_mix = False + + def export_chara_mix(self, speakers_mix): + self.speaker_map = torch.zeros((len(speakers_mix), 1, 1, self.gin_channels)) + i = 0 + for key in speakers_mix.keys(): + spkidx = speakers_mix[key] + self.speaker_map[i] = self.emb_g(torch.LongTensor([[spkidx]])) + i = i + 1 + self.speaker_map = self.speaker_map.unsqueeze(0) + self.export_mix = True + + def forward(self, c, f0, mel2ph, uv, noise=None, g=None, vol = None): + decoder_inp = F.pad(c, [0, 0, 1, 0]) + mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]]) + c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H] + + if self.export_mix: # [N, S] * [S, B, 1, H] + g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] + g = g * self.speaker_map # [N, S, B, 1, H] + g = torch.sum(g, dim=1) # [N, 1, B, 1, H] + g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] + else: + if g.dim() == 1: + g = g.unsqueeze(0) + g = self.emb_g(g).transpose(1, 2) + + x_mask = torch.unsqueeze(torch.ones_like(f0), 1).to(c.dtype) + # vol proj + + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 + + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol + + if self.use_automatic_f0_prediction and self.predict_f0: + lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 + norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) + pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) + f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) + + z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g, f0=f0) + return o + diff --git a/preprocess_flist_config.py b/preprocess_flist_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1335d7d7258a33f4020f67a4840c99f92b51affd --- /dev/null +++ b/preprocess_flist_config.py @@ -0,0 +1,119 @@ +import argparse +import json +import os +import re +import wave +from random import shuffle + +from loguru import logger +from tqdm import tqdm + +import diffusion.logger.utils as du + +pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') + +def get_wav_duration(file_path): + try: + with wave.open(file_path, 'rb') as wav_file: + # 获取音频帧数 + n_frames = wav_file.getnframes() + # 获取采样率 + framerate = wav_file.getframerate() + # 计算时长(秒) + return n_frames / float(framerate) + except Exception as e: + logger.error(f"Reading {file_path}") + raise e + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list") + parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list") + parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir") + parser.add_argument("--speech_encoder", type=str, default="vec768l12", help="choice a speech encoder|'vec768l12','vec256l9','hubertsoft','whisper-ppg','cnhubertlarge','dphubert','whisper-ppg-large','wavlmbase+'") + parser.add_argument("--vol_aug", action="store_true", help="Whether to use volume embedding and volume augmentation") + parser.add_argument("--tiny", action="store_true", help="Whether to train sovits tiny") + args = parser.parse_args() + + config_template = json.load(open("configs_template/config_tiny_template.json")) if args.tiny else json.load(open("configs_template/config_template.json")) + train = [] + val = [] + idx = 0 + spk_dict = {} + spk_id = 0 + + for speaker in tqdm(os.listdir(args.source_dir)): + spk_dict[speaker] = spk_id + spk_id += 1 + wavs = [] + + for file_name in os.listdir(os.path.join(args.source_dir, speaker)): + if not file_name.endswith("wav"): + continue + if file_name.startswith("."): + continue + + file_path = "/".join([args.source_dir, speaker, file_name]) + + if not pattern.match(file_name): + logger.warning("Detected non-ASCII file name: " + file_path) + + if get_wav_duration(file_path) < 0.3: + logger.info("Skip too short audio: " + file_path) + continue + + wavs.append(file_path) + + shuffle(wavs) + train += wavs[2:] + val += wavs[:2] + + shuffle(train) + shuffle(val) + + logger.info("Writing " + args.train_list) + with open(args.train_list, "w") as f: + for fname in tqdm(train): + wavpath = fname + f.write(wavpath + "\n") + + logger.info("Writing " + args.val_list) + with open(args.val_list, "w") as f: + for fname in tqdm(val): + wavpath = fname + f.write(wavpath + "\n") + + + d_config_template = du.load_config("configs_template/diffusion_template.yaml") + d_config_template["model"]["n_spk"] = spk_id + d_config_template["data"]["encoder"] = args.speech_encoder + d_config_template["spk"] = spk_dict + + config_template["spk"] = spk_dict + config_template["model"]["n_speakers"] = spk_id + config_template["model"]["speech_encoder"] = args.speech_encoder + + if args.speech_encoder == "vec768l12" or args.speech_encoder == "dphubert" or args.speech_encoder == "wavlmbase+": + config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768 + d_config_template["data"]["encoder_out_channels"] = 768 + elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft': + config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256 + d_config_template["data"]["encoder_out_channels"] = 256 + elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge': + config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024 + d_config_template["data"]["encoder_out_channels"] = 1024 + elif args.speech_encoder == "whisper-ppg-large": + config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1280 + d_config_template["data"]["encoder_out_channels"] = 1280 + + if args.vol_aug: + config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True + + if args.tiny: + config_template["model"]["filter_channels"] = 512 + + logger.info("Writing to configs/config.json") + with open("configs/config.json", "w") as f: + json.dump(config_template, f, indent=2) + logger.info("Writing to configs/diffusion.yaml") + du.save_config("configs/diffusion.yaml",d_config_template) diff --git a/preprocess_hubert_f0.py b/preprocess_hubert_f0.py new file mode 100644 index 0000000000000000000000000000000000000000..0c482104ca889aea8e0e65af23337988e034d684 --- /dev/null +++ b/preprocess_hubert_f0.py @@ -0,0 +1,172 @@ +import argparse +import logging +import os +import random +from concurrent.futures import ProcessPoolExecutor +from glob import glob +from random import shuffle + +import librosa +import numpy as np +import torch +import torch.multiprocessing as mp +from loguru import logger +from tqdm import tqdm + +import diffusion.logger.utils as du +import utils +from diffusion.vocoder import Vocoder +from modules.mel_processing import spectrogram_torch + +logging.getLogger("numba").setLevel(logging.WARNING) +logging.getLogger("matplotlib").setLevel(logging.WARNING) + +hps = utils.get_hparams_from_file("configs/config.json") +dconfig = du.load_config("configs/diffusion.yaml") +sampling_rate = hps.data.sampling_rate +hop_length = hps.data.hop_length +speech_encoder = hps["model"]["speech_encoder"] + + +def process_one(filename, hmodel, f0p, device, diff=False, mel_extractor=None): + wav, sr = librosa.load(filename, sr=sampling_rate) + audio_norm = torch.FloatTensor(wav) + audio_norm = audio_norm.unsqueeze(0) + soft_path = filename + ".soft.pt" + if not os.path.exists(soft_path): + wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000) + wav16k = torch.from_numpy(wav16k).to(device) + c = hmodel.encoder(wav16k) + torch.save(c.cpu(), soft_path) + + f0_path = filename + ".f0.npy" + if not os.path.exists(f0_path): + f0_predictor = utils.get_f0_predictor(f0p,sampling_rate=sampling_rate, hop_length=hop_length,device=None,threshold=0.05) + f0,uv = f0_predictor.compute_f0_uv( + wav + ) + np.save(f0_path, np.asanyarray((f0,uv),dtype=object)) + + + spec_path = filename.replace(".wav", ".spec.pt") + if not os.path.exists(spec_path): + # Process spectrogram + # The following code can't be replaced by torch.FloatTensor(wav) + # because load_wav_to_torch return a tensor that need to be normalized + + if sr != hps.data.sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format( + sr, hps.data.sampling_rate + ) + ) + + #audio_norm = audio / hps.data.max_wav_value + + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_path) + + if diff or hps.model.vol_embedding: + volume_path = filename + ".vol.npy" + volume_extractor = utils.Volume_Extractor(hop_length) + if not os.path.exists(volume_path): + volume = volume_extractor.extract(audio_norm) + np.save(volume_path, volume.to('cpu').numpy()) + + if diff: + mel_path = filename + ".mel.npy" + if not os.path.exists(mel_path) and mel_extractor is not None: + mel_t = mel_extractor.extract(audio_norm.to(device), sampling_rate) + mel = mel_t.squeeze().to('cpu').numpy() + np.save(mel_path, mel) + aug_mel_path = filename + ".aug_mel.npy" + aug_vol_path = filename + ".aug_vol.npy" + max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5 + max_shift = min(1, np.log10(1/max_amp)) + log10_vol_shift = random.uniform(-1, max_shift) + keyshift = random.uniform(-5, 5) + if mel_extractor is not None: + aug_mel_t = mel_extractor.extract(audio_norm * (10 ** log10_vol_shift), sampling_rate, keyshift = keyshift) + aug_mel = aug_mel_t.squeeze().to('cpu').numpy() + aug_vol = volume_extractor.extract(audio_norm * (10 ** log10_vol_shift)) + if not os.path.exists(aug_mel_path): + np.save(aug_mel_path,np.asanyarray((aug_mel,keyshift),dtype=object)) + if not os.path.exists(aug_vol_path): + np.save(aug_vol_path,aug_vol.to('cpu').numpy()) + + +def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu"): + logger.info("Loading speech encoder for content...") + rank = mp.current_process()._identity + rank = rank[0] if len(rank) > 0 else 0 + if torch.cuda.is_available(): + gpu_id = rank % torch.cuda.device_count() + device = torch.device(f"cuda:{gpu_id}") + logger.info(f"Rank {rank} uses device {device}") + hmodel = utils.get_speech_encoder(speech_encoder, device=device) + logger.info(f"Loaded speech encoder for rank {rank}") + for filename in tqdm(file_chunk, position = rank): + process_one(filename, hmodel, f0p, device, diff, mel_extractor) + +def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device): + with ProcessPoolExecutor(max_workers=num_processes) as executor: + tasks = [] + for i in range(num_processes): + start = int(i * len(filenames) / num_processes) + end = int((i + 1) * len(filenames) / num_processes) + file_chunk = filenames[start:end] + tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device)) + for task in tqdm(tasks, position = 0): + task.result() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--device', type=str, default=None) + parser.add_argument( + "--in_dir", type=str, default="dataset/44k", help="path to input dir" + ) + parser.add_argument( + '--use_diff',action='store_true', help='Whether to use the diffusion model' + ) + parser.add_argument( + '--f0_predictor', type=str, default="rmvpe", help='Select F0 predictor, can select crepe,pm,dio,harvest,rmvpe,fcpe|default: pm(note: crepe is original F0 using mean filter)' + ) + parser.add_argument( + '--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores' + ) + args = parser.parse_args() + f0p = args.f0_predictor + device = args.device + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + print(speech_encoder) + logger.info("Using device: " + str(device)) + logger.info("Using SpeechEncoder: " + speech_encoder) + logger.info("Using extractor: " + f0p) + logger.info("Using diff Mode: " + str(args.use_diff)) + + if args.use_diff: + print("use_diff") + print("Loading Mel Extractor...") + mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device=device) + print("Loaded Mel Extractor.") + else: + mel_extractor = None + filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10] + shuffle(filenames) + mp.set_start_method("spawn", force=True) + + num_processes = args.num_processes + if num_processes == 0: + num_processes = os.cpu_count() + + parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor, device) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f5d55ec9320fa61c7f04db0abdfed2864afd420 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +ffmpeg-python +Flask +Flask_Cors +gradio>=3.7.0 +numpy==1.23.5 +pyworld +scipy==1.10.0 +SoundFile==0.12.1 +torch +torchaudio +torchcrepe +tqdm +rich +loguru +scikit-maad +praat-parselmouth +onnx +onnxsim +onnxoptimizer +fairseq==0.12.2 +librosa==0.9.1 +tensorboard +tensorboardX +transformers +edge_tts +langdetect +pyyaml +pynvml +faiss-cpu +einops +local_attention \ No newline at end of file diff --git a/requirements_onnx_encoder.txt b/requirements_onnx_encoder.txt new file mode 100644 index 0000000000000000000000000000000000000000..cfde17cef5a6d38d33d41f61590b2b24cfb8ccd6 --- /dev/null +++ b/requirements_onnx_encoder.txt @@ -0,0 +1,29 @@ +Flask +Flask_Cors +gradio>=3.7.0 +numpy==1.23.0 +pyworld==0.2.5 +scipy==1.10.0 +SoundFile==0.12.1 +torch==1.13.1 +torchaudio==0.13.1 +torchcrepe +tqdm +rich.progress +loguru +scikit-maad +praat-parselmouth +onnx +onnxsim +onnxoptimizer +onnxruntime-gpu +librosa==0.9.1 +tensorboard +tensorboardX +edge_tts +langdetect +pyyaml +pynvml +transformers +ffmpeg-python +faiss-cpu \ No newline at end of file diff --git a/requirements_win.txt b/requirements_win.txt new file mode 100644 index 0000000000000000000000000000000000000000..461a9921a1e73ee6fa7981ecafb357b6aa32d05f --- /dev/null +++ b/requirements_win.txt @@ -0,0 +1,33 @@ +librosa==0.9.1 +fairseq==0.12.2 +ffmpeg-python +Flask==2.1.2 +Flask_Cors==3.0.10 +gradio>=3.7.0 +numpy +playsound==1.3.0 +PyAudio==0.2.12 +pydub==0.25.1 +pyworld==0.3.0 +requests==2.28.1 +scipy==1.7.3 +sounddevice==0.4.5 +SoundFile==0.10.3.post1 +starlette==0.19.1 +tqdm==4.63.0 +rich +loguru +torchcrepe +scikit-maad +praat-parselmouth +onnx +onnxsim +onnxoptimizer +tensorboard +tensorboardX +transformers +edge_tts +langdetect +pyyaml +pynvml +faiss-cpu diff --git a/resample.py b/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..af421fdebeaebc0fe1a335ab32bc2389d8201e3f --- /dev/null +++ b/resample.py @@ -0,0 +1,98 @@ +import argparse +import concurrent.futures +import os +from concurrent.futures import ProcessPoolExecutor +from multiprocessing import cpu_count + +import librosa +import numpy as np +from rich.progress import track +from scipy.io import wavfile + + +def load_wav(wav_path): + return librosa.load(wav_path, sr=None) + + +def trim_wav(wav, top_db=40): + return librosa.effects.trim(wav, top_db=top_db) + + +def normalize_peak(wav, threshold=1.0): + peak = np.abs(wav).max() + if peak > threshold: + wav = 0.98 * wav / peak + return wav + + +def resample_wav(wav, sr, target_sr): + return librosa.resample(wav, orig_sr=sr, target_sr=target_sr) + + +def save_wav_to_path(wav, save_path, sr): + wavfile.write( + save_path, + sr, + (wav * np.iinfo(np.int16).max).astype(np.int16) + ) + + +def process(item): + spkdir, wav_name, args = item + speaker = spkdir.replace("\\", "/").split("/")[-1] + + wav_path = os.path.join(args.in_dir, speaker, wav_name) + if os.path.exists(wav_path) and '.wav' in wav_path: + os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True) + + wav, sr = load_wav(wav_path) + wav, _ = trim_wav(wav) + wav = normalize_peak(wav) + resampled_wav = resample_wav(wav, sr, args.sr2) + + if not args.skip_loudnorm: + resampled_wav /= np.max(np.abs(resampled_wav)) + + save_path2 = os.path.join(args.out_dir2, speaker, wav_name) + save_wav_to_path(resampled_wav, save_path2, args.sr2) + + +""" +def process_all_speakers(): + process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) + + with ThreadPoolExecutor(max_workers=process_count) as executor: + for speaker in speakers: + spk_dir = os.path.join(args.in_dir, speaker) + if os.path.isdir(spk_dir): + print(spk_dir) + futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] + for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): + pass +""" +# multi process + + +def process_all_speakers(): + process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) + with ProcessPoolExecutor(max_workers=process_count) as executor: + for speaker in speakers: + spk_dir = os.path.join(args.in_dir, speaker) + if os.path.isdir(spk_dir): + print(spk_dir) + futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] + for _ in track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"): + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sr2", type=int, default=44100, help="sampling rate") + parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir") + parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir") + parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it") + args = parser.parse_args() + + print(f"CPU count: {cpu_count()}") + speakers = os.listdir(args.in_dir) + process_all_speakers() diff --git a/shadowdiffusion.png b/shadowdiffusion.png new file mode 100644 index 0000000000000000000000000000000000000000..dedec9d787f156ba2d2ca6675cfc6c9d4287fe04 Binary files /dev/null and b/shadowdiffusion.png differ diff --git a/sovits4_for_colab.ipynb b/sovits4_for_colab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7669a53a146d1472e391c9570582a5d24aca9c0d --- /dev/null +++ b/sovits4_for_colab.ipynb @@ -0,0 +1,718 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "2q0l56aFQhAM" + }, + "source": [ + "# Terms of Use\n", + "\n", + "### Please solve the authorization problem of the dataset on your own. You shall be solely responsible for any problems caused by the use of non-authorized datasets for training and all consequences thereof.The repository and its maintainer, svc develop team, have nothing to do with the consequences!\n", + "\n", + "1. This project is established for academic exchange purposes only and is intended for communication and learning purposes. It is not intended for production environments.\n", + "2. Any videos based on sovits that are published on video platforms must clearly indicate in the description that they are used for voice changing and specify the input source of the voice or audio, for example, using videos or audios published by others and separating the vocals as input source for conversion, which must provide clear original video or music links. If your own voice or other synthesized voices from other commercial vocal synthesis software are used as the input source for conversion, you must also explain it in the description.\n", + "3. You shall be solely responsible for any infringement problems caused by the input source. When using other commercial vocal synthesis software as input source, please ensure that you comply with the terms of use of the software. Note that many vocal synthesis engines clearly state in their terms of use that they cannot be used for input source conversion.\n", + "4. Continuing to use this project is deemed as agreeing to the relevant provisions stated in this repository README. This repository README has the obligation to persuade, and is not responsible for any subsequent problems that may arise.\n", + "5. If you distribute this repository's code or publish any results produced by this project publicly (including but not limited to video sharing platforms), please indicate the original author and code source (this repository).\n", + "6. If you use this project for any other plan, please contact and inform the author of this repository in advance. Thank you very much.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "M_RcDbVPhivj" + }, + "source": [ + "## **Note:**\n", + "## **Make sure there is no a directory named `sovits4data` in your google drive at the first time you use this notebook.**\n", + "## **It will be created to store some necessary files.** \n", + "## **For sure you can change it to another directory by modifying `sovits_data_dir` variable.**" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "fHaw6hGEa_Nk" + }, + "source": [ + "# **Initialize environment**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0gQcIZ8RsOkn" + }, + "outputs": [], + "source": [ + "#@title Connect to colab runtime and check GPU\n", + "\n", + "#@markdown # Connect to colab runtime and check GPU\n", + "\n", + "#@markdown\n", + "\n", + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0YUGpYrXhMck" + }, + "outputs": [], + "source": [ + "#@title Clone repository and install requirements\n", + "\n", + "#@markdown # Clone repository and install requirements\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown ### After the execution is completed, the runtime will **automatically restart**\n", + "\n", + "#@markdown\n", + "\n", + "!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n", + "%cd /content/so-vits-svc\n", + "%pip install --upgrade pip setuptools\n", + "%pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118\n", + "exit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wmUkpUmfn_Hs" + }, + "outputs": [], + "source": [ + "#@title Mount google drive and select which directories to sync with google drive\n", + "\n", + "#@markdown # Mount google drive and select which directories to sync with google drive\n", + "\n", + "#@markdown\n", + "\n", + "from google.colab import drive\n", + "drive.mount(\"/content/drive\")\n", + "\n", + "#@markdown Directory to store **necessary files**, dont miss the slash at the end👇.\n", + "sovits_data_dir = \"/content/drive/MyDrive/sovits4data/\" #@param {type:\"string\"}\n", + "#@markdown By default it will create a `sovits4data/` folder in your google drive.\n", + "RAW_DIR = sovits_data_dir + \"raw/\"\n", + "RESULTS_DIR = sovits_data_dir + \"results/\"\n", + "FILELISTS_DIR = sovits_data_dir + \"filelists/\"\n", + "CONFIGS_DIR = sovits_data_dir + \"configs/\"\n", + "LOGS_DIR = sovits_data_dir + \"logs/44k/\"\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown ### These folders will be synced with your google drvie\n", + "\n", + "#@markdown ### **Strongly recommend to check all.**\n", + "\n", + "#@markdown Sync **input audios** and **output audios**\n", + "sync_raw_and_results = True #@param {type:\"boolean\"}\n", + "if sync_raw_and_results:\n", + " !mkdir -p {RAW_DIR}\n", + " !mkdir -p {RESULTS_DIR}\n", + " !rm -rf /content/so-vits-svc/raw\n", + " !rm -rf /content/so-vits-svc/results\n", + " !ln -s {RAW_DIR} /content/so-vits-svc/raw\n", + " !ln -s {RESULTS_DIR} /content/so-vits-svc/results\n", + "\n", + "#@markdown Sync **config** and **models**\n", + "sync_configs_and_logs = True #@param {type:\"boolean\"}\n", + "if sync_configs_and_logs:\n", + " !mkdir -p {FILELISTS_DIR}\n", + " !mkdir -p {CONFIGS_DIR}\n", + " !mkdir -p {LOGS_DIR}\n", + " !rm -rf /content/so-vits-svc/filelists\n", + " !rm -rf /content/so-vits-svc/configs\n", + " !rm -rf /content/so-vits-svc/logs/44k\n", + " !ln -s {FILELISTS_DIR} /content/so-vits-svc/filelists\n", + " !ln -s {CONFIGS_DIR} /content/so-vits-svc/configs\n", + " !ln -s {LOGS_DIR} /content/so-vits-svc/logs/44k" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "G_PMPCN6wvgZ" + }, + "outputs": [], + "source": [ + "#@title Get pretrained model(Optional but strongly recommend).\n", + "\n", + "#@markdown # Get pretrained model(Optional but strongly recommend).\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown - Pre-trained model files: `G_0.pth` `D_0.pth`\n", + "#@markdown - Place them under /sovits4data/logs/44k/ in your google drive manualy\n", + "\n", + "#@markdown Get them from svc-develop-team(TBD) or anywhere else.\n", + "\n", + "#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n", + "\n", + "download_pretrained_model = True #@param {type:\"boolean\"}\n", + "D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_D_320000.pth\"] {allow-input: true}\n", + "G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_G_320000.pth\"] {allow-input: true}\n", + "\n", + "download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n", + "diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n", + "\n", + "%cd /content/so-vits-svc\n", + "\n", + "if download_pretrained_model:\n", + " !curl -L {D_0_URL} -o logs/44k/D_0.pth\n", + " !md5sum logs/44k/D_0.pth\n", + " !curl -L {G_0_URL} -o logs/44k/G_0.pth\n", + " !md5sum logs/44k/G_0.pth\n", + "\n", + "if download_pretrained_diffusion_model:\n", + " !mkdir -p logs/44k/diffusion\n", + " !curl -L {diff_model_URL} -o logs/44k/diffusion/model_0.pt\n", + " !md5sum logs/44k/diffusion/model_0.pt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "k1qadJBFehMo" + }, + "source": [ + "# **Dataset preprocessing**" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "kBlju6Q3lSM6" + }, + "source": [ + "Pack and upload your raw dataset(dataset_raw/) to your google drive.\n", + "\n", + "Makesure the file structure in your zip file looks like this:\n", + "\n", + "```\n", + "YourZIPforSingleSpeakers.zip\n", + "└───speaker\n", + " ├───xxx1-xxx1.wav\n", + " ├───...\n", + " └───Lxx-0xx8.wav\n", + "```\n", + "\n", + "```\n", + "YourZIPforMultipleSpeakers.zip\n", + "├───speaker0\n", + "│ ├───xxx1-xxx1.wav\n", + "│ ├───...\n", + "│ └───Lxx-0xx8.wav\n", + "└───speaker1\n", + " ├───xx2-0xxx2.wav\n", + " ├───...\n", + " └───xxx7-xxx007.wav\n", + "```\n", + "\n", + "**Even if there is only one speaker, a folder named `{speaker_name}` is needed.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U05CXlAipvJR" + }, + "outputs": [], + "source": [ + "#@title Get raw dataset from google drive\n", + "\n", + "#@markdown # Get raw dataset from google drive\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown Directory where **your zip file** located in, dont miss the slash at the end👇.\n", + "sovits_data_dir = \"/content/drive/MyDrive/sovits4data/\" #@param {type:\"string\"}\n", + "#@markdown Filename of **your zip file**, do NOT be \"dataset.zip\"\n", + "zip_filename = \"YourZIPFilenameofRawDataset.zip\" #@param {type:\"string\"}\n", + "ZIP_PATH = sovits_data_dir + zip_filename\n", + "\n", + "!unzip -od /content/so-vits-svc/dataset_raw {ZIP_PATH}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_ThKTzYs5CfL" + }, + "outputs": [], + "source": [ + "#@title Resample to 44100Hz and mono\n", + "\n", + "#@markdown # Resample to 44100Hz and mono\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "!python resample.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "svITReeL5N8K" + }, + "outputs": [], + "source": [ + "#@title Divide filelists and generate config.json\n", + "\n", + "#@markdown # Divide filelists and generate config.json\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "\n", + "speech_encoder = \"vec768l12\" #@param [\"vec768l12\", \"vec256l9\", \"hubertsoft\", \"whisper-ppg\", \"whisper-ppg-large\"]\n", + "use_vol_aug = False #@param {type:\"boolean\"}\n", + "vol_aug = \"--vol_aug\" if use_vol_aug else \"\"\n", + "\n", + "from pretrain.meta import download_dict\n", + "download_dict = download_dict()\n", + "\n", + "url = download_dict[speech_encoder][\"url\"]\n", + "output = download_dict[speech_encoder][\"output\"]\n", + "\n", + "import os\n", + "if not os.path.exists(output):\n", + " !curl -L {url} -o {output}\n", + " !md5sum {output}\n", + "\n", + "!python preprocess_flist_config.py --speech_encoder={speech_encoder} {vol_aug}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xHUXMi836DMe" + }, + "outputs": [], + "source": [ + "#@title Generate hubert and f0\n", + "\n", + "#@markdown # Generate hubert and f0\n", + "\n", + "#@markdown\n", + "%cd /content/so-vits-svc\n", + "\n", + "f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n", + "use_diff = True #@param {type:\"boolean\"}\n", + "\n", + "import os\n", + "if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n", + " !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n", + "\n", + "if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n", + " !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n", + "\n", + "\n", + "diff_param = \"\"\n", + "if use_diff:\n", + " diff_param = \"--use_diff\"\n", + "\n", + " if not os.path.exists(\"./pretrain/nsf_hifigan/model\"):\n", + " !curl -L https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip -o nsf_hifigan_20221211.zip\n", + " !md5sum nsf_hifigan_20221211.zip\n", + " !unzip nsf_hifigan_20221211.zip\n", + " !rm -rf pretrain/nsf_hifigan\n", + " !mv -v nsf_hifigan pretrain\n", + "\n", + "!python preprocess_hubert_f0.py --f0_predictor={f0_predictor} {diff_param}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Wo4OTmTAUXgj" + }, + "outputs": [], + "source": [ + "#@title Save the preprocessed dataset to google drive\n", + "\n", + "#@markdown # Save the preprocessed dataset to google drive\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown You can save the dataset and related files to your google drive for the next training\n", + "\n", + "#@markdown **Directory for saving**, dont miss the slash at the end👇.\n", + "sovits_data_dir = \"/content/drive/MyDrive/sovits4data/\" #@param {type:\"string\"}\n", + "\n", + "#@markdown There will be a `dataset.zip` contained `dataset/` in your google drive, which is preprocessed data.\n", + "\n", + "!mkdir -p {sovits_data_dir}\n", + "!zip -r dataset.zip /content/so-vits-svc/dataset\n", + "!cp -vr dataset.zip \"{sovits_data_dir}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P2G6v_6zblWK" + }, + "outputs": [], + "source": [ + "#@title Unzip preprocessed dataset from google drive directly if you have preprocessed already.\n", + "\n", + "#@markdown # Unzip preprocessed dataset from google drive directly if you have preprocessed already.\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown Directory where **your preprocessed dataset** located in, dont miss the slash at the end👇.\n", + "sovits_data_dir = \"/content/drive/MyDrive/sovits4data/\" #@param {type:\"string\"}\n", + "CONFIG = sovits_data_dir + \"configs/\"\n", + "FILELISTS = sovits_data_dir + \"filelists/\"\n", + "DATASET = sovits_data_dir + \"dataset.zip\"\n", + "\n", + "!cp -vr {CONFIG} /content/so-vits-svc/\n", + "!cp -vr {FILELISTS} /content/so-vits-svc/\n", + "!unzip {DATASET} -d /" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "ENoH-pShel7w" + }, + "source": [ + "# **Trainning**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-hEFFTCfZf57" + }, + "outputs": [], + "source": [ + "#@title Start training\n", + "\n", + "#@markdown # Start training\n", + "\n", + "#@markdown If you want to use pre-trained models, upload them to /sovits4data/logs/44k/ in your google drive manualy.\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "\n", + "#@markdown Whether to enable tensorboard\n", + "tensorboard_on = True #@param {type:\"boolean\"}\n", + "\n", + "if tensorboard_on:\n", + " %load_ext tensorboard\n", + " %tensorboard --logdir logs/44k\n", + "\n", + "config_path = \"configs/config.json\"\n", + "\n", + "from pretrain.meta import get_speech_encoder\n", + "url, output = get_speech_encoder(config_path)\n", + "\n", + "import os\n", + "if not os.path.exists(output):\n", + " !curl -L {url} -o {output}\n", + "\n", + "!python train.py -c {config_path} -m 44k" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZThaMxmIJgWy" + }, + "outputs": [], + "source": [ + "#@title Train cluster model (Optional)\n", + "\n", + "#@markdown # Train cluster model (Optional)\n", + "\n", + "#@markdown #### Details see [README.md#cluster-based-timbre-leakage-control](https://github.com/svc-develop-team/so-vits-svc#cluster-based-timbre-leakage-control)\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "!python cluster/train_cluster.py --gpu" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Train index model (Optional)\n", + "\n", + "#@markdown # Train index model (Optional)\n", + "\n", + "#@markdown #### Details see [README.md#feature-retrieval](https://github.com/svc-develop-team/so-vits-svc#feature-retrieval)\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "!python train_index.py -c configs/config.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Train diffusion model (Optional)\n", + "\n", + "#@markdown # Train diffusion model (Optional)\n", + "\n", + "#@markdown #### Details see [README.md#-about-shallow-diffusion](https://github.com/svc-develop-team/so-vits-svc#-about-shallow-diffusion)\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "\n", + "import os\n", + "if not os.path.exists(\"./pretrain/nsf_hifigan/model\"):\n", + " !curl -L https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip -o nsf_hifigan_20221211.zip\n", + " !unzip nsf_hifigan_20221211.zip\n", + " !rm -rf pretrain/nsf_hifigan\n", + " !mv -v nsf_hifigan pretrain\n", + "\n", + "#@markdown Whether to enable tensorboard\n", + "tensorboard_on = True #@param {type:\"boolean\"}\n", + "\n", + "if tensorboard_on:\n", + " %load_ext tensorboard\n", + " %tensorboard --logdir logs/44k\n", + "\n", + "!python train_diff.py -c configs/diffusion.yaml" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# keep colab alive\n", + "Open the devtools and copy & paste to run the scrips.\n", + "\n", + "\n", + "```JavaScript\n", + "const ping = () => {\n", + " const btn = document.querySelector(\"colab-connect-button\");\n", + " const inner_btn = btn.shadowRoot.querySelector(\"#connect\");\n", + " if (inner_btn) {\n", + " inner_btn.click();\n", + " console.log(\"Clicked on connect button\");\n", + " } else {\n", + " console.log(\"connect button not found\");\n", + " }\n", + "\n", + " const nextTime = 50000 + Math.random() * 10000;\n", + "\n", + " setTimeout(ping, nextTime);\n", + "};\n", + "\n", + "ping();\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "oCnbX-OT897k" + }, + "source": [ + "# **Inference**\n", + "### Upload wav files from this notebook\n", + "### **OR**\n", + "### Upload to `sovits4data/raw/` in your google drive manualy (should be faster)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#title Download nsf_hifigan if you need it\n", + "\n", + "%cd /content/so-vits-svc\n", + "!curl -L https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip -o /content/so-vits-svc/nsf_hifigan_20221211.zip\n", + "!unzip nsf_hifigan_20221211.zip\n", + "!rm -rf pretrain/nsf_hifigan\n", + "!mv -v nsf_hifigan pretrain\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "executionInfo": { + "elapsed": 94633, + "status": "ok", + "timestamp": 1678591088790, + "user": { + "displayName": "謬紗特", + "userId": "09445825975794260265" + }, + "user_tz": -480 + }, + "id": "XUsmGkgCMD_Q", + "outputId": "8bbfde13-030a-4ba0-bbdb-7eb6b89c02b4" + }, + "outputs": [], + "source": [ + "#@title Upload wav files, the filename should not contain any special symbols like `#` `$` `(` `)`\n", + "\n", + "#@markdown # Upload wav files, the filename should not contain any special symbols like `#` `$` `(` `)`\n", + "\n", + "#@markdown\n", + "\n", + "%cd /content/so-vits-svc\n", + "%run wav_upload.py --type audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dYnKuKTIj3z1" + }, + "outputs": [], + "source": [ + "#@title Start inference (and download)\n", + "\n", + "#@markdown # Start inference (and download)\n", + "\n", + "#@markdown Parameters see [README.MD#Inference](https://github.com/svc-develop-team/so-vits-svc#-inference)\n", + "\n", + "#@markdown\n", + "\n", + "wav_filename = \"YourWAVFile.wav\" #@param {type:\"string\"}\n", + "model_filename = \"G_210000.pth\" #@param {type:\"string\"}\n", + "model_path = \"/content/so-vits-svc/logs/44k/\" + model_filename\n", + "speaker = \"YourSpeaker\" #@param {type:\"string\"}\n", + "trans = \"0\" #@param {type:\"string\"}\n", + "cluster_infer_ratio = \"0\" #@param {type:\"string\"}\n", + "auto_predict_f0 = False #@param {type:\"boolean\"}\n", + "apf = \"\"\n", + "if auto_predict_f0:\n", + " apf = \" -a \"\n", + "\n", + "f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n", + "\n", + "enhance = False #@param {type:\"boolean\"}\n", + "ehc = \"\"\n", + "if enhance:\n", + " ehc = \" -eh \"\n", + "#@markdown\n", + "\n", + "#@markdown Generally keep default:\n", + "config_filename = \"config.json\" #@param {type:\"string\"}\n", + "config_path = \"/content/so-vits-svc/configs/\" + config_filename\n", + "\n", + "from pretrain.meta import get_speech_encoder\n", + "url, output = get_speech_encoder(config_path)\n", + "\n", + "import os\n", + "\n", + "if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n", + " !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n", + "\n", + "if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n", + " !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n", + "\n", + "if not os.path.exists(output):\n", + " !curl -L {url} -o {output}\n", + "\n", + "kmeans_filenname = \"kmeans_10000.pt\" #@param {type:\"string\"}\n", + "kmeans_path = \"/content/so-vits-svc/logs/44k/\" + kmeans_filenname\n", + "slice_db = \"-40\" #@param {type:\"string\"}\n", + "wav_format = \"flac\" #@param {type:\"string\"}\n", + "\n", + "key = \"auto\" if auto_predict_f0 else f\"{trans}key\"\n", + "cluster_name = \"\" if cluster_infer_ratio == \"0\" else f\"_{cluster_infer_ratio}\"\n", + "isdiffusion = \"sovits\"\n", + "wav_output = f\"/content/so-vits-svc/results/{wav_filename}_{key}_{speaker}{cluster_name}_{isdiffusion}_{f0_predictor}.{wav_format}\"\n", + "\n", + "%cd /content/so-vits-svc\n", + "!python inference_main.py -n {wav_filename} -m {model_path} -s {speaker} -t {trans} -cr {cluster_infer_ratio} -c {config_path} -cm {kmeans_path} -sd {slice_db} -wf {wav_format} {apf} --f0_predictor={f0_predictor} {ehc}\n", + "\n", + "#@markdown\n", + "\n", + "#@markdown If you dont want to download from here, uncheck this.\n", + "download_after_inference = True #@param {type:\"boolean\"}\n", + "\n", + "if download_after_inference:\n", + " from google.colab import files\n", + " files.download(wav_output)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [ + { + "file_id": "19fxpo-ZoL_ShEUeZIZi6Di-YioWrEyhR", + "timestamp": 1678516497580 + }, + { + "file_id": "1rCUOOVG7-XQlVZuWRAj5IpGrMM8t07pE", + "timestamp": 1673086970071 + }, + { + "file_id": "1Ul5SmzWiSHBj0MaKA0B682C-RZKOycwF", + "timestamp": 1670483515921 + } + ] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/spkmix.py b/spkmix.py new file mode 100644 index 0000000000000000000000000000000000000000..1d266e017859aca3c48727c5acbef9c8da8c1411 --- /dev/null +++ b/spkmix.py @@ -0,0 +1,11 @@ +# 角色混合轨道 编写规则: +# 角色ID : [[起始时间1, 终止时间1, 起始数值1, 起始数值1], [起始时间2, 终止时间2, 起始数值2, 起始数值2]] +# 起始时间和前一个的终止时间必须相同,第一个起始时间必须为0,最后一个终止时间必须为1 (时间的范围为0-1) +# 全部角色必须填写,不使用的角色填[[0., 1., 0., 0.]]即可 +# 融合数值可以随便填,在指定的时间段内从起始数值线性变化为终止数值,内部会自动确保线性组合为1,可以放心使用 + +spk_mix_map = { + 0 : [[0., 0.5, 1, 0.5], [0.5, 1, 0.5, 1]], + 1 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]], + 2 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]] +} \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8487f177df1db43951e015802a8ea05166c56b8c --- /dev/null +++ b/train.py @@ -0,0 +1,329 @@ +import logging +import multiprocessing +import os +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn import functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import modules.commons as commons +import utils +from data_utils import TextAudioCollate, TextAudioSpeakerLoader +from models import ( + MultiPeriodDiscriminator, + SynthesizerTrn, +) +from modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss +from modules.mel_processing import mel_spectrogram_torch, spec_to_mel_torch + +logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.getLogger('numba').setLevel(logging.WARNING) + +torch.backends.cudnn.benchmark = True +global_step = 0 +start_time = time.time() + +# os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' + + +def main(): + """Assume Single Node Multi GPUs Training Only""" + assert torch.cuda.is_available(), "CPU training is not allowed." + hps = utils.get_hparams() + + n_gpus = torch.cuda.device_count() + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = hps.train.port + + mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) + + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger = utils.get_logger(hps.model_dir) + logger.info(hps) + utils.check_git_hash(hps.model_dir) + writer = SummaryWriter(log_dir=hps.model_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) + + # for pytorch on win, backend use gloo + dist.init_process_group(backend= 'gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus, rank=rank) + torch.manual_seed(hps.train.seed) + torch.cuda.set_device(rank) + collate_fn = TextAudioCollate() + all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training. + train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem) + num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count() + if all_in_mem: + num_workers = 0 + train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True, + batch_size=hps.train.batch_size, collate_fn=collate_fn) + if rank == 0: + eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps, all_in_mem=all_in_mem,vol_aug = False) + eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False, + batch_size=1, pin_memory=False, + drop_last=False, collate_fn=collate_fn) + + net_g = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model).cuda(rank) + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + optim_g = torch.optim.AdamW( + net_g.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps) + net_g = DDP(net_g, device_ids=[rank]) # , find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank]) + + skip_optimizer = False + try: + _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, + optim_g, skip_optimizer) + _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, + optim_d, skip_optimizer) + epoch_str = max(epoch_str, 1) + name=utils.latest_checkpoint_path(hps.model_dir, "D_*.pth") + global_step=int(name[name.rfind("_")+1:name.rfind(".")])+1 + #global_step = (epoch_str - 1) * len(train_loader) + except Exception: + print("load old checkpoint failed...") + epoch_str = 1 + global_step = 0 + if skip_optimizer: + epoch_str = 1 + global_step = 0 + + warmup_epoch = hps.train.warmup_epochs + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scaler = GradScaler(enabled=hps.train.fp16_run) + + for epoch in range(epoch_str, hps.train.epochs + 1): + # set up warm-up learning rate + if epoch <= warmup_epoch: + for param_group in optim_g.param_groups: + param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch + for param_group in optim_d.param_groups: + param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch + # training + if rank == 0: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, + [train_loader, eval_loader], logger, [writer, writer_eval]) + else: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, + [train_loader, None], None, None) + # update learning rate + scheduler_g.step() + scheduler_d.step() + + +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): + net_g, net_d = nets + optim_g, optim_d = optims + scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16 + + # train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, items in enumerate(train_loader): + c, f0, spec, y, spk, lengths, uv,volume = items + g = spk.cuda(rank, non_blocking=True) + spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) + c = c.cuda(rank, non_blocking=True) + f0 = f0.cuda(rank, non_blocking=True) + uv = uv.cuda(rank, non_blocking=True) + lengths = lengths.cuda(rank, non_blocking=True) + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax) + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): + y_hat, ids_slice, z_mask, \ + (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, + spec_lengths=lengths,vol = volume) + + y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax + ) + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + + with autocast(enabled=False, dtype=half_type): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc + + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False, dtype=half_type): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0 + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0 + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]['lr'] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + reference_loss=0 + for i in losses: + reference_loss += i + logger.info('Train Epoch: {} [{:.0f}%]'.format( + epoch, + 100. * batch_idx / len(train_loader))) + logger.info(f"Losses: {[x.item() for x in losses]}, step: {global_step}, lr: {lr}, reference_loss: {reference_loss}") + + scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, + "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} + scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, + "loss/g/lf0": loss_lf0}) + + # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()) + } + + if net_g.module.use_automatic_f0_prediction: + image_dict.update({ + "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + pred_lf0[0, 0, :].detach().cpu().numpy()), + "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + norm_lf0[0, 0, :].detach().cpu().numpy()) + }) + + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict + ) + + if global_step % hps.train.eval_interval == 0: + evaluate(hps, net_g, eval_loader, writer_eval) + utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) + utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + keep_ckpts = getattr(hps.train, 'keep_ckpts', 0) + if keep_ckpts > 0: + utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) + + global_step += 1 + + if rank == 0: + global start_time + now = time.time() + durtaion = format(now - start_time, '.2f') + logger.info(f'====> Epoch: {epoch}, cost {durtaion} s') + start_time = now + + +def evaluate(hps, generator, eval_loader, writer_eval): + generator.eval() + image_dict = {} + audio_dict = {} + with torch.no_grad(): + for batch_idx, items in enumerate(eval_loader): + c, f0, spec, y, spk, _, uv,volume = items + g = spk[:1].cuda(0) + spec, y = spec[:1].cuda(0), y[:1].cuda(0) + c = c[:1].cuda(0) + f0 = f0[:1].cuda(0) + uv= uv[:1].cuda(0) + if volume is not None: + volume = volume[:1].cuda(0) + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax) + y_hat,_ = generator.module.infer(c, f0, uv, g=g,vol = volume) + + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax + ) + + audio_dict.update({ + f"gen/audio_{batch_idx}": y_hat[0], + f"gt/audio_{batch_idx}": y[0] + }) + image_dict.update({ + "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), + "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()) + }) + utils.summarize( + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=hps.data.sampling_rate + ) + generator.train() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_diff.py b/train_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..65ba3820e180eacbb712755aac7f8cd4fe99bce7 --- /dev/null +++ b/train_diff.py @@ -0,0 +1,77 @@ +import argparse + +import torch +from loguru import logger +from torch.optim import lr_scheduler + +from diffusion.data_loaders import get_data_loaders +from diffusion.logger import utils +from diffusion.solver import train +from diffusion.unit2mel import Unit2Mel +from diffusion.vocoder import Vocoder + + +def parse_args(args=None, namespace=None): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config", + type=str, + required=True, + help="path to the config file") + return parser.parse_args(args=args, namespace=namespace) + + +if __name__ == '__main__': + # parse commands + cmd = parse_args() + + # load config + args = utils.load_config(cmd.config) + logger.info(' > config:'+ cmd.config) + logger.info(' > exp:'+ args.env.expdir) + + # load vocoder + vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) + + # load model + model = Unit2Mel( + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + vocoder.dimension, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max + ) + + logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') + + # load parameters + optimizer = torch.optim.AdamW(model.parameters()) + initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) + for param_group in optimizer.param_groups: + param_group['initial_lr'] = args.train.lr + param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) ) + param_group['weight_decay'] = args.train.weight_decay + scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2) + + # device + if args.device == 'cuda': + torch.cuda.set_device(args.env.gpu_id) + model.to(args.device) + + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to(args.device) + + # datas + loader_train, loader_valid = get_data_loaders(args, whole_audio=False) + + # run + train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid) + diff --git a/train_index.py b/train_index.py new file mode 100644 index 0000000000000000000000000000000000000000..13d66d3ebf39cac113191278d6641407c223b957 --- /dev/null +++ b/train_index.py @@ -0,0 +1,30 @@ +import argparse +import os +import pickle + +import utils + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--root_dir", type=str, default="dataset/44k", help="path to root dir" + ) + parser.add_argument('-c', '--config', type=str, default="./configs/config.json", + help='JSON file for configuration') + parser.add_argument( + "--output_dir", type=str, default="logs/44k", help="path to output dir" + ) + + args = parser.parse_args() + + hps = utils.get_hparams_from_file(args.config) + spk_dic = hps.spk + result = {} + + for k,v in spk_dic.items(): + print(f"now, index {k} feature...") + index = utils.train_index(k,args.root_dir) + result[v] = index + + with open(os.path.join(args.output_dir,"feature_and_index.pkl"),"wb") as f: + pickle.dump(result,f) \ No newline at end of file diff --git a/trained/put_trained_checkpoints_here b/trained/put_trained_checkpoints_here new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95b6d8882867a81bc638237957dd3141b7bc1210 --- /dev/null +++ b/utils.py @@ -0,0 +1,572 @@ +import argparse +import glob +import json +import logging +import os +import re +import subprocess +import sys +import traceback +from multiprocessing import cpu_count + +import faiss +import librosa +import numpy as np +import torch +from scipy.io.wavfile import read +from sklearn.cluster import MiniBatchKMeans +from torch.nn import functional as F + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.WARN) +logger = logging + +f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + +def normalize_f0(f0, x_mask, uv, random_scale=True): + # calculate means based on x_mask + uv_sum = torch.sum(uv, dim=1, keepdim=True) + uv_sum[uv_sum == 0] = 9999 + means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum + + if random_scale: + factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device) + else: + factor = torch.ones(f0.shape[0], 1).to(f0.device) + # normalize f0 based on means and factor + f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) + if torch.isnan(f0_norm).any(): + exit(0) + return f0_norm * x_mask +def plot_data_to_numpy(x, y): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + plt.plot(x) + plt.plot(y) + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def f0_to_coarse(f0): + f0_mel = 1127 * (1 + f0 / 700).log() + a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) + b = f0_mel_min * a - 1. + f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) + # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) + f0_coarse = torch.round(f0_mel).long() + f0_coarse = f0_coarse * (f0_coarse > 0) + f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) + f0_coarse = f0_coarse * (f0_coarse < f0_bin) + f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) + return f0_coarse + +def get_content(cmodel, y): + with torch.no_grad(): + c = cmodel.extract_features(y.squeeze(1))[0] + c = c.transpose(1, 2) + return c + +def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs): + if f0_predictor == "pm": + from modules.F0Predictor.PMF0Predictor import PMF0Predictor + f0_predictor_object = PMF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) + elif f0_predictor == "crepe": + from modules.F0Predictor.CrepeF0Predictor import CrepeF0Predictor + f0_predictor_object = CrepeF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,device=kargs["device"],threshold=kargs["threshold"]) + elif f0_predictor == "harvest": + from modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor + f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) + elif f0_predictor == "dio": + from modules.F0Predictor.DioF0Predictor import DioF0Predictor + f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) + elif f0_predictor == "rmvpe": + from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor + f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) + elif f0_predictor == "fcpe": + from modules.F0Predictor.FCPEF0Predictor import FCPEF0Predictor + f0_predictor_object = FCPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) + else: + raise Exception("Unknown f0 predictor") + return f0_predictor_object + +def get_speech_encoder(speech_encoder,device=None,**kargs): + if speech_encoder == "vec768l12": + from vencoder.ContentVec768L12 import ContentVec768L12 + speech_encoder_object = ContentVec768L12(device = device) + elif speech_encoder == "vec256l9": + from vencoder.ContentVec256L9 import ContentVec256L9 + speech_encoder_object = ContentVec256L9(device = device) + elif speech_encoder == "vec256l9-onnx": + from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx + speech_encoder_object = ContentVec256L9_Onnx(device = device) + elif speech_encoder == "vec256l12-onnx": + from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx + speech_encoder_object = ContentVec256L12_Onnx(device = device) + elif speech_encoder == "vec768l9-onnx": + from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx + speech_encoder_object = ContentVec768L9_Onnx(device = device) + elif speech_encoder == "vec768l12-onnx": + from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx + speech_encoder_object = ContentVec768L12_Onnx(device = device) + elif speech_encoder == "hubertsoft-onnx": + from vencoder.HubertSoft_Onnx import HubertSoft_Onnx + speech_encoder_object = HubertSoft_Onnx(device = device) + elif speech_encoder == "hubertsoft": + from vencoder.HubertSoft import HubertSoft + speech_encoder_object = HubertSoft(device = device) + elif speech_encoder == "whisper-ppg": + from vencoder.WhisperPPG import WhisperPPG + speech_encoder_object = WhisperPPG(device = device) + elif speech_encoder == "cnhubertlarge": + from vencoder.CNHubertLarge import CNHubertLarge + speech_encoder_object = CNHubertLarge(device = device) + elif speech_encoder == "dphubert": + from vencoder.DPHubert import DPHubert + speech_encoder_object = DPHubert(device = device) + elif speech_encoder == "whisper-ppg-large": + from vencoder.WhisperPPGLarge import WhisperPPGLarge + speech_encoder_object = WhisperPPGLarge(device = device) + elif speech_encoder == "wavlmbase+": + from vencoder.WavLMBasePlus import WavLMBasePlus + speech_encoder_object = WavLMBasePlus(device = device) + else: + raise Exception("Unknown speech encoder") + return speech_encoder_object + +def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + iteration = checkpoint_dict['iteration'] + learning_rate = checkpoint_dict['learning_rate'] + if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + saved_state_dict = checkpoint_dict['model'] + model = model.to(list(saved_state_dict.values())[0].dtype) + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + # assert "dec" in k or "disc" in k + # print("load", k) + new_state_dict[k] = saved_state_dict[k] + assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) + except Exception: + if "enc_q" not in k or "emb_g" not in k: + print("%s is not in the checkpoint,please check your checkpoint.If you're using pretrain model,just ignore this warning." % k) + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, 'module'): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + print("load ") + logger.info("Loaded checkpoint '{}' (iteration {})".format( + checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info("Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path)) + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save({'model': state_dict, + 'iteration': iteration, + 'optimizer': optimizer.state_dict(), + 'learning_rate': learning_rate}, checkpoint_path) + +def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True): + """Freeing up space by deleting saved ckpts + + Arguments: + path_to_models -- Path to the model directory + n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth + sort_by_time -- True -> chronologically delete ckpts + False -> lexicographically delete ckpts + """ + ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] + def name_key(_f): + return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) + def time_key(_f): + return os.path.getmtime(os.path.join(path_to_models, _f)) + sort_key = time_key if sort_by_time else name_key + def x_sorted(_x): + return sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key) + to_del = [os.path.join(path_to_models, fn) for fn in + (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] + def del_info(fn): + return logger.info(f".. Free up space by deleting ckpt {fn}") + def del_routine(x): + return [os.remove(x), del_info(x)] + [del_routine(fn) for fn in to_del] + +def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats='HWC') + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10,2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', + interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True): + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default="./configs/config.json", + help='JSON file for configuration') + parser.add_argument('-m', '--model', type=str, required=True, + help='Model name') + + args = parser.parse_args() + model_dir = os.path.join("./logs", args.model) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams =HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path, infer_mode = False): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + hparams =HParams(**config) if not infer_mode else InferHParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + )) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn("git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8])) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +def repeat_expand_2d(content, target_len, mode = 'left'): + # content : [h, t] + return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode) + + + +def repeat_expand_2d_left(content, target_len): + # content : [h, t] + + src_len = content.shape[-1] + target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device) + temp = torch.arange(src_len+1) * target_len / src_len + current_pos = 0 + for i in range(target_len): + if i < temp[current_pos+1]: + target[:, i] = content[:, current_pos] + else: + current_pos += 1 + target[:, i] = content[:, current_pos] + + return target + + +# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area' +def repeat_expand_2d_other(content, target_len, mode = 'nearest'): + # content : [h, t] + content = content[None,:,:] + target = F.interpolate(content,size=target_len,mode=mode)[0] + return target + + +def mix_model(model_paths,mix_rate,mode): + mix_rate = torch.FloatTensor(mix_rate)/100 + model_tem = torch.load(model_paths[0]) + models = [torch.load(path)["model"] for path in model_paths] + if mode == 0: + mix_rate = F.softmax(mix_rate,dim=0) + for k in model_tem["model"].keys(): + model_tem["model"][k] = torch.zeros_like(model_tem["model"][k]) + for i,model in enumerate(models): + model_tem["model"][k] += model[k]*mix_rate[i] + torch.save(model_tem,os.path.join(os.path.curdir,"output.pth")) + return os.path.join(os.path.curdir,"output.pth") + +def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC + # print(data1.max(),data2.max()) + rms1 = librosa.feature.rms( + y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2 + ) # 每半秒一个点 + rms2 = librosa.feature.rms(y=data2.detach().cpu().numpy(), frame_length=sr2 // 2 * 2, hop_length=sr2 // 2) + rms1 = torch.from_numpy(rms1).to(data2.device) + rms1 = F.interpolate( + rms1.unsqueeze(0), size=data2.shape[0], mode="linear" + ).squeeze() + rms2 = torch.from_numpy(rms2).to(data2.device) + rms2 = F.interpolate( + rms2.unsqueeze(0), size=data2.shape[0], mode="linear" + ).squeeze() + rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6) + data2 *= ( + torch.pow(rms1, torch.tensor(1 - rate)) + * torch.pow(rms2, torch.tensor(rate - 1)) + ) + return data2 + +def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI + n_cpu = cpu_count() + print("The feature index is constructing.") + exp_dir = os.path.join(root_dir,spk_name) + listdir_res = [] + for file in os.listdir(exp_dir): + if ".wav.soft.pt" in file: + listdir_res.append(os.path.join(exp_dir,file)) + if len(listdir_res) == 0: + raise Exception("You need to run preprocess_hubert_f0.py!") + npys = [] + for name in sorted(listdir_res): + phone = torch.load(name)[0].transpose(-1,-2).numpy() + npys.append(phone) + big_npy = np.concatenate(npys, 0) + big_npy_idx = np.arange(big_npy.shape[0]) + np.random.shuffle(big_npy_idx) + big_npy = big_npy[big_npy_idx] + if big_npy.shape[0] > 2e5: + # if(1): + info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0] + print(info) + try: + big_npy = ( + MiniBatchKMeans( + n_clusters=10000, + verbose=True, + batch_size=256 * n_cpu, + compute_labels=False, + init="random", + ) + .fit(big_npy) + .cluster_centers_ + ) + except Exception: + info = traceback.format_exc() + print(info) + n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) + index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf) + index_ivf = faiss.extract_index_ivf(index) # + index_ivf.nprobe = 1 + index.train(big_npy) + batch_size_add = 8192 + for i in range(0, big_npy.shape[0], batch_size_add): + index.add(big_npy[i : i + batch_size_add]) + # faiss.write_index( + # index, + # f"added_{spk_name}.index" + # ) + print("Successfully build index") + return index + + +class HParams(): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + def get(self,index): + return self.__dict__.get(index) + + +class InferHParams(HParams): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = InferHParams(**v) + self[k] = v + + def __getattr__(self,index): + return self.get(index) + + +class Volume_Extractor: + def __init__(self, hop_size = 512): + self.hop_size = hop_size + + def extract(self, audio): # audio: 2d tensor array + if not isinstance(audio,torch.Tensor): + audio = torch.Tensor(audio) + n_frames = int(audio.size(-1) // self.hop_size) + audio2 = audio ** 2 + audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect') + volume = torch.nn.functional.unfold(audio2[:,None,None,:],(1,self.hop_size),stride=self.hop_size)[:,:,:n_frames].mean(dim=1)[0] + volume = torch.sqrt(volume) + return volume diff --git a/vdecoder/__init__.py b/vdecoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vdecoder/hifigan/env.py b/vdecoder/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/hifigan/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/hifigan/models.py b/vdecoder/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..107553368ff1798f72df21c6d5a965260f5a60fd --- /dev/null +++ b/vdecoder/hifigan/models.py @@ -0,0 +1,557 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +def padDiff(x): + return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.onnx = False + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0, upp=None): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp=None): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.num_upsamples = len(h["upsample_rates"]) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h["sampling_rate"], + harmonic_num=8) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)) + resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])): + c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), + k, u, padding=(k - u +1 ) // 2))) + if i + 1 < len(h["upsample_rates"]): # + stride_f0 = np.prod(h["upsample_rates"][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h["upsample_initial_channel"] // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True + + def forward(self, x, f0, g=None): + # print(1,x.shape,f0.shape,f0[:, None].shape) + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + # print(2,f0.shape) + har_source, noi_source, uv = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + x = x + self.cond(g) + # print(124,x.shape,har_source.shape) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + # print(3,x.shape) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + # print(4,x_source.shape,har_source.shape,x.shape) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/hifigan/nvSTFT.py b/vdecoder/hifigan/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 --- /dev/null +++ b/vdecoder/hifigan/nvSTFT.py @@ -0,0 +1,109 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 32000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 32000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + if fmax not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + # print(222,spec) + spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/hifigan/utils.py b/vdecoder/hifigan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e519e2b7ed8fe5f93266d21d727a30173699f88b --- /dev/null +++ b/vdecoder/hifigan/utils.py @@ -0,0 +1,68 @@ +import glob +import os + +# matplotlib.use("Agg") +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vdecoder/hifiganwithsnake/alias/__init__.py b/vdecoder/hifiganwithsnake/alias/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be97a33248ae6378c6736586774abda11cfbdeba --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .act import * # noqa: F403 +from .filter import * # noqa: F403 +from .resample import * # noqa: F403 diff --git a/vdecoder/hifiganwithsnake/alias/act.py b/vdecoder/hifiganwithsnake/alias/act.py new file mode 100644 index 0000000000000000000000000000000000000000..e46b3467b73b90df51c1d19032b90d26595aca6e --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/act.py @@ -0,0 +1,130 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import pow, sin +from torch.nn import Parameter + +from .resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta = x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze( + 0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + return x + + +class Mish(nn.Module): + """ + Mish activation function is proposed in "Mish: A Self + Regularized Non-Monotonic Neural Activation Function" + paper, https://arxiv.org/abs/1908.08681. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SnakeAlias(nn.Module): + def __init__(self, + channels, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + C = None): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = SnakeBeta(channels, alpha_logscale=True) + self.upsample = UpSample1d(up_ratio, up_kernel_size, C) + self.downsample = DownSample1d(down_ratio, down_kernel_size, C) + + # x: [B,C,T] + def forward(self, x, C=None): + x = self.upsample(x, C) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/filter.py b/vdecoder/hifiganwithsnake/alias/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3942eb3ae547a2f500d5c47defdd70cd29ea4655 --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/filter.py @@ -0,0 +1,110 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12, + C=None): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + self.conv1d_block = None + if C is not None: + self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),] + self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1)) + self.conv1d_block[0].requires_grad_(False) + + #input [B, C, T] + def forward(self, x): + if self.conv1d_block[0].weight.device != x.device: + self.conv1d_block[0] = self.conv1d_block[0].to(x.device) + if self.conv1d_block is None: + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + else: + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = self.conv1d_block[0](x) + + return out \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/resample.py b/vdecoder/hifiganwithsnake/alias/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a364403f0977bc8bcffbb4764081e4bd3619467a --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/resample.py @@ -0,0 +1,72 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, C=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + self.conv_transpose1d_block = None + if C is not None: + self.conv_transpose1d_block = [nn.ConvTranspose1d(C, + C, + kernel_size=self.kernel_size, + stride=self.stride, + groups=C, + bias=False + ),] + self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone()) + self.conv_transpose1d_block[0].requires_grad_(False) + + + + # x: [B, C, T] + def forward(self, x, C=None): + if self.conv_transpose1d_block[0].weight.device != x.device: + self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device) + if self.conv_transpose1d_block is None: + if C is None: + _, C, _ = x.shape + # print("snake.conv_t.in:",x.shape) + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + # print("snake.conv_t.out:",x.shape) + x = x[..., self.pad_left:-self.pad_right] + else: + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * self.conv_transpose1d_block[0](x) + x = x[..., self.pad_left:-self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, C=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + C=C) + + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/env.py b/vdecoder/hifiganwithsnake/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/hifiganwithsnake/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py new file mode 100644 index 0000000000000000000000000000000000000000..08bbda9b77b095d81ca8d8a9e5e8ebe20fa9bcfa --- /dev/null +++ b/vdecoder/hifiganwithsnake/models.py @@ -0,0 +1,576 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from vdecoder.hifiganwithsnake.alias.act import SnakeAlias + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) + self.activations = nn.ModuleList([ + SnakeAlias(channels, C=C) for _ in range(self.num_layers) + ]) + + def forward(self, x, DIM=None): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x, DIM) + xt = c1(xt) + xt = a2(xt, DIM) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) + self.activations = nn.ModuleList([ + SnakeAlias(channels, C=C) for _ in range(self.num_layers) + ]) + + def forward(self, x, DIM=None): + for c,a in zip(self.convs, self.activations): + xt = a(x, DIM) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +def padDiff(x): + return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.onnx = False + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0, upp=None): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp=None): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.num_upsamples = len(h["upsample_rates"]) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h["sampling_rate"], + harmonic_num=8) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)) + resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])): + c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), + k, u, padding=(k - u + 1) // 2))) + if i + 1 < len(h["upsample_rates"]): # + stride_f0 = np.prod(h["upsample_rates"][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+ 1) // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + self.snakes = nn.ModuleList() + for i in range(len(self.ups)): + ch = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1))) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups)) + self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True + + def forward(self, x, f0, g=None): + # print(1,x.shape,f0.shape,f0[:, None].shape) + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + # print(2,f0.shape) + har_source, noi_source, uv = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + x = x + self.cond(g) + # print(124,x.shape,har_source.shape) + for i in range(self.num_upsamples): + # print(f"self.snakes.{i}.pre:", x.shape) + x = self.snakes[i](x) + # print(f"self.snakes.{i}.after:", x.shape) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + # print(4,x_source.shape,har_source.shape,x.shape) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + # print(f"self.resblocks.{i}.after:", xs.shape) + x = xs / self.num_kernels + x = self.snake_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/hifiganwithsnake/nvSTFT.py b/vdecoder/hifiganwithsnake/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 --- /dev/null +++ b/vdecoder/hifiganwithsnake/nvSTFT.py @@ -0,0 +1,109 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 32000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 32000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + if fmax not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + # print(222,spec) + spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/hifiganwithsnake/utils.py b/vdecoder/hifiganwithsnake/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e519e2b7ed8fe5f93266d21d727a30173699f88b --- /dev/null +++ b/vdecoder/hifiganwithsnake/utils.py @@ -0,0 +1,68 @@ +import glob +import os + +# matplotlib.use("Agg") +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vdecoder/nsf_hifigan/env.py b/vdecoder/nsf_hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/nsf_hifigan/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/nsf_hifigan/models.py b/vdecoder/nsf_hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8a35b134d814008c3990d019d1de502ff10dd86f --- /dev/null +++ b/vdecoder/nsf_hifigan/models.py @@ -0,0 +1,441 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + h = load_config(model_path) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path, map_location=device) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + +def load_config(model_path): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + json_config = json.loads(data) + h = AttrDict(json_config) + return h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + @torch.no_grad() + def forward(self, f0, upp): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0 = f0.unsqueeze(-1) + fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) + rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + is_half = rad_values.dtype is not torch.float32 + tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 + if is_half: + tmp_over_one = tmp_over_one.half() + else: + tmp_over_one = tmp_over_one.float() + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), scale_factor=upp, + mode='linear', align_corners=True + ).transpose(2, 1) + rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + rad_values = rad_values.double() + cumsum_shift = cumsum_shift.double() + sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) + if is_half: + sine_waves = sine_waves.half() + else: + sine_waves = sine_waves.float() + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp): + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=h.sampling_rate, + harmonic_num=8 + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + c_cur = h.upsample_initial_channel // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2))) + if i + 1 < len(h.upsample_rates): # + stride_f0 = int(np.prod(h.upsample_rates[i + 1:])) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + ch = h.upsample_initial_channel + for i in range(len(self.ups)): + ch //= 2 + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = int(np.prod(h.upsample_rates)) + + def forward(self, x, f0): + har_source = self.m_source(f0, self.upp).transpose(1, 2) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/nsf_hifigan/nvSTFT.py b/vdecoder/nsf_hifigan/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..e756cca561a45bde435f36447e6681bfa17e34aa --- /dev/null +++ b/vdecoder/nsf_hifigan/nvSTFT.py @@ -0,0 +1,132 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 48000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 48000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, keyshift=0, speed=1, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(n_fft * factor)) + win_size_new = int(np.round(win_size * factor)) + hop_length_new = int(np.round(hop_length * speed)) + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + mel_basis_key = str(fmax)+'_'+str(y.device) + if mel_basis_key not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) + + keyshift_key = str(keyshift)+'_'+str(y.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) + + pad_left = (win_size_new - hop_length_new) //2 + pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) + if pad_right < y.size(-1): + mode = 'reflect' + else: + mode = 'constant' + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) + y = y.squeeze(1) + + spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + if keyshift != 0: + size = n_fft // 2 + 1 + resize = spec.size(1) + if resize < size: + spec = F.pad(spec, (0, 0, 0, size-resize)) + spec = spec[:, :size, :] * win_size / win_size_new + + # print(222,spec) + spec = torch.matmul(self.mel_basis[mel_basis_key], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/nsf_hifigan/utils.py b/vdecoder/nsf_hifigan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58d0e701d377e318fe0302743c27bdb4d6e089ec --- /dev/null +++ b/vdecoder/nsf_hifigan/utils.py @@ -0,0 +1,70 @@ +import glob +import os + +import matplotlib +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vencoder/CNHubertLarge.py b/vencoder/CNHubertLarge.py new file mode 100644 index 0000000000000000000000000000000000000000..f43694762f92c5d839d358825f157f5d1a4ff6f6 --- /dev/null +++ b/vencoder/CNHubertLarge.py @@ -0,0 +1,36 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class CNHubertLarge(SpeechEncoder): + def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 1024 + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device) + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) \ No newline at end of file diff --git a/vencoder/ContentVec256L12_Onnx.py b/vencoder/ContentVec256L12_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..466e6c128b88acdfb94392662086e6752d503a27 --- /dev/null +++ b/vencoder/ContentVec256L12_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec256L12_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/ContentVec256L9.py b/vencoder/ContentVec256L9.py new file mode 100644 index 0000000000000000000000000000000000000000..c973090dd4cdaa3d8ca07d9007c26633883c36a7 --- /dev/null +++ b/vencoder/ContentVec256L9.py @@ -0,0 +1,38 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class ContentVec256L9(SpeechEncoder): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device), + "output_layer": 9, # layer 9 + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + feats = self.model.final_proj(logits[0]) + return feats.transpose(1, 2) diff --git a/vencoder/ContentVec256L9_Onnx.py b/vencoder/ContentVec256L9_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..a27e1f76655d9dc9fcc41d05d11b4a1ac5d85b90 --- /dev/null +++ b/vencoder/ContentVec256L9_Onnx.py @@ -0,0 +1,32 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec256L9_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + if device == 'cpu' or device == torch.device("cpu") or device is None: + providers = ['CPUExecutionProvider'] + elif device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) + \ No newline at end of file diff --git a/vencoder/ContentVec768L12.py b/vencoder/ContentVec768L12.py new file mode 100644 index 0000000000000000000000000000000000000000..066b824b68447b5c860730c9f11b7be415068b46 --- /dev/null +++ b/vencoder/ContentVec768L12.py @@ -0,0 +1,37 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class ContentVec768L12(SpeechEncoder): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 768 + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device), + "output_layer": 12, # layer 12 + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) diff --git a/vencoder/ContentVec768L12_Onnx.py b/vencoder/ContentVec768L12_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..e737594526fd09f19353b85c11d4c357a325af48 --- /dev/null +++ b/vencoder/ContentVec768L12_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec768L12_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 768 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/ContentVec768L9_Onnx.py b/vencoder/ContentVec768L9_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd0f337bbf5fa261ea43adfab2377fced7c9e7c --- /dev/null +++ b/vencoder/ContentVec768L9_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec768L9_Onnx(SpeechEncoder): + def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 768 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/DPHubert.py b/vencoder/DPHubert.py new file mode 100644 index 0000000000000000000000000000000000000000..130064ff3ea5c24017be2f0faa204fc4c7dbd078 --- /dev/null +++ b/vencoder/DPHubert.py @@ -0,0 +1,29 @@ +import torch + +from vencoder.dphubert.model import wav2vec2_model +from vencoder.encoder import SpeechEncoder + + +class DPHubert(SpeechEncoder): + def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + ckpt = torch.load(vec_path) + self.hidden_dim = 768 + self.model = wav2vec2_model(**ckpt["config"]).to(self.dev) + self.model.load_state_dict(ckpt["state_dict"], strict=False) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats[None, :] + with torch.no_grad(): + with torch.inference_mode(): + units = self.model(feats)[0] + return units.transpose(1,2) diff --git a/vencoder/HubertSoft.py b/vencoder/HubertSoft.py new file mode 100644 index 0000000000000000000000000000000000000000..423c159c44f0e5cb820a911a47b71ae1478d725d --- /dev/null +++ b/vencoder/HubertSoft.py @@ -0,0 +1,28 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.hubert import hubert_model + + +class HubertSoft(SpeechEncoder): + def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + hubert_soft = hubert_model.hubert_soft(vec_path) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.hidden_dim = 256 + self.model = hubert_soft.to(self.dev) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats[None,None,:] + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.units(feats) + return units.transpose(1,2) diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..038d78e8ffa0804cb63b146f8122b3f2bba2f637 --- /dev/null +++ b/vencoder/HubertSoft_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class HubertSoft_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/WavLMBasePlus.py b/vencoder/WavLMBasePlus.py new file mode 100644 index 0000000000000000000000000000000000000000..99df15be73c0c4774cea83a376f79fb68405bfa1 --- /dev/null +++ b/vencoder/WavLMBasePlus.py @@ -0,0 +1,32 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.wavlm.WavLM import WavLM, WavLMConfig + + +class WavLMBasePlus(SpeechEncoder): + def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + checkpoint = torch.load(vec_path) + self.cfg = WavLMConfig(checkpoint['cfg']) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.hidden_dim = self.cfg.encoder_embed_dim + self.model = WavLM(self.cfg) + self.model.load_state_dict(checkpoint['model']) + self.model.to(self.dev).eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + if self.cfg.normalize: + feats = torch.nn.functional.layer_norm(feats, feats.shape) + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.extract_features(feats[None, :])[0] + return units.transpose(1, 2) diff --git a/vencoder/WhisperPPG.py b/vencoder/WhisperPPG.py new file mode 100644 index 0000000000000000000000000000000000000000..86af53e69b5f60f143a4acce0949c24812e327d1 --- /dev/null +++ b/vencoder/WhisperPPG.py @@ -0,0 +1,31 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper + + +class WhisperPPG(SpeechEncoder): + def __init__(self, vec_path="pretrain/medium.pt", device=None): + super().__init__() + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + checkpoint = torch.load(vec_path, map_location=device) + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + self.hidden_dim = dims + self.model = model.to(self.dev) + + def encoder(self, wav): + audio = wav + audln = audio.shape[0] + ppgln = audln // 320 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio).to(self.dev) + with torch.no_grad(): + ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/WhisperPPGLarge.py b/vencoder/WhisperPPGLarge.py new file mode 100644 index 0000000000000000000000000000000000000000..cd0ff76ba56815c1af0178e0d949e02a9a80f5fb --- /dev/null +++ b/vencoder/WhisperPPGLarge.py @@ -0,0 +1,31 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper + + +class WhisperPPGLarge(SpeechEncoder): + def __init__(self, vec_path="pretrain/large-v2.pt", device=None): + super().__init__() + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + checkpoint = torch.load(vec_path, map_location=device) + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + self.hidden_dim = dims + self.model = model.to(self.dev) + + def encoder(self, wav): + audio = wav + audln = audio.shape[0] + ppgln = audln // 320 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio).to(self.dev) + with torch.no_grad(): + ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/__init__.py b/vencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/__init__.py b/vencoder/dphubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/components.py b/vencoder/dphubert/components.py new file mode 100644 index 0000000000000000000000000000000000000000..be5cc8ce28f11f4f1339578a9d2658740f103283 --- /dev/null +++ b/vencoder/dphubert/components.py @@ -0,0 +1,1410 @@ +"""Building blocks for speech SSL models supporting pruning. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py + +""" + +import math +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.nn import Module + +from .hardconcrete import HardConcrete +from .pruning_utils import ( + prune_conv1d_layer, + prune_layer_norm, + prune_linear_layer, +) + + +def _init_transformer_params(module): + """ + Initialize the weights of Transformer module in Wav2Vec2/HuBERT. + + If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. + If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. + + If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. + If ``padding_idx`` is not None, set the weight of padding to 0. + + Note: + Ths method corresponds to + `init_bert_params + `__ + in the original ``fairseq`` implementation. + """ + + def normal_(data): + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class LayerNorm(nn.LayerNorm): + """Layer norm with transpose""" + + def forward(self, input: Tensor) -> Tensor: + x = input.transpose(-2, -1) + x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(-2, -1) + return x + + +class ConvLayerBlock(Module): + """Convolution unit of FeatureExtractor""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool, + layer_norm: Optional[Module], + prune_conv_channels: bool = False, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.layer_norm = layer_norm + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + if prune_conv_channels: + self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01) + else: + self.hard_concrete = None + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Shape: ``[batch, in_channels, in_frame]``. + length (Tensor or None, optional): Shape ``[batch, ]``. + Returns: + Tensor: Shape ``[batch, out_channels, out_frames]``. + Optional[Tensor]: Shape ``[batch, ]``. + """ + x = self.conv(x) + if self.layer_norm is not None: + x = self.layer_norm(x) + x = nn.functional.gelu(x) + + if self.hard_concrete is not None: + channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,) + x = x * channel_mask.unsqueeze(-1) + + if length is not None: + length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 + # When input length is 0, the resulting length can be negative. So fix it here. + length = torch.max(torch.zeros_like(length), length) + return x, length + + def get_num_params_and_out_channels(self, in_channels): + if self.hard_concrete is not None: + out_channels = self.hard_concrete.l0_norm() + else: + out_channels = self.conv.out_channels + + num_params = in_channels * out_channels * self.kernel_size + if self.conv.bias is not None: + num_params += out_channels + if self.layer_norm is not None: + num_params += out_channels * 2 + + return num_params, out_channels + + +class FeatureExtractor(Module): + """Extract features from audio + + Args: + conv_layers (nn.ModuleList): + convolution layers + """ + + def __init__( + self, + conv_layers: nn.ModuleList, + ): + super().__init__() + self.conv_layers = conv_layers + + # NOTE: a dummy weight used to save the soft mask of the last conv layer + self.dummy_weight = nn.Parameter( + torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32), + requires_grad=False + ) + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): + Input Tensor representing a batch of audio, + shape: ``[batch, time]``. + length (Tensor or None, optional): + Valid length of each input sample. shape: ``[batch, ]``. + + Returns: + Tensor: + The resulting feature, shape: ``[batch, frame, feature]`` + Optional[Tensor]: + Valid length of each output sample. shape: ``[batch, ]``. + """ + if x.ndim != 2: + raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}") + + x = x.unsqueeze(1) # (batch, channel==1, frame) + for layer in self.conv_layers: + x, length = layer(x, length) # (batch, feature, frame) + x = x.transpose(1, 2) # (batch, frame, feature) + x = x * self.dummy_weight + return x, length + + def get_num_params_and_final_out_channels(self): + in_channels = 1 + num_params = 0 + for layer in self.conv_layers: + layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels) + num_params += layer_params + + num_params += in_channels # dummy weight + + return num_params, in_channels + + def prune(self): + """"Prune conv layers and dummy weight based on hardconcrete parameters. + This is an in-place operation. + """ + new_config = [] # [(output_channel, kernel_size, stride), ...] + for idx, layer in enumerate(self.conv_layers): + if layer.hard_concrete is not None: + assert not layer.hard_concrete.training + mask = layer.hard_concrete() # (out_features,) + index = mask.nonzero().squeeze(-1) # 2D -> 1D + assert len(index) > 0, f"Conv channels pruned to zero at index {idx}" + new_config.append( + (len(index), layer.kernel_size, layer.stride) + ) + + # prune the current layer + prune_conv1d_layer(layer.conv, index, "output") + if layer.layer_norm is not None: + prune_layer_norm(layer.layer_norm, index) + + # prune the next layer + if idx == len(self.conv_layers) - 1: + self.dummy_weight.data *= mask + self.dummy_weight = nn.Parameter( + self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False + ) + else: + self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1) + prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input") + + layer.hard_concrete = None + else: + new_config.append( + (layer.conv.out_channels, layer.kernel_size, layer.stride) + ) + index = torch.arange(layer.conv.out_channels, dtype=torch.long) + + return new_config, index + + +class FeatureProjection(Module): + """Layer that connects FeatureExtractor and Encoder + + Projects features to encoder dimension. + + Args: + in_features (int): Input feature dim. + out_features (int): Output feature dim. + dropout (float): Dropout probability. + """ + + def __init__( + self, + in_features: int, + out_features: int, + dropout: float, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(in_features) + self.projection = nn.Linear( + in_features, + out_features, + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x (Tensor): + Feature Tensor. shape: ``[batch, frame, in_feature]`` + Returns: + Tensor: Projected features. ``[batch, frame, out_feature]``. + """ + x = self.layer_norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + def get_num_params(self, in_features): + return in_features * 2 + (in_features + 1) * self.projection.out_features + + +class ConvolutionalPositionalEmbedding(Module): + """Positional embedding which is placed at the beginning of Transformer. + + Args: + embed_dim (int): Feature dimension of the input Tensor. + kernel_size (int): The number of frames to be use. + groups (int): The number of groups in feature dimensions. + """ + + def __init__( + self, + embed_dim: int, + kernel_size: int, + groups: int, + ): + super().__init__() + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self.conv = nn.Conv1d( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 + + def __prepare_scriptable__(self): + for hook in self.conv._forward_pre_hooks.values(): + # The hook we want to remove is an instance of WeightNorm class, so + # normally we would do `if isinstance(...)` but this class is not accessible + # because of shadowing, so we check the module name directly. + # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 + if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm": + torch.nn.utils.remove_weight_norm(self.conv) + return self + + def forward(self, x): + """ + Args: + x (Tensor): shape ``[batch, frame, feature]``. + + Returns: + Tensor: The resulting feature. Shape ``[batch, frame, feature]``. + """ + x = x.transpose(-2, -1) + x = self.conv(x) + if self.num_remove > 0: + x = x[..., : -self.num_remove] + x = torch.nn.functional.gelu(x) + x = x.transpose(-2, -1) + return x + + +class SelfAttention(Module): + """Multihead Self Attention module + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): + Dropout probability on attn_output_weights. Default: ``0.0`` + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + prune_heads: bool = False, # whether to prune attention heads + prune_layer: bool = False, # whether to prune entire attention layers + ): + super().__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = torch.nn.Dropout(dropout) + + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True) + + if prune_heads: + self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01) + else: + self.hard_concrete_for_heads = None + + if prune_layer: + self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) + else: + self.hard_concrete_for_layer = None + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. + attention_mask (Tensor or ``None``, optional): + shape: ``[batch_size, 1, sequence_length, sequence_length]`` + position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. + key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with + :py:class:`WavLMSelfAttention`. + Returns: + (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility + with :py:class:`WavLMSelAttention`). + Attention output shape: ``[batch, sequence_length, embed_dim]``. + """ + if x.ndim != 3 or x.shape[2] != self.embed_dim: + raise ValueError( + f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." + ) + batch_size, length, embed_dim = x.size() + + shape = (batch_size, length, self.num_heads, self.head_dim) + q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L + v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + + # scale down q to avoid value overflow. + weights = (self.scaling * q) @ k # B, nH, L, L + if attention_mask is not None: + weights += attention_mask + # subtracting a constant value from the tensor won't change the output of softmax. + # apply the subtraction to avoid value overflow in torch.nn.functional.softmax. + # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 + weights = weights - weights.max(dim=-1, keepdim=True)[0] + + weights = torch.nn.functional.softmax(weights, dim=-1) + weights = self.dropout(weights) + + output = weights @ v # B, nH, L, Hd + + if self.hard_concrete_for_heads is not None: + head_mask = self.hard_concrete_for_heads() # (nH,) + output = output * head_mask.unsqueeze(-1).unsqueeze(-1) + + output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim) + + output = self.out_proj(output) + + if self.hard_concrete_for_layer is not None: + layer_mask = self.hard_concrete_for_layer() # (1,) + output = output * layer_mask + + return output, None # Necessary for compatibility with WavLMSelAttention + + def get_num_params(self): + if self.hard_concrete_for_heads is not None: + num_heads = self.hard_concrete_for_heads.l0_norm() + else: + num_heads = self.num_heads + num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \ + + (num_heads * self.head_dim + 1) * self.embed_dim + + if self.hard_concrete_for_layer is not None: + num_params *= self.hard_concrete_for_layer.l0_norm() + + return num_params + + def prune(self): + new_config = { + "use_attention": True, + "num_heads": self.num_heads, + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() # (1,) + self.out_proj.weight.data *= layer_mask + self.out_proj.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_attention"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_heads is not None: + assert not self.hard_concrete_for_heads.training + head_mask = self.hard_concrete_for_heads() # (num_heads,) + new_config["num_heads"] = len(head_mask.nonzero()) + if new_config["num_heads"] == 0: + new_config["use_attention"] = False + else: + full_mask = head_mask.repeat_interleave(self.head_dim) + full_index = full_mask.nonzero().squeeze(-1) # 1D + + prune_linear_layer(self.k_proj, full_index, "output") + prune_linear_layer(self.v_proj, full_index, "output") + prune_linear_layer(self.q_proj, full_index, "output") + + self.out_proj.weight.data *= full_mask + prune_linear_layer(self.out_proj, full_index, "input") + self.hard_concrete_for_heads = None + + return new_config + + +class WavLMSelfAttention(SelfAttention): + """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) + bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) + has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. + Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) + num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) + max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) + gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) + """ + + def __init__( + self, + embed_dim: int, + total_num_heads: int, + remaining_heads: Optional[List[int]] = None, + dropout: float = 0.0, + bias: bool = True, + has_relative_attention_bias: bool = False, + num_buckets: int = 32, + max_distance: int = 128, + gru_rel_pos: bool = True, + prune_heads: bool = False, + prune_layer: bool = False, + ): + self.total_num_heads = total_num_heads + if remaining_heads is None: + self.remaining_heads = list(range(total_num_heads)) + else: + self.remaining_heads = remaining_heads # list of indices + + self.head_dim = embed_dim // total_num_heads + + super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + + if has_relative_attention_bias: + self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads) + else: + self.rel_attn_embed = None + + # override linear layers to customize bias + self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias) + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) + self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1)) + self.has_position_bias = True + + def compute_bias(self, query_length: int, key_length: int) -> Tensor: + """Compute relative position embeddings for WavLM model. + Args: + query_length (int): Query position can take values between 0 and ``query_length - 1``. + key_length (int): Key position can take values between 0 and ``key_length - 1``. + Returns: + Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings + """ + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # Shape (query_length, key_length) + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) + values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]) + return values + + def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): + """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM + paper :cite:`chen2022wavlm`. + Args: + relative_positions (Tensor): Relative offsets between query and key positions, + of shape ``(query_length, key_length)``. + bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting + matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set + to zero. (Default ``True``) + Returns: + Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. + """ + num_buckets = self.num_buckets + max_distance = self.max_distance + # Shape (query_length, key_length) + relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def forward( + self, + query: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. + key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape + `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) + attn_mask: Needs to be ``None``. The argument exists for compatibility with + ``EncoderLayer``. (Default: ``None``) + position_bias (Tensor or None, optional): Position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be + generated in the first layer and then passed from each encoder layer to the next one. + (Default: ``None``) + Returns: + attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. + position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. + """ + bsz, seq_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert key_padding_mask is None + + # only for the first layer + if self.rel_attn_embed is not None and position_bias is None: + position_bias = self.compute_bias(seq_len, seq_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len) + + attn_mask_rel_pos: Optional[Tensor] = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: # Apply gating on relative position bias + query_layer = query.view(bsz, seq_len, self.total_num_heads, -1) + query_layer = query_layer.permute(0, 2, 1, 3) + + gate_a, gate_b = torch.sigmoid( + self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) + attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :] + + attn_mask = attn_mask_rel_pos + if attention_mask is not None: + attn_mask = attn_mask + attention_mask + if key_padding_mask is not None: + attn_mask = attn_mask.masked_fill( + key_padding_mask.reshape(bsz, 1, 1, seq_len), + float("-inf") + ) + attn_output, _ = super().forward(query, attention_mask=attn_mask) + + return attn_output, position_bias + + def prune(self): + new_config = { + "use_attention": True, + "remaining_heads": self.remaining_heads, + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() # (1,) + self.out_proj.weight.data *= layer_mask + self.out_proj.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_attention"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_heads is not None: + assert not self.hard_concrete_for_heads.training + head_mask = self.hard_concrete_for_heads() # (num_heads,) + new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist() + if len(new_config["remaining_heads"]) == 0: + new_config["use_attention"] = False + else: + full_mask = head_mask.repeat_interleave(self.head_dim) + full_index = full_mask.nonzero().squeeze(-1) # 1D + + prune_linear_layer(self.k_proj, full_index, "output") + prune_linear_layer(self.v_proj, full_index, "output") + prune_linear_layer(self.q_proj, full_index, "output") + + self.out_proj.weight.data *= full_mask + prune_linear_layer(self.out_proj, full_index, "input") + self.hard_concrete_for_heads = None + + return new_config + + +class FeedForward(Module): + """Layer that follows attention layer in encoder layer.""" + + def __init__( + self, + io_features: int, + intermediate_features: int, + intermediate_dropout: float, + output_dropout: float, + prune_intermediate: bool = False, + prune_layer: bool = False, + ): + super().__init__() + self.intermediate_dense = nn.Linear(io_features, intermediate_features) + self.intermediate_dropout = nn.Dropout(intermediate_dropout) + self.output_dense = nn.Linear(intermediate_features, io_features) + self.output_dropout = nn.Dropout(output_dropout) + + if prune_intermediate: + self.hard_concrete_for_intermediate = HardConcrete( + n_in=intermediate_features, init_mean=0.5 + ) + else: + self.hard_concrete_for_intermediate = None + + if prune_layer: + self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) + else: + self.hard_concrete_for_layer = None + + def forward(self, x): + """ + Args: + x (Tensor): shape: `(batch, sequence_length, io_features)` + Returns: + x (Tensor): shape: `(batch, sequence_length, io_features)` + """ + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.intermediate_dropout(x) + + if self.hard_concrete_for_intermediate is not None: + intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,) + x = x * intermediate_mask + + x = self.output_dense(x) + x = self.output_dropout(x) + + if self.hard_concrete_for_layer is not None: + layer_mask = self.hard_concrete_for_layer() # (1,) + x = x * layer_mask + + return x + + def get_num_params(self): + io_features = self.intermediate_dense.in_features + if self.hard_concrete_for_intermediate is not None: + intermediate_features = self.hard_concrete_for_intermediate.l0_norm() + else: + intermediate_features = self.intermediate_dense.out_features + num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features + + if self.hard_concrete_for_layer is not None: + num_params *= self.hard_concrete_for_layer.l0_norm() + + return num_params + + def prune(self): + new_config = { + "use_feed_forward": True, + "ff_interm_features": self.intermediate_dense.out_features + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() + self.output_dense.weight.data *= layer_mask + self.output_dense.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_feed_forward"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_intermediate is not None: + assert not self.hard_concrete_for_intermediate.training + interm_mask = self.hard_concrete_for_intermediate() + interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1 + new_config["ff_interm_features"] = len(interm_index) + if new_config["ff_interm_features"] == 0: + new_config["use_feed_forward"] = False + else: + prune_linear_layer(self.intermediate_dense, interm_index, "output") + + self.output_dense.weight.data *= interm_mask + prune_linear_layer(self.output_dense, interm_index, "input") + self.hard_concrete_for_intermediate = None + + return new_config + + +class EncoderLayer(Module): + """A layer unit in encoder. Combines multihead self attention and feed forward.""" + + def __init__( + self, + attention: Optional[Module], # can be None if the entire layer is pruned + dropout: float, + layer_norm_first: bool, + feed_forward: Optional[Module], # can be None if the entire layer is pruned + embed_dim: int, + ): + super().__init__() + self.attention = attention + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(embed_dim) + self.layer_norm_first = layer_norm_first + self.feed_forward = feed_forward + self.final_layer_norm = nn.LayerNorm(embed_dim) + self.embed_dim = embed_dim + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. + attention_mask (Tensor or ``None``, optional): attention mask + of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) + position_bias (Tensor or ``None``, optional): position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. + Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) + key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. + Only used for WavLM model, ignored otherwise. (Default: ``None``) + Returns: + (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, + ``None`` otherwise. + """ + if self.attention is not None: + residual = x + + if self.layer_norm_first: + x = self.layer_norm(x) + + x, position_bias = self.attention( + x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask + ) + + x = self.dropout(x) + x = residual + x + + if self.layer_norm_first: + if self.feed_forward is not None: + x = x + self.feed_forward(self.final_layer_norm(x)) + else: + # NOTE: for post norm, the layer norms should always be applied even if the layers are pruned. + x = self.layer_norm(x) + if self.feed_forward is not None: + x = x + self.feed_forward(x) + x = self.final_layer_norm(x) + return x, position_bias + + def get_num_params(self): + num_params = self.embed_dim * 2 * 2 # two layer norms + if self.attention is not None: + num_params += self.attention.get_num_params() + if self.feed_forward is not None: + num_params += self.feed_forward.get_num_params() + return num_params + + +class Transformer(Module): + def __init__( + self, + pos_conv_embed: Module, + dropout: float, + layers: Module, + layer_norm_first: bool, + layer_drop: float, + ): + super().__init__() + self.pos_conv_embed = pos_conv_embed + self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) + self.layer_norm_first = layer_norm_first + self.layer_drop = layer_drop + self.dropout = nn.Dropout(dropout) + self.layers = layers + + def _preprocess(self, x: Tensor): + x = x + self.pos_conv_embed(x) + + if self.layer_norm_first: + x = self.layer_norm(x) + + x = self.dropout(x) + return x + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tensor: + x = self._preprocess(x) + for layer in self.layers: + if not (self.training and torch.rand(1).item() <= self.layer_drop): + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + + if not self.layer_norm_first: + x = self.layer_norm(x) + return x + + def get_intermediate_outputs( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + num_layers: Optional[int] = None, + position_bias: Optional[Tensor] = None, + ) -> List[Tensor]: + if num_layers is not None: + if not 0 < num_layers <= len(self.layers): + raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") + + ret: List[Tensor] = [] + x = self._preprocess(x) + for layer in self.layers: + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + ret.append(x) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + def get_num_params(self): + # pos_conv_embed and layer_norm + num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2 + for layer in self.layers: + num_params += layer.get_num_params() + return num_params + + def prune(self): + new_config = defaultdict(list) + for layer in self.layers: + attention_config = layer.attention.prune() + new_config["use_attention"].append(attention_config["use_attention"]) + if "remaining_heads" in attention_config: + new_config["remaining_heads"].append(attention_config["remaining_heads"]) + else: + new_config["num_heads"].append(attention_config["num_heads"]) + + if not attention_config["use_attention"]: + layer.attention = None + + ff_config = layer.feed_forward.prune() + new_config["use_feed_forward"].append(ff_config["use_feed_forward"]) + new_config["ff_interm_features"].append(ff_config["ff_interm_features"]) + if not ff_config["use_feed_forward"]: + layer.feed_forward = None + + return new_config + + +class Encoder(Module): + def __init__( + self, + feature_projection: Module, + transformer: Module, + ): + super().__init__() + self.feature_projection = feature_projection + self.transformer = transformer + + def _preprocess( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + x = self.feature_projection(features) + + mask: Optional[Tensor] = None + if lengths is not None: + batch_size, max_len, _ = x.shape + # create mask for padded elements and zero-out them + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + x[mask] = 0.0 + # extend the mask to attention shape and set weight + mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) + mask = mask.expand(batch_size, 1, max_len, max_len) + return x, mask + + def forward( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tensor: + x, mask = self._preprocess(features, lengths) + x = self.transformer(x, attention_mask=mask) + return x + + def extract_features( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + x, masks = self._preprocess(features, lengths) + interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) + return [x] + interm + + def get_num_params(self, in_features): + """Calculate the current model size.""" + feature_projection_size = self.feature_projection.get_num_params(in_features) + transformer_size = self.transformer.get_num_params() + return feature_projection_size + transformer_size + + def prune(self, conv_out_index): + """In-place pruning of submodules.""" + prune_layer_norm(self.feature_projection.layer_norm, conv_out_index) + prune_linear_layer(self.feature_projection.projection, conv_out_index, "input") + transformer_config = self.transformer.prune() + return transformer_config + + +################################################################################ +def _get_feature_extractor( + norm_mode: str, + shapes: List[Tuple[int, int, int]], + bias: bool, + prune_conv_channels: bool = False, +) -> FeatureExtractor: + """ + Args: + norm_mode (str): + Either "group_norm" or "layer_norm". + If "group_norm", then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + This option corresponds to "extractor_mode" from fairseq. + Expected values are "group_norm" for Base arch, and + "layer_norm" for Large arch. + shapes (list of tuple of int): + Configuration of convolution layers. List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + This option corresponds to "conv_feature_layers" from fairseq. + Expected values are + ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` + for all the architectures. + bias (bool): + Whether to include bias term to each convolution operation. + This option corresponds to "conv_bias" from fairseq. + Expected values are False for Base arch, and True for Large arch. + + See Also: + * Original implementation + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 + * "extractor_mode" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 + * "conv_feature_layers" + - Def, base and large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 + * "conv_bias" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 + """ + if norm_mode not in ["group_norm", "layer_norm"]: + raise ValueError("Invalid norm mode") + blocks = [] + in_channels = 1 + for i, (out_channels, kernel_size, stride) in enumerate(shapes): + normalization = None + if norm_mode == "group_norm" and i == 0: + normalization = nn.GroupNorm( + num_groups=out_channels, + num_channels=out_channels, + affine=True, + ) + elif norm_mode == "layer_norm": + normalization = LayerNorm( + normalized_shape=out_channels, + elementwise_affine=True, + ) + blocks.append( + ConvLayerBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + layer_norm=normalization, + prune_conv_channels=prune_conv_channels, + ) + ) + in_channels = out_channels + return FeatureExtractor(nn.ModuleList(blocks)) + + +def _get_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + use_attention: List[bool], + use_feed_forward: List[bool], + num_heads: List[int], + head_dim: int, + attention_dropout: float, + ff_interm_features: List[int], + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, + prune_attention_heads: bool = False, + prune_attention_layer: bool = False, + prune_feed_forward_intermediate: bool = False, + prune_feed_forward_layer: bool = False, +) -> Encoder: + """ + Args: + in_features (int): The number of input features. + embed_dim (int): + The dimension of embedding. + This option corresponds to "encoder_embed_dim" from fairseq. + Expected values are 768 for Base arch, and 1024 for Large arch. + dropout_input (float): + The dropout probability applied after the input feature is projected + to ``embed_dim``. + This option corresponds to "dropout_input" from fairseq. + Expected values are 0.1 for both Base and Large arch. + pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + This option corresponds to "conv_pos" from fairseq. + Expected values are 128 for both Base and Large arch. + pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + This option corresponds to "conv_pos_groups" from fairseq. + Expected values are 16 for both Base and Large arch. + num_layers (int): + The number of self attention layers in transformer block. + This option corresponds to "encoder_layers" from fairseq. + Expected values are 12 for Base and 24 for Large arch. + num_heads (int): + The number of heads in self attention layers. + This option corresponds to "encoder_attention_heads" from fairseq. + Expected values are 12 for Base and 16 for Large arch. + attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + This option corresponds to "attention_dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + ff_interm_features (int): + The dimension of hidden features in feed forward layer. + This option corresponds to "encoder_ffn_embed_dim" from fairseq. + Expected values are 3072 for Base and 4096 for Large arch. + ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + This option correspinds to "activation_dropout" from fairseq. + Expected values are 0.1 for both Base and Large arch. + dropout (float): + The dropout probability applied at the end of feed forward layer. + This option corresponds to "dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + This option corresponds to "layer_norm_first" from fairseq. + Expected values are False for Base and True for Large arch. + layer_drop (float): + Probability to drop each encoder layer during training. + This option corresponds to "layerdrop" from fairseq. + Expected values are 0.1 for both Base and Large arch. + + See Also: + * "encoder_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 + * "dropout_input" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 + * "conv_pos" + - Def, base and large + NOTE: The description is wrong. + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 + - Usage + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 + * "conv_pos_groups" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 + * "encoder_layers" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 + * "encoder_attention_heads" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 + * "attention_dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 + * "encoder_ffn_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 + * "activation_dropout" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 + * "dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 + * "layer_norm_first" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 + * "layerdrop" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for idx in range(num_layers): + if use_attention[idx]: + attention = SelfAttention( + embed_dim=embed_dim, + num_heads=num_heads[idx], + head_dim=head_dim, + dropout=attention_dropout, + prune_heads=prune_attention_heads, + prune_layer=prune_attention_layer, + ) + else: + attention = None + if use_feed_forward[idx]: + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features[idx], + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + prune_intermediate=prune_feed_forward_intermediate, + prune_layer=prune_feed_forward_layer, + ) + else: + feed_forward = None + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + embed_dim=embed_dim, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_wavlm_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + use_attention: List[bool], + use_feed_forward: List[bool], + total_num_heads: List[int], + remaining_heads: List[List[int]], + num_buckets: int, + max_distance: int, + attention_dropout: float, + ff_interm_features: List[int], + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, + prune_attention_heads: bool = False, + prune_attention_layer: bool = False, + prune_feed_forward_intermediate: bool = False, + prune_feed_forward_layer: bool = False, +) -> Encoder: + """ + Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are + the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder + is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and + `max_distance`. + Args: + in_features (int): See :py:func:`_get_encoder`. + embed_dim (int): See :py:func:`_get_encoder`. + dropout_input (float): See :py:func:`_get_encoder`. + pos_conv_kernel (int): See :py:func:`_get_encoder`. + pos_conv_groups (int): See :py:func:`_get_encoder`. + num_layers (int): See :py:func:`_get_encoder`. + num_heads (int): See :py:func:`_get_encoder`. + num_buckets (int): Number of buckets for relative position embedding. + max_distance (int): Maximum distance for relative position embedding. + attention_dropout (float): See :py:func:`_get_encoder`. + ff_interm_features (int): See :py:func:`_get_encoder`. + ff_interm_dropout (float): See :py:func:`_get_encoder`. + dropout (float): See :py:func:`_get_encoder`. + layer_norm_first (bool): See :py:func:`_get_encoder`. + layer_drop (float): See :py:func:`_get_encoder`. + + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for i in range(num_layers): + if use_attention[i]: + attention = WavLMSelfAttention( + embed_dim=embed_dim, + total_num_heads=total_num_heads[i], + remaining_heads=remaining_heads[i], + dropout=attention_dropout, + has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. + num_buckets=num_buckets, + max_distance=max_distance, + prune_heads=prune_attention_heads, + prune_layer=prune_attention_layer, + ) + else: + attention = None + if use_feed_forward[i]: + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features[i], + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + prune_intermediate=prune_feed_forward_intermediate, + prune_layer=prune_feed_forward_layer, + ) + else: + feed_forward = None + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + embed_dim=embed_dim, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: + """Generate the padding mask given the padded input and the lengths Tensors. + Args: + input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. + lengths (Tensor): The lengths Tensor of dimension `[batch,]`. + + Returns: + (Tensor): The padding mask. + """ + batch_size, max_len, _ = input.shape + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + return mask + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/vencoder/dphubert/hardconcrete.py b/vencoder/dphubert/hardconcrete.py new file mode 100644 index 0000000000000000000000000000000000000000..468a30d1eccdf20ee7493e71792c46e48449c4e3 --- /dev/null +++ b/vencoder/dphubert/hardconcrete.py @@ -0,0 +1,122 @@ +"""Implementation of the hard Concrete distribution. + +Originally from: +https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py + +""" + +import math + +import torch +import torch.nn as nn + + +class HardConcrete(nn.Module): + """A HarcConcrete module. + Use this module to create a mask of size N, which you can + then use to perform L0 regularization. + + To obtain a mask, simply run a forward pass through the module + with no input data. The mask is sampled in training mode, and + fixed during evaluation mode, e.g.: + + >>> module = HardConcrete(n_in=100) + >>> mask = module() + >>> norm = module.l0_norm() + """ + + def __init__( + self, + n_in: int, + init_mean: float = 0.5, + init_std: float = 0.01, + temperature: float = 2/3, # from CoFi + stretch: float = 0.1, + eps: float = 1e-6 + ) -> None: + """Initialize the HardConcrete module. + Parameters + ---------- + n_in : int + The number of hard concrete variables in this mask. + init_mean : float, optional + Initial drop rate for hard concrete parameter, + by default 0.5., + init_std: float, optional + Used to initialize the hard concrete parameters, + by default 0.01. + temperature : float, optional + Temperature used to control the sharpness of the + distribution, by default 1.0 + stretch : float, optional + Stretch the sampled value from [0, 1] to the interval + [-stretch, 1 + stretch], by default 0.1. + """ + super().__init__() + + self.n_in = n_in + self.limit_l = -stretch + self.limit_r = 1.0 + stretch + self.log_alpha = nn.Parameter(torch.zeros(n_in)) + self.beta = temperature + self.init_mean = init_mean + self.init_std = init_std + self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) + + self.eps = eps + self.compiled_mask = None + self.reset_parameters() + + def reset_parameters(self): + """Reset the parameters of this module.""" + self.compiled_mask = None + mean = math.log(1 - self.init_mean) - math.log(self.init_mean) + self.log_alpha.data.normal_(mean, self.init_std) + + def l0_norm(self) -> torch.Tensor: + """Compute the expected L0 norm of this mask. + Returns + ------- + torch.Tensor + The expected L0 norm. + """ + return (self.log_alpha + self.bias).sigmoid().sum() + + def forward(self) -> torch.Tensor: + """Sample a hard concrete mask. + Returns + ------- + torch.Tensor + The sampled binary mask + """ + if self.training: + # Reset the compiled mask + self.compiled_mask = None + # Sample mask dynamically + u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) + s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) + s = s * (self.limit_r - self.limit_l) + self.limit_l + mask = s.clamp(min=0., max=1.) + + else: + # Compile new mask if not cached + if self.compiled_mask is None: + # Get expected sparsity + expected_num_zeros = self.n_in - self.l0_norm().item() + num_zeros = round(expected_num_zeros) + # Approximate expected value of each mask variable z; + # We use an empirically validated magic number 0.8 + soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) + # Prune small values to set to 0 + _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) + soft_mask[indices] = 0. + self.compiled_mask = soft_mask + mask = self.compiled_mask + + return mask + + def extra_repr(self) -> str: + return str(self.n_in) + + def __repr__(self) -> str: + return "{}({})".format(self.__class__.__name__, self.extra_repr()) diff --git a/vencoder/dphubert/model.py b/vencoder/dphubert/model.py new file mode 100644 index 0000000000000000000000000000000000000000..348ede2c3edc3e5588ee75760085dee9eafd9d68 --- /dev/null +++ b/vencoder/dphubert/model.py @@ -0,0 +1,966 @@ +"""Speech SSL models supporting pruning. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/model.py + +""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module + +from . import components + + +class Wav2Vec2Model(Module): + """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`. + + Note: + To build the model, please use one of the factory functions. + :py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`, + :py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`, + and :py:func:`hubert_xlarge`. + + See Also: + * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning) + * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models. + + Args: + feature_extractor (torch.nn.Module): + Feature extractor that extracts feature vectors from raw audio Tensor. + + encoder (torch.nn.Module): + Encoder that converts the audio features into the sequence of probability + distribution (in negative log-likelihood) over labels. + + aux (torch.nn.Module or None, optional): + Auxiliary module. If provided, the output from encoder is passed to this module. + """ # noqa: E501 + + def __init__( + self, + normalize_waveform: bool, + feature_extractor: Module, + encoder: Module, + aux: Optional[Module] = None, + ): + super().__init__() + self.normalize_waveform = normalize_waveform + self.feature_extractor = feature_extractor + self.encoder = encoder + self.aux = aux + + @torch.jit.export + def extract_features( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> Tuple[List[Tensor], Optional[Tensor]]: + """Extract feature vectors from raw waveforms + + This returns the list of outputs from the intermediate layers of + transformer block in encoder. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that the entire audio waveform + length is valid. + num_layers (int or None, optional): + If given, limit the number of intermediate layers to go through. + Providing `1` will stop the computation after going through one + intermediate layers. If not given, the outputs from all the + intermediate layers are returned. + + Returns: + (List[Tensor], Optional[Tensor]): + List of Tensors + Features from requested layers. + Each Tensor is of shape: `(batch, time frame, feature dimension)` + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of each feature Tensor. + """ + if self.normalize_waveform: + if lengths is not None: + waveforms = [ + F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) + ] + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + else: + waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) + + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder.extract_features(x, lengths, num_layers) # (num_layers+1,), including the input + return x, lengths + + def get_num_params(self): + """Calculate the current size.""" + feature_extractor_size, encoder_in_features = self.feature_extractor.get_num_params_and_final_out_channels() + encoder_size = self.encoder.get_num_params(encoder_in_features) + return feature_extractor_size + encoder_size + + def prune(self): + self.eval() # must be in eval mode + conv_config, conv_out_index = self.feature_extractor.prune() # [(output_channel, kernel_size, stride), ...] + transformer_config = self.encoder.prune(conv_out_index) # NOTE: this is a defaultdict(list) + use_attention = transformer_config["use_attention"] + use_feed_forward = transformer_config["use_feed_forward"] + num_heads = transformer_config["num_heads"] # can be [] + remaining_heads = transformer_config["remaining_heads"] # can be [] + ff_interm_features = transformer_config["ff_interm_features"] + + return conv_config, use_attention, use_feed_forward, num_heads, remaining_heads, ff_interm_features + + def forward( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Compute the sequence of probability distribution over labels. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The sequences of probability distribution (in logit) over labels. + Shape: `(batch, frames, num labels)`. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + if self.normalize_waveform: + if lengths is not None: + waveforms = [ + F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) + ] + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + else: + waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) + + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder(x, lengths) + if self.aux is not None: + x = self.aux(x) + return x, lengths + + +def wav2vec2_model(**configs) -> Wav2Vec2Model: + """Wraps the original wav2vec2_model and wavlm_model.""" + + if "encoder_remaining_heads" in configs: + return wavlm_model(**configs) + + return wav2vec2_model_original(**configs) + + +def wav2vec2_model_original( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_use_attention: List[bool], + encoder_use_feed_forward: List[bool], + encoder_num_heads: List[int], + encoder_head_dim: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: List[int], + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], + normalize_waveform: bool, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`. + + Note: + The "feature extractor" below corresponds to + `ConvFeatureExtractionModel `__ + in the original ``fairseq`` implementation. + This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* + :cite:`baevski2020wav2vec` paper. + + The "encoder" below corresponds to `TransformerEncoder `__, + and this is referred as "Transformer" in the paper. + + Args: + extractor_mode (str): Operation mode of feature extractor. + Valid values are ``"group_norm"`` or ``"layer_norm"``. + If ``"group_norm"``, then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + + This option corresponds to ``extractor_mode`` from ``fairseq``. + extractor_conv_layer_config (list of integer tuples or None): + Configuration of convolution layers in feature extractor. + List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + + If ``None`` is provided, then the following default value is used. + + .. code-block:: python + + [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ] + + This option corresponds to ``conv_feature_layers`` from ``fairseq``. + + extractor_conv_bias (bool): + Whether to include bias term to each convolution operation. + + This option corresponds to ``conv_bias`` from ``fairseq``. + + encoder_embed_dim (int): + The dimension of embedding in encoder. + + This option corresponds to ``encoder_embed_dim`` from ``fairseq``. + + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected + to ``encoder_embed_dim``. + + This option corresponds to ``dropout_input`` from ``fairseq``. + + encoder_pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + + This option corresponds to ``conv_pos`` from ``fairseq``. + + encoder_pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + + This option corresponds to ``conv_pos_groups`` from ``fairseq``. + + encoder_num_layers (int): + The number of self attention layers in transformer block. + + This option corresponds to ``encoder_layers`` from ``fairseq``. + + encoder_num_heads (int): + The number of heads in self attention layers. + + This option corresponds to ``encoder_attention_heads`` from ``fairseq``. + + encoder_attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + + This option corresponds to ``attention_dropout`` from ``fairseq``. + + encoder_ff_interm_features (int): + The dimension of hidden features in feed forward layer. + + This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. + + encoder_ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + + This option correspinds to ``activation_dropout`` from ``fairseq``. + + encoder_dropout (float): + The dropout probability applied at the end of feed forward layer. + + This option corresponds to ``dropout`` from ``fairseq``. + + encoder_layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + + This option corresponds to ``layer_norm_first`` from ``fairseq``. + + encoder_layer_drop (float): + Probability to drop each encoder layer during training. + + This option corresponds to ``layerdrop`` from ``fairseq``. + + aux_num_out (int or None): + When provided, attach an extra linear layer on top of encoder, which can be + used for fine-tuning. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias, + prune_conv_channels=extractor_prune_conv_channels, + ) + encoder = components._get_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + use_attention=encoder_use_attention, + use_feed_forward=encoder_use_feed_forward, + num_heads=encoder_num_heads, + head_dim=encoder_head_dim, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + prune_attention_heads=encoder_prune_attention_heads, + prune_attention_layer=encoder_prune_attention_layer, + prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) + + +def wav2vec2_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def wav2vec2_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def wav2vec2_large_lv60k( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.05, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "base" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_use_attention=[True] * 12, + encoder_use_feed_forward=[True] * 12, + encoder_num_heads=[12] * 12, + encoder_head_dim=64, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=[3072] * 12, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_large( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_xlarge( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "extra large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def _init_hubert_pretrain_model(module): + if isinstance(module, components.LayerNorm): + torch.nn.init.kaiming_normal_(module.conv.weight) + elif isinstance(module, components.ConvolutionalPositionalEmbedding): + # normalize the weight to normal distribution. + std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size)) + torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std) + torch.nn.init.constant_(module.conv.bias, 0.0) + elif isinstance(module, components.SelfAttention): + # normalize the query, key, value, and out_proj parameters in self attention module. + torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.out_proj.weight) + torch.nn.init.constant_(module.out_proj.bias, 0.0) + elif isinstance(module, components.Transformer): + module.apply(components._init_transformer_params) + else: + pass + + +def wavlm_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_use_attention: List[bool], + encoder_use_feed_forward: List[bool], + encoder_total_num_heads: List[int], + encoder_remaining_heads: List[List[int]], + encoder_num_buckets: int, + encoder_max_distance: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: List[int], + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], + normalize_waveform: bool, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is + :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning + as in :py:func:`wav2vec2_model` so please refer there for documentation. + + Args: + extractor_mode (str): Operation mode of feature extractor. + See :py:func:`wav2vec2_model`. + + extractor_conv_layer_config (list of integer tuples or None): + See :py:func:`wav2vec2_model`. + + extractor_conv_bias (bool): + See :py:func:`wav2vec2_model`. + + encoder_embed_dim (int): + See :py:func:`wav2vec2_model`. + + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_pos_conv_kernel (int): + See :py:func:`wav2vec2_model`. + + encoder_pos_conv_groups (int): + See :py:func:`wav2vec2_model`. + + encoder_num_layers (int): + See :py:func:`wav2vec2_model`. + + encoder_num_heads (int): + See :py:func:`wav2vec2_model`. + + encoder_num_buckets (int): + Number of buckets for relative position embedding. + encoder_max_distance (int): + Maximum distance for relative position embedding. + + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_ff_interm_features (int): + See :py:func:`wav2vec2_model`. + + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_layer_norm_first (bool): + See :py:func:`wav2vec2_model`. + + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + + aux_num_out (int or None): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias, + prune_conv_channels=extractor_prune_conv_channels, + ) + encoder = components._get_wavlm_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + use_attention=encoder_use_attention, + use_feed_forward=encoder_use_feed_forward, + total_num_heads=encoder_total_num_heads, + remaining_heads=encoder_remaining_heads, + num_buckets=encoder_num_buckets, + max_distance=encoder_max_distance, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + prune_attention_heads=encoder_prune_attention_heads, + prune_attention_layer=encoder_prune_attention_layer, + prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) + + +def wavlm_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wavlm_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) diff --git a/vencoder/dphubert/pruning_utils.py b/vencoder/dphubert/pruning_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac185980c2c3da716bf3ce402a541ffe70776acf --- /dev/null +++ b/vencoder/dphubert/pruning_utils.py @@ -0,0 +1,51 @@ +"""Utility functions for pruning.""" + +from typing import Union + +import torch +import torch.nn as nn + + +def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str): + "Prune linear layer in place." + # NOTE: weight: (out_features, in_features), bias: (out_features,) + if dim == "input": + dim = 1 + layer.in_features = len(index) + elif dim == "output": + dim = 0 + layer.out_features = len(index) + else: + raise ValueError + + layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) + if layer.bias is not None and dim == 0: + layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) + + +def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str): + """Prune conv1d in place.""" + # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,) + if dim == "input": + dim = 1 + layer.in_channels = len(index) + elif dim == "output": + dim = 0 + layer.out_channels = len(index) + else: + raise ValueError + + layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) + if layer.bias is not None and dim == 0: + layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) + + +def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor): + """Prune layer norm or group norm in place.""" + layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach()) + layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach()) + if isinstance(layernorm, nn.LayerNorm): + layernorm.normalized_shape = (len(index),) + elif isinstance(layernorm, nn.GroupNorm): + layernorm.num_groups = len(index) + layernorm.num_channels = len(index) diff --git a/vencoder/dphubert/utils/__init__.py b/vencoder/dphubert/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/utils/import_huggingface_wavlm.py b/vencoder/dphubert/utils/import_huggingface_wavlm.py new file mode 100644 index 0000000000000000000000000000000000000000..24a3f38ae9cc08e19010b2876b19dc9082873377 --- /dev/null +++ b/vencoder/dphubert/utils/import_huggingface_wavlm.py @@ -0,0 +1,129 @@ +"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/utils/import_huggingface.py + +""" + +import logging +from typing import Any, Dict + +from torch.nn import Module + +from ..model import Wav2Vec2Model, wav2vec2_model, wavlm_model + +_LG = logging.getLogger(__name__) + + +def _get_config(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_num_heads": cfg.num_attention_heads, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": cfg.intermediate_size, + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + } + return config + + +def _get_config_wavlm(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_use_attention": [True] * cfg.num_hidden_layers, + "encoder_use_feed_forward": [True] * cfg.num_hidden_layers, + "encoder_total_num_heads": [cfg.num_attention_heads for _ in range(cfg.num_hidden_layers)], + "encoder_remaining_heads": [list(range(cfg.num_attention_heads)) for _ in range(cfg.num_hidden_layers)], + "encoder_num_buckets": cfg.num_buckets, + "encoder_max_distance": cfg.max_bucket_distance, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": [cfg.intermediate_size for _ in range(cfg.num_hidden_layers)], + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + "normalize_waveform": cfg.feat_extract_norm == "layer", + } + return config + + +def _build(config, original): + is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"] + if is_for_ctc: + aux_num_out = original.config.vocab_size + wav2vec2 = original.wav2vec2 + else: + _LG.warning( + "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.' + ) + aux_num_out = None + wav2vec2 = original + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + imported = wavlm_model(**config, aux_num_out=aux_num_out) + else: + imported = wav2vec2_model(**config, aux_num_out=aux_num_out) + print(imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict(), strict=False)) + print(imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict(), strict=False)) + encoder_state_dict = wav2vec2.encoder.state_dict() + if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model + transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"]) + print(imported.encoder.transformer.load_state_dict(encoder_state_dict, strict=False)) + if is_for_ctc: + imported.aux.load_state_dict(original.lm_head.state_dict()) + return imported + + +def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int): + """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and + biases to align with the structure of ``torch.nn.MultiheadAttention``. + """ + pass + + +def import_huggingface_model(original: Module) -> Wav2Vec2Model: + """Builds :class:`Wav2Vec2Model` from the corresponding model object of + `Transformers `_. + + Args: + original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. + + Returns: + Wav2Vec2Model: Imported model. + + Example + >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model + >>> + >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = import_huggingface_model(original) + >>> + >>> waveforms, _ = torchaudio.load("audio.wav") + >>> logits, _ = model(waveforms) + """ + _LG.info("Importing model.") + _LG.info("Loading model configuration.") + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + config = _get_config_wavlm(original.config) + else: + config = _get_config(original.config) + _LG.debug(" - config: %s", config) + _LG.info("Building model.") + imported = _build(config, original) + return imported diff --git a/vencoder/encoder.py b/vencoder/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad120da34893d64b47b8ebeeaaed1f822a2e0be --- /dev/null +++ b/vencoder/encoder.py @@ -0,0 +1,13 @@ +class SpeechEncoder(object): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + self.model = None # This is Model + self.hidden_dim = 768 + pass + + + def encoder(self, wav): + """ + input: wav:[signal_length] + output: embedding:[batchsize,hidden_dim,wav_frame] + """ + pass diff --git a/vencoder/hubert/__init__.py b/vencoder/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/hubert/hubert_model.py b/vencoder/hubert/hubert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb642d89b07ca60792debab18e3454f52d8f357 --- /dev/null +++ b/vencoder/hubert/hubert_model.py @@ -0,0 +1,222 @@ +import copy +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as t_func +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + + +class Hubert(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = nn.LayerNorm(768) + self.dropout = nn.Dropout(0.1) + self.encoder = TransformerEncoder( + nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = nn.Linear(768, 256) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + def logits(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.cosine_similarity( + x.unsqueeze(2), + self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + dim=-1, + ) + return logits / 0.1 + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, mask = self.encode(x) + x = self.proj(x) + logits = self.logits(x) + return logits, mask + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + @torch.inference_mode() + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = nn.GroupNorm(512, 512) + self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = t_func.gelu(self.norm0(self.conv0(x))) + x = t_func.gelu(self.conv1(x)) + x = t_func.gelu(self.conv2(x)) + x = t_func.gelu(self.conv3(x)) + x = t_func.gelu(self.conv4(x)) + x = t_func.gelu(self.conv5(x)) + x = t_func.gelu(self.conv6(x)) + return x + + +class FeatureProjection(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(512) + self.projection = nn.Linear(512, 768) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = t_func.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + +class TransformerEncoder(nn.Module): + def __init__( + self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +def _compute_mask( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + + +def hubert_soft( + path: str, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft() + checkpoint = torch.load(path) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert diff --git a/vencoder/hubert/hubert_model_onnx.py b/vencoder/hubert/hubert_model_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..d18f3c2a0fc29592a573a9780308d38f059640b9 --- /dev/null +++ b/vencoder/hubert/hubert_model_onnx.py @@ -0,0 +1,217 @@ +import copy +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as t_func +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + + +class Hubert(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = nn.LayerNorm(768) + self.dropout = nn.Dropout(0.1) + self.encoder = TransformerEncoder( + nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = nn.Linear(768, 256) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + def logits(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.cosine_similarity( + x.unsqueeze(2), + self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + dim=-1, + ) + return logits / 0.1 + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + def forward(self, x): + return self.units(x) + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = nn.GroupNorm(512, 512) + self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = t_func.gelu(self.norm0(self.conv0(x))) + x = t_func.gelu(self.conv1(x)) + x = t_func.gelu(self.conv2(x)) + x = t_func.gelu(self.conv3(x)) + x = t_func.gelu(self.conv4(x)) + x = t_func.gelu(self.conv5(x)) + x = t_func.gelu(self.conv6(x)) + return x + + +class FeatureProjection(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(512) + self.projection = nn.Linear(512, 768) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = t_func.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + +class TransformerEncoder(nn.Module): + def __init__( + self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +def _compute_mask( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + + +def hubert_soft( + path: str, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft() + checkpoint = torch.load(path) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert diff --git a/vencoder/wavlm/WavLM.py b/vencoder/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3986fdcc00033a9e8f1bfcd25df3799f40ed90 --- /dev/null +++ b/vencoder/wavlm/WavLM.py @@ -0,0 +1,741 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from vencoder.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/vencoder/wavlm/modules.py b/vencoder/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..add4a1aa0042cbcbf5c3b28d4d72f017b507717d --- /dev/null +++ b/vencoder/wavlm/modules.py @@ -0,0 +1,828 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/vencoder/whisper/__init__.py b/vencoder/whisper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/whisper/audio.py b/vencoder/whisper/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..05890dc195a376181c21072eb0a8af24cf29928a --- /dev/null +++ b/vencoder/whisper/audio.py @@ -0,0 +1,123 @@ +from functools import lru_cache +from typing import Union + +import ffmpeg +import numpy as np +import torch +import torch.nn.functional as F +from librosa.filters import mel as librosa_mel_fn + +from .utils import exact_div + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + try: + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + return torch.from_numpy(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)).to(device) + + +def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/vencoder/whisper/decoding.py b/vencoder/whisper/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..45e50b1c33c2c8f9ca6572e6175b8d6051ae02ee --- /dev/null +++ b/vencoder/whisper/decoding.py @@ -0,0 +1,712 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import Categorical + +from .audio import CHUNK_LENGTH +from .tokenizer import Tokenizer, get_tokenizer +from .utils import compression_ratio + +if TYPE_CHECKING: + from .model import Whisper + + +@torch.no_grad() +def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]: + """ + Detect the spoken language in the audio, and return them as list of strings, along with the ids + of the most probable language tokens and the probability distribution over all language tokens. + This is performed outside the main decode loop in order to not interfere with kv-caching. + + Returns + ------- + language_tokens : Tensor, shape = (n_audio,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]], length = n_audio + list of dictionaries containing the probability distribution over all languages. + """ + if tokenizer is None: + tokenizer = get_tokenizer(model.is_multilingual) + if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: + raise ValueError("This model doesn't have language tokens so it can't perform lang id") + + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] + logits = model.logits(x, mel)[:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) + mask[list(tokenizer.all_language_tokens)] = False + logits[:, mask] = -np.inf + language_tokens = logits.argmax(dim=-1) + language_token_probs = logits.softmax(dim=-1).cpu() + language_probs = [ + { + c: language_token_probs[i, j].item() + for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) + } + for i in range(n_audio) + ] + + if single: + language_tokens = language_tokens[0] + language_probs = language_probs[0] + + return language_tokens, language_probs + + +@dataclass(frozen=True) +class DecodingOptions: + task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" + language: Optional[str] = None # language that the audio is in; uses detected language if None + + # sampling-related options + temperature: float = 0.0 + sample_len: Optional[int] = None # maximum number of tokens to sample + best_of: Optional[int] = None # number of independent samples to collect, when t > 0 + beam_size: Optional[int] = None # number of beams in beam search, when t == 0 + patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) + + # options for ranking generations (either beams or best-of-N samples) + length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm + + # prompt, prefix, and token suppression + prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context + prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context + suppress_blank: bool = True # this will suppress blank outputs + + # list of tokens ids (or comma-separated token ids) to suppress + # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` + suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" + + # timestamp sampling options + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this + + # implementation details + fp16: bool = True # use fp16 for most of the calculation + + +@dataclass(frozen=True) +class DecodingResult: + audio_features: Tensor + language: str + language_probs: Optional[Dict[str, float]] = None + tokens: List[int] = field(default_factory=list) + text: str = "" + avg_logprob: float = np.nan + no_speech_prob: float = np.nan + temperature: float = np.nan + compression_ratio: float = np.nan + + +class Inference: + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + """Perform a forward pass on the decoder and return per-token logits""" + raise NotImplementedError + + def rearrange_kv_cache(self, source_indices) -> None: + """Update the key-value cache according to the updated beams""" + raise NotImplementedError + + def cleanup_caching(self) -> None: + """Clean up any resources or hooks after decoding is finished""" + pass + + +class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length + self.kv_cache = {} + self.hooks = [] + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + if not self.kv_cache: + self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() + + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] + + return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) + + def cleanup_caching(self): + for hook in self.hooks: + hook.remove() + + self.kv_cache = {} + self.hooks = [] + + def rearrange_kv_cache(self, source_indices): + for module, tensor in self.kv_cache.items(): + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = tensor[source_indices].detach() + + +class SequenceRanker: + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: + """ + Given a list of groups of samples and their cumulative log probabilities, + return the indices of the samples in each group to select as the final result + """ + raise NotImplementedError + + +class MaximumLikelihoodRanker(SequenceRanker): + """ + Select the sample with the highest log probabilities, penalized using either + a simple length normalization or Google NMT paper's length penalty + """ + + def __init__(self, length_penalty: Optional[float]): + self.length_penalty = length_penalty + + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): + def scores(logprobs, lengths): + result = [] + for logprob, length in zip(logprobs, lengths): + if self.length_penalty is None: + penalty = length + else: + # from the Google NMT paper + penalty = ((5 + length) / 6) ** self.length_penalty + result.append(logprob / penalty) + return result + + # get the sequence with the highest score + lengths = [[len(t) for t in s] for s in tokens] + return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] + + +class TokenDecoder: + def reset(self): + """Initialize any stateful variables for decoding a new sequence""" + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + """Specify how to select the next token, based on the current trace and logits + + Parameters + ---------- + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + sum_logprobs : Tensor, shape = (n_batch) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Tensor, shape = (n_batch, current_sequence_length + 1) + the tokens, appended with the selected next token + + completed : bool + True if all sequences has reached the end of text + + """ + raise NotImplementedError + + def finalize( + self, tokens: Tensor, sum_logprobs: Tensor + ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: + """Finalize search and return the final candidate sequences + + Parameters + ---------- + tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence + + sum_logprobs : Tensor, shape = (n_audio, n_group) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Sequence[Sequence[Tensor]], length = n_audio + sequence of Tensors containing candidate token sequences, for each audio input + + sum_logprobs : List[List[float]], length = n_audio + sequence of cumulative log probabilities corresponding to the above + + """ + raise NotImplementedError + + +class GreedyDecoder(TokenDecoder): + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + temperature = self.temperature + if temperature == 0: + next_tokens = logits.argmax(dim=-1) + else: + next_tokens = Categorical(logits=logits / temperature).sample() + + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + + next_tokens[tokens[:, -1] == self.eot] = self.eot + tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) + + completed = (tokens[:, -1] == self.eot).all() + return tokens, completed + + def finalize(self, tokens: Tensor, sum_logprobs: Tensor): + # make sure each sequence has at least one EOT token at the end + tokens = F.pad(tokens, (0, 1), value=self.eot) + return tokens, sum_logprobs.tolist() + + +class BeamSearchDecoder(TokenDecoder): + def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences = None + + assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + self.finished_sequences = None + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: # for the first update + self.finished_sequences = [{} for _ in range(n_audio)] + + logprobs = F.log_softmax(logits.float(), dim=-1) + next_tokens, source_indices, finished_sequences = [], [], [] + for i in range(n_audio): + scores, sources, finished = {}, {}, {} + + # STEP 1: calculate the cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens[idx].tolist() + for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + new_logprob = (sum_logprobs[idx] + logprob).item() + sequence = tuple(prefix + [token.item()]) + scores[sequence] = new_logprob + sources[sequence] = idx + + # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + sum_logprobs[len(next_tokens)] = scores[sequence] + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + + tokens = torch.tensor(next_tokens, device=tokens.device) + self.inference.rearrange_kv_cache(source_indices) + + # add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break # the candidate list is full + previously_finished[seq] = newly_finished[seq] + + # mark as completed if all audio has enough number of samples + completed = all( + len(sequences) >= self.max_candidates for sequences in self.finished_sequences + ) + return tokens, completed + + def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): + # collect all finished sequences, including patience, and add unfinished ones if not enough + sum_logprobs = sum_logprobs.cpu() + for i, sequences in enumerate(self.finished_sequences): + if len(sequences) < self.beam_size: # when not enough sequences are finished + for j in list(np.argsort(sum_logprobs[i]))[::-1]: + sequence = preceding_tokens[i, j].tolist() + [self.eot] + sequences[tuple(sequence)] = sum_logprobs[i][j].item() + if len(sequences) >= self.beam_size: + break + + tokens: List[List[Tensor]] = [ + [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences + ] + sum_logprobs: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + return tokens, sum_logprobs + + +class LogitFilter: + def apply(self, logits: Tensor, tokens: Tensor) -> None: + """Apply any filtering or masking to logits in-place + + Parameters + ---------- + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + """ + raise NotImplementedError + + +class SuppressBlank(LogitFilter): + def __init__(self, tokenizer: Tokenizer, sample_begin: int): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + + def apply(self, logits: Tensor, tokens: Tensor): + if tokens.shape[1] == self.sample_begin: + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf + + +class SuppressTokens(LogitFilter): + def __init__(self, suppress_tokens: Sequence[int]): + self.suppress_tokens = list(suppress_tokens) + + def apply(self, logits: Tensor, tokens: Tensor): + logits[:, self.suppress_tokens] = -np.inf + + +class ApplyTimestampRules(LogitFilter): + def __init__( + self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] + ): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + self.max_initial_timestamp_index = max_initial_timestamp_index + + def apply(self, logits: Tensor, tokens: Tensor): + # suppress <|notimestamps|> which is handled by without_timestamps + if self.tokenizer.no_timestamps is not None: + logits[:, self.tokenizer.no_timestamps] = -np.inf + + # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + for k in range(tokens.shape[0]): + seq = [t for t in tokens[k, self.sample_begin :].tolist()] + last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin + penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + logits[k, self.tokenizer.timestamp_begin :] = -np.inf + else: # cannot be normal text tokens + logits[k, : self.tokenizer.eot] = -np.inf + + if tokens.shape[1] == self.sample_begin: + # suppress generating non-timestamp tokens at the beginning + logits[:, : self.tokenizer.timestamp_begin] = -np.inf + + # apply the `max_initial_timestamp` option + if self.max_initial_timestamp_index is not None: + last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + logits[:, last_allowed + 1 :] = -np.inf + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = F.log_softmax(logits.float(), dim=-1) + for k in range(tokens.shape[0]): + timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) + max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() + if timestamp_logprob > max_text_token_logprob: + logits[k, : self.tokenizer.timestamp_begin] = -np.inf + + +class DecodingTask: + inference: Inference + sequence_ranker: SequenceRanker + decoder: TokenDecoder + logit_filters: List[LogitFilter] + + def __init__(self, model: "Whisper", options: DecodingOptions): + self.model = model + + language = options.language or "en" + tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) + self.tokenizer: Tokenizer = tokenizer + self.options: DecodingOptions = self._verify_options(options) + + self.n_group: int = options.beam_size or options.best_of or 1 + self.n_ctx: int = model.dims.n_text_ctx + self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 + + self.sot_sequence: Tuple[int] = tokenizer.sot_sequence + if self.options.without_timestamps: + self.sot_sequence = tokenizer.sot_sequence_including_notimestamps + + self.initial_tokens: Tuple[int] = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.sot_index: int = self.initial_tokens.index(tokenizer.sot) + + # inference: implements the forward pass through the decoder, including kv caching + self.inference = PyTorchInference(model, len(self.initial_tokens)) + + # sequence ranker: implements how to rank a group of sampled sequences + self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) + + # decoder: implements how to select the next tokens, given the autoregressive distribution + if options.beam_size is not None: + self.decoder = BeamSearchDecoder( + options.beam_size, tokenizer.eot, self.inference, options.patience + ) + else: + self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) + + # logit filters: applies various rules to suppress or penalize certain tokens + self.logit_filters = [] + if self.options.suppress_blank: + self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) + if self.options.suppress_tokens: + self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) + if not options.without_timestamps: + precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds + max_initial_timestamp_index = None + if options.max_initial_timestamp: + max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) + self.logit_filters.append( + ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) + ) + + def _verify_options(self, options: DecodingOptions) -> DecodingOptions: + if options.beam_size is not None and options.best_of is not None: + raise ValueError("beam_size and best_of can't be given together") + if options.temperature == 0: + if options.best_of is not None: + raise ValueError("best_of with greedy sampling (T=0) is not compatible") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") + if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): + raise ValueError("length_penalty (alpha) should be a value between 0 and 1") + + return options + + def _get_initial_tokens(self) -> Tuple[int]: + tokens = list(self.sot_sequence) + prefix = self.options.prefix + prompt = self.options.prompt + + if prefix: + prefix_tokens = ( + self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix + ) + if self.sample_len is not None: + max_prefix_len = self.n_ctx // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt: + prompt_tokens = ( + self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt + ) + tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens + + return tuple(tokens) + + def _get_suppress_tokens(self) -> Tuple[int]: + suppress_tokens = self.options.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" + + suppress_tokens.extend( + [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] + ) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) + + return tuple(sorted(set(suppress_tokens))) + + def _get_audio_features(self, mel: Tensor): + if self.options.fp16: + mel = mel.half() + + if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): + # encoded audio features are given; skip audio encoding + print("encoded audio features are given; skip audio encoding") + audio_features = mel + else: + print(mel.shape) + print("===============================") + audio_features = self.model.encoder(mel) + + if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): + return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") + + return audio_features + + def _detect_language(self, audio_features: Tensor, tokens: Tensor): + languages = [self.options.language] * audio_features.shape[0] + lang_probs = None + + if self.options.language is None or self.options.task == "lang_id": + lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) + languages = [max(probs, key=probs.get) for probs in lang_probs] + if self.options.language is None: + tokens[:, self.sot_index + 1] = lang_tokens # write language tokens + + return languages, lang_probs + + def _main_loop(self, audio_features: Tensor, tokens: Tensor): + assert audio_features.shape[0] == tokens.shape[0] + n_batch = tokens.shape[0] + sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) + no_speech_probs = [np.nan] * n_batch + + try: + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features) + + if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs + probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + + # now we need to consider the logits at the last token only + logits = logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + + if completed or tokens.shape[-1] > self.n_ctx: + break + finally: + self.inference.cleanup_caching() + + return tokens, sum_logprobs, no_speech_probs + + @torch.no_grad() + def run(self, mel: Tensor) -> List[DecodingResult]: + self.decoder.reset() + tokenizer: Tokenizer = self.tokenizer + n_audio: int = mel.shape[0] + + audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass + tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) + + # detect language if requested, overwriting the language token + languages, language_probs = self._detect_language(audio_features, tokens) + if self.options.task == "lang_id": + return [ + DecodingResult(audio_features=features, language=language, language_probs=probs) + for features, language, probs in zip(audio_features, languages, language_probs) + ] + + # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling + audio_features = audio_features.repeat_interleave(self.n_group, dim=0) + tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) + + # call the main sampling loop + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + + # reshape the tensors to have (n_audio, n_group) as the first two dimensions + audio_features = audio_features[:: self.n_group] + no_speech_probs = no_speech_probs[:: self.n_group] + assert audio_features.shape[0] == len(no_speech_probs) == n_audio + + tokens = tokens.reshape(n_audio, self.n_group, -1) + sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) + + # get the final candidates for each group, and slice between the first sampled token and EOT + tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens: List[List[Tensor]] = [ + [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens + ] + + # select the top-ranked sample in each group + selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] + texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] + + sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] + avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + + fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) + if len(set(map(len, fields))) != 1: + raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") + + return [ + DecodingResult( + audio_features=features, + language=language, + tokens=tokens, + text=text, + avg_logprob=avg_logprob, + no_speech_prob=no_speech_prob, + temperature=self.options.temperature, + compression_ratio=compression_ratio(text), + ) + for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) + ] + + +@torch.no_grad() +def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) + A tensor containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + result = DecodingTask(model, options).run(mel) + + if single: + result = result[0] + + return result diff --git a/vencoder/whisper/model.py b/vencoder/whisper/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f3de4d32cb9646964074401aad176dbef9ef2125 --- /dev/null +++ b/vencoder/whisper/model.py @@ -0,0 +1,268 @@ +from dataclasses import dataclass +from typing import Dict, Iterable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + len_x = x.shape[1] + len_e = self.positional_embedding.shape[0] + assert len_x <= len_e, "incorrect audio shape" + pos_e = self.positional_embedding[:len_x, :] + x = (x + pos_e).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] + ) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() + + return logits + + +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + + def embed_audio(self, mel: torch.Tensor): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + return self.decoder(tokens, audio_features) + + def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: + return self.decoder(tokens, self.encoder(mel)) + + @property + def device(self): + return next(self.parameters()).device + + @property + def is_multilingual(self): + return self.dims.n_vocab == 51865 + + def install_kv_cache_hooks(self, cache: Optional[dict] = None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value + tensors calculated for the previous positions. This method returns a dictionary that stores + all caches, and the necessary hooks for the key and value projection modules that save the + intermediate tensors to be reused during later calculations. + + Returns + ------- + cache : Dict[nn.Module, torch.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: + cache[module] = output # save as-is, for the first token or cross attention + else: + cache[module] = torch.cat([cache[module], output], dim=1).detach() + return cache[module] + + def install_hooks(layer: nn.Module): + if isinstance(layer, MultiHeadAttention): + hooks.append(layer.key.register_forward_hook(save_to_cache)) + hooks.append(layer.value.register_forward_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks + + detect_language = detect_language_function + decode = decode_function diff --git a/vencoder/whisper/tokenizer.py b/vencoder/whisper/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b15645dc7e15ca9f601413076299b362293eae6d --- /dev/null +++ b/vencoder/whisper/tokenizer.py @@ -0,0 +1,331 @@ +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import GPT2TokenizerFast + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + + +@dataclass(frozen=True) +class Tokenizer: + """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" + + tokenizer: "GPT2TokenizerFast" + language: Optional[str] + sot_sequence: Tuple[int] + + def encode(self, text, **kwargs): + return self.tokenizer.encode(text, **kwargs) + + def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): + return self.tokenizer.decode(token_ids, **kwargs) + + def decode_with_timestamps(self, tokens) -> str: + """ + Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. + This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". + """ + outputs = [[]] + for token in tokens: + if token >= self.timestamp_begin: + timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] + return "".join(outputs) + + @property + @lru_cache() + def eot(self) -> int: + return self.tokenizer.eos_token_id + + @property + @lru_cache() + def sot(self) -> int: + return self._get_single_token_id("<|startoftranscript|>") + + @property + @lru_cache() + def sot_lm(self) -> int: + return self._get_single_token_id("<|startoflm|>") + + @property + @lru_cache() + def sot_prev(self) -> int: + return self._get_single_token_id("<|startofprev|>") + + @property + @lru_cache() + def no_speech(self) -> int: + return self._get_single_token_id("<|nospeech|>") + + @property + @lru_cache() + def no_timestamps(self) -> int: + return self._get_single_token_id("<|notimestamps|>") + + @property + @lru_cache() + def timestamp_begin(self) -> int: + return self.tokenizer.all_special_ids[-1] + 1 + + @property + @lru_cache() + def language_token(self) -> int: + """Returns the token id corresponding to the value of the `language` field""" + if self.language is None: + raise ValueError("This tokenizer does not have language token configured") + + additional_tokens = dict( + zip( + self.tokenizer.additional_special_tokens, + self.tokenizer.additional_special_tokens_ids, + ) + ) + candidate = f"<|{self.language}|>" + if candidate in additional_tokens: + return additional_tokens[candidate] + + raise KeyError(f"Language {self.language} not found in tokenizer.") + + @property + @lru_cache() + def all_language_tokens(self) -> Tuple[int]: + result = [] + for token, token_id in zip( + self.tokenizer.additional_special_tokens, + self.tokenizer.additional_special_tokens_ids, + ): + if token.strip("<|>") in LANGUAGES: + result.append(token_id) + return tuple(result) + + @property + @lru_cache() + def all_language_codes(self) -> Tuple[str]: + return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) + + @property + @lru_cache() + def sot_sequence_including_notimestamps(self) -> Tuple[int]: + return tuple(list(self.sot_sequence) + [self.no_timestamps]) + + @property + @lru_cache() + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech + annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") + symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} + for symbol in symbols + list(miscellaneous): + for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + + def _get_single_token_id(self, text) -> int: + tokens = self.tokenizer.encode(text) + assert len(tokens) == 1, f"{text} is not encoded as a single token" + return tokens[0] + + +@lru_cache(maxsize=None) +def build_tokenizer(name: str = "gpt2"): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + path = os.path.join(os.path.dirname(__file__), "assets", name) + tokenizer = GPT2TokenizerFast.from_pretrained(path) + + specials = [ + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + ] + + tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) + return tokenizer + + +@lru_cache(maxsize=None) +def get_tokenizer( + multilingual: bool, + *, + task: Optional[str] = None, # Literal["transcribe", "translate", None] + language: Optional[str] = None, +) -> Tokenizer: + if language is not None: + language = language.lower() + if language not in LANGUAGES: + if language in TO_LANGUAGE_CODE: + language = TO_LANGUAGE_CODE[language] + else: + raise ValueError(f"Unsupported language: {language}") + + if multilingual: + tokenizer_name = "multilingual" + task = task or "transcribe" + language = language or "en" + else: + tokenizer_name = "gpt2" + task = None + language = None + + tokenizer = build_tokenizer(name=tokenizer_name) + all_special_ids: List[int] = tokenizer.all_special_ids + sot: int = all_special_ids[1] + translate: int = all_special_ids[-6] + transcribe: int = all_special_ids[-5] + + langs = tuple(LANGUAGES.keys()) + sot_sequence = [sot] + if language is not None: + sot_sequence.append(sot + 1 + langs.index(language)) + if task is not None: + sot_sequence.append(transcribe if task == "transcribe" else translate) + + return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) diff --git a/vencoder/whisper/utils.py b/vencoder/whisper/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5dacc173c40bcd6e999d728862e29a968000b12e --- /dev/null +++ b/vencoder/whisper/utils.py @@ -0,0 +1,163 @@ +import json +import os +import sys +import zlib +from typing import Callable, TextIO + +system_encoding = sys.getdefaultencoding() + +if system_encoding != "utf-8": + def make_safe(string): + # replaces any character not representable using the system default encoding with an '?', + # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). + return string.encode(system_encoding, errors="replace").decode(system_encoding) +else: + def make_safe(string): + # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding + return string + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + +def optional_int(string): + return None if string == "None" else int(string) + + +def optional_float(string): + return None if string == "None" else float(string) + + +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + + +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__(self, result: dict, audio_path: str): + audio_basename = os.path.basename(audio_path) + output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f) + + def write_result(self, result: dict, file: TextIO): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result(self, result: dict, file: TextIO): + for segment in result["segments"]: + print(segment['text'].strip(), file=file, flush=True) + + +class WriteVTT(ResultWriter): + extension: str = "vtt" + + def write_result(self, result: dict, file: TextIO): + print("WEBVTT\n", file=file) + for segment in result["segments"]: + print( + f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, + ) + + +class WriteSRT(ResultWriter): + extension: str = "srt" + + def write_result(self, result: dict, file: TextIO): + for i, segment in enumerate(result["segments"], start=1): + # write srt lines + print( + f"{i}\n" + f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " + f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, + ) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + extension: str = "tsv" + + def write_result(self, result: dict, file: TextIO): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment['start']), file=file, end="\t") + print(round(1000 * segment['end']), file=file, end="\t") + print(segment['text'].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result(self, result: dict, file: TextIO): + json.dump(result, file) + + +def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "tsv": WriteTSV, + "json": WriteJSON, + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all(result: dict, file: TextIO): + for writer in all_writers: + writer(result, file) + + return write_all + + return writers[output_format](output_dir) + diff --git a/wav_upload.py b/wav_upload.py new file mode 100644 index 0000000000000000000000000000000000000000..fffe12a6ff780ab4173c58009154232a4a10aa49 --- /dev/null +++ b/wav_upload.py @@ -0,0 +1,25 @@ +import argparse +import os +import shutil + +from google.colab import files + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--type", type=str, required=True, help="type of file to upload") + args = parser.parse_args() + file_type = args.type + + basepath = os.getcwd() + uploaded = files.upload() # 上传文件 + assert(file_type in ['zip', 'audio']) + if file_type == "zip": + upload_path = "./upload/" + for filename in uploaded.keys(): + #将上传的文件移动到指定的位置上 + shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, "userzip.zip")) + elif file_type == "audio": + upload_path = "./raw/" + for filename in uploaded.keys(): + #将上传的文件移动到指定的位置上 + shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, filename)) \ No newline at end of file