diff --git a/2.0 b/2.0 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/2.0' b/2.0' new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..aabf9130b0a67aca9beaac9f2cb1a40237a4468d --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,28 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +## [1.0.0] - 2023-08-02 + +Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion. +Added pretrained model for AudioGen and MultiBandDiffusion. + +## [0.0.2] - 2023-08-01 + +Improved demo, fixed top p (thanks @jnordberg). + +Compressor tanh on output to avoid clipping with some style (especially piano). +Now repeating the conditioning periodically if it is too short. + +More options when launching Gradio app locally (thanks @ashleykleynhans). + +Testing out PyTorch 2.0 memory efficient attention. + +Added extended generation (infinite length) by slowly moving the windows. +Note that other implementations exist: https://github.com/camenduru/MusicGen-colab. + +## [0.0.1] - 2023-06-09 + +Initial release, with model evaluation only. diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..83f431e8feeb7e80d571f39c9f6c1b96857b5f85 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..a3e9507643d4439f509a8fc8b87dc73417ef9822 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,35 @@ +# Contributing to AudioCraft + +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests + +AudioCraft is the implementation of a research paper. +Therefore, we do not plan on accepting many pull requests for new features. +We certainly welcome them for bug fixes. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to encodec, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b93be90515ccd0b9daedaa589e42bf5929693f1f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Meta Platforms, Inc. and affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LICENSE_weights b/LICENSE_weights new file mode 100644 index 0000000000000000000000000000000000000000..108b5f002fc31efe11d881de2cd05329ebe8cc37 --- /dev/null +++ b/LICENSE_weights @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..ac6828f0ab296c7e34e44548b14bce9df4f65a6c --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,9 @@ +include Makefile +include LICENSE +include LICENSE_weights +include *.md +include *.ini +include requirements.txt +include audiocraft/py.typed +include assets/*.mp3 +recursive-include conf *.yaml diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..3a4910066583dc22f06f5ec2d5711367c941c86b --- /dev/null +++ b/Makefile @@ -0,0 +1,40 @@ +INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \ + dataset.train.num_samples=10 dataset.valid.num_samples=10 \ + dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \ + logging.level=DEBUG +INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e +INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ + transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e +INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ + transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e +INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \ + checkpoint.save_last=false # Using compression model from 616d7b3c + +default: linter tests + +install: + pip install -U pip + pip install -U -e '.[dev]' + +linter: + flake8 audiocraft && mypy audiocraft + flake8 tests && mypy tests + +tests: + coverage run -m pytest tests + coverage report + +tests_integ: + $(INTEG_COMPRESSION) + $(INTEG_MBD) + $(INTEG_MUSICGEN) + $(INTEG_AUDIOGEN) + + +api_docs: + pdoc3 --html -o api_docs -f audiocraft + +dist: + python setup.py sdist + +.PHONY: linter tests api_docs dist diff --git a/README.md b/README.md index a9a87af2a56b6389e8123728818d0d33f1579409..0f5380e53c8725a5e0334e09a98958c7b078f49a 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,86 @@ ---- -title: Videoshop Backend -emoji: 📚 -colorFrom: green -colorTo: red -sdk: gradio -sdk_version: 3.41.2 -app_file: app.py -pinned: false -license: openrail ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# AudioCraft + +![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg) +![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg) +![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg) + +AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and trainingcon code +for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen. + +## Installation + +AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following: + +```shell +# Best to make sure you have torch installed first, in particular before installing xformers. +# Don't run this if you already have PyTorch installed. +pip install 'torch>=2.0' +# Then proceed to one of the following +pip install -U audiocraft # stable release +pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge +pip install -e . # or if you cloned the repo locally (mandatory if you want to train). + +``` + +We also recommend having `ffmpeg` installed, either through your system or Anaconda: + +```bash +sudo apt-get install ffmpeg +# Or if you are using Anaconda or Miniconda +conda install 'ffmpeg<5' -c conda-forge + +``` + +## Models + +At the moment, AudioCraft contains the training code and inference code for: + +* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model. +* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model. +* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec. +* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion. + +## Training code + +AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models. +For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to +the [AudioCraft training documentation](./docs/TRAINING.md). + +For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model +that provides pointers to configuration, example grids and model/task-specific information and FAQ. + +## API documentation + +We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft. + +## FAQ + +#### Is the training code available? + +Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md). + +#### Where are the models stored? + +Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable. + +## License + +* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). +* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). + +## Citation + +For the general framework of AudioCraft, please cite the following. + +```json +@article{copet2023simple, + title={Simple and Controllable Music Generation}, + author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + year={2023}, + journal={arXiv preprint arXiv:2306.05284}, +} + +``` + +When referring to a specific model, please cite as mentioned in the model specific README, e.g +[./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc. diff --git a/config/conditioner/chroma2music.yaml b/config/conditioner/chroma2music.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91d37e758ef183678cff3f7a880b6bab2e36b03c --- /dev/null +++ b/config/conditioner/chroma2music.yaml @@ -0,0 +1,46 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.2 + inference_coef: 3.0 + +attribute_dropout: + args: + active_on_eval: false + text: {} + wav: + self_wav: 0.5 + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [self_wav, description] + cross: [] + input_interpolate: [] + +conditioners: + self_wav: + model: chroma_stem + chroma_stem: + sample_rate: ${sample_rate} + n_chroma: 12 + radix2_exp: 14 + argmax: true + match_len_on_eval: false + eval_wavs: null + n_eval_wavs: 100 + cache_path: null + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.2 + normalize_text: false + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/clapemb2music.yaml b/config/conditioner/clapemb2music.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8500a826e7379b4a8baaf67570e233f7bac7e5da --- /dev/null +++ b/config/conditioner/clapemb2music.yaml @@ -0,0 +1,44 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 + inference_coef: 3.0 + +attribute_dropout: + text: {} + wav: {} + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + input_interpolate: [] + +conditioners: + description: + model: clap + clap: + checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt + model_arch: 'HTSAT-base' + enable_fusion: false + sample_rate: 44100 + max_audio_length: 10 + audio_stride: 1 + dim: 512 + attribute: description + normalize: true + quantize: true # use RVQ quantization + n_q: 12 + bins: 1024 + kmeans_iters: 50 + text_p: 0. # probability of using text embed at train time + cache_path: null + +dataset: + joint_embed_attributes: [description] + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/none.yaml b/config/conditioner/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6055dc910cad46d80609aae57bb46b81f2663d70 --- /dev/null +++ b/config/conditioner/none.yaml @@ -0,0 +1,19 @@ +# @package __global__ + +# No conditioning + +classifier_free_guidance: + training_dropout: 0 + inference_coef: 1 + +attribute_dropout: + text: {} + wav: {} + +fuser: + sum: [] + prepend: [] + cross: [] + input_interpolate: [] + +conditioners: null diff --git a/config/conditioner/text2music.yaml b/config/conditioner/text2music.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d0fe6cfa3fb33bcdb4f9fd16bd5ab4034c68b7b --- /dev/null +++ b/config/conditioner/text2music.yaml @@ -0,0 +1,30 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 + inference_coef: 3.0 + +attribute_dropout: {} + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + input_interpolate: [] + +conditioners: + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.3 + normalize_text: false + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/text2sound.yaml b/config/conditioner/text2sound.yaml new file mode 100644 index 0000000000000000000000000000000000000000..555d4b7c3cecf0ec06c8cb25440b2f426c098ad2 --- /dev/null +++ b/config/conditioner/text2sound.yaml @@ -0,0 +1,24 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.1 + inference_coef: 3.0 + +attribute_dropout: {} + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + input_interpolate: [] + +conditioners: + description: + model: t5 + t5: + name: t5-large + finetune: false + word_dropout: 0. + normalize_text: false diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b0b7866eafac173fe7b056ad5920be1df57a947 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,75 @@ +# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft +# Please don't update this file directly. Instead use distinct configuration files +# to override the below configuration. +defaults: + - _self_ + - dset: default + - solver: default + +device: cuda +dtype: float32 +autocast: false +autocast_dtype: bfloat16 +seed: 2036 +show: false # just show the model and its size and exit +continue_from: # continue from a given sig or path +execute_only: # can be set to generate/evaluate/valid to run that stage +execute_inplace: false # don't enforce continue_from to be set + # to enable inplace execution of the stage. This assume + # that you know what you are doing and execute stage + # preserving the original xp sig. +benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them + +efficient_attention_backend: torch # can be torch or xformers. +num_threads: 1 # called with torch.set_num_thread. +mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server). + + +label: # use this if you want twice the same exp, with a name. + +# logging parameters +logging: + level: INFO + log_updates: 10 + log_tensorboard: false + log_wandb: false +tensorboard: + with_media_logging: false + name: # optional name for the experiment + sub_dir: # optional sub directory to store tensorboard data +wandb: + with_media_logging: true + project: # project name + name: # optional name for the experiment + group: # optional group + +# SLURM launcher configuration. +slurm: + gpus: 4 # convenience parameter, number of GPUs to use. + mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`. + time: 3600 + constraint: + partition: + comment: + setup: [] + exclude: '' + +# dora parameters +dora: + # Output folder for all artifacts of an experiment. + dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs + # The following entries will be ignored by dora when computing the unique XP signature. + # Note that slurm.* and dora.* are automatically ignored. + exclude: [ + 'device', 'wandb.*', 'tensorboard.*', 'logging.*', + 'dataset.num_workers', 'eval.num_workers', 'special.*', + 'metrics.visqol.bin', 'metrics.fad.bin', + 'execute_only', 'execute_best', 'generate.every', + 'optim.eager_sync', 'profiler.*', 'deadlock.*', + 'efficient_attention_backend', 'num_threads', 'mp_start_method', + ] + use_rendezvous: false + # for grids, always run from a clean repo, allowing reliable runs and storing + # the exact commit. Your repo must be absolutely pristine clean. + # Local `dora run` are not impacted for easier debugging. + git_save: true diff --git a/config/dset/audio/audiocaps_16khz.yaml b/config/dset/audio/audiocaps_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14f5d6a4fcbf4426b7987d4427ca2d98d17d6c5b --- /dev/null +++ b/config/dset/audio/audiocaps_16khz.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +# AudioCaps dataset +datasource: + max_sample_rate: 16000 + max_channels: 1 + + train: null # only evaluation set + valid: null # only evaluation set + evaluate: egs/audiocaps/audiocaps_16khz + generate: egs/audiocaps/audiocaps_16khz # identical to evaluate diff --git a/config/dset/audio/default.yaml b/config/dset/audio/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80be23e999c6366cc89ebcf55af6b958c0e45158 --- /dev/null +++ b/config/dset/audio/default.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +datasource: + max_sample_rate: ??? + max_channels: ??? + + train: ??? + valid: ??? + evaluate: ??? + generate: null diff --git a/config/dset/audio/example.yaml b/config/dset/audio/example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d559d6d79a1cc05a82bb09f267c446258ef9ca55 --- /dev/null +++ b/config/dset/audio/example.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +datasource: + max_sample_rate: 44100 + max_channels: 2 + + train: egs/example + valid: egs/example + evaluate: egs/example + generate: egs/example diff --git a/config/dset/audio/musiccaps_32khz.yaml b/config/dset/audio/musiccaps_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d4eea0f7a521a47b9f673fecab075c5223d2b07 --- /dev/null +++ b/config/dset/audio/musiccaps_32khz.yaml @@ -0,0 +1,12 @@ +# @package __global__ + +# total samples obtained from MusicCaps = 5469 +# (out of 5521 due to AudioSet corrupted samples) +datasource: + max_sample_rate: 32000 + max_channels: 2 + + train: null # only evaluation set + valid: null # only evaluation set + evaluate: egs/musiccaps/musiccaps_32khz + generate: egs/musiccaps/musiccaps_32khz # identical to evaluate diff --git a/config/dset/default.yaml b/config/dset/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5d730130e090b38a42984a8a87e1eea01cbf031 --- /dev/null +++ b/config/dset/default.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft +# Please don't update this file directly. Instead use distinct configuration files +# to override the below configuration. +datasource: + train: ??? + valid: ??? + evaluate: ??? + generate: ??? diff --git a/config/dset/internal/music_10k_32khz.yaml b/config/dset/internal/music_10k_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..036628abfeaa89279790547bbb5b3ee9dd69cea3 --- /dev/null +++ b/config/dset/internal/music_10k_32khz.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +# high quality music dataset with no artist overlap between splits +datasource: + max_sample_rate: 32000 + max_channels: 1 + + train: egs/music/music_10k_32khz/train + valid: egs/music/music_10k_32khz/valid + evaluate: egs/music/music_10k_32khz/test + generate: egs/music/music_10k_32khz/test # identical to evaluate diff --git a/config/dset/internal/music_400k_32khz.yaml b/config/dset/internal/music_400k_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7786880ab9c0464a0423d906c18d62bdf7194463 --- /dev/null +++ b/config/dset/internal/music_400k_32khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +datasource: + max_sample_rate: 32000 + max_channels: 1 + + train: egs/music/music_400k_32khz/train + valid: egs/music/music_400k_32khz/valid + evaluate: egs/music/music_400k_32khz/test + generate: egs/music/music_400k_32khz/test # identical to evaluate diff --git a/config/dset/internal/sounds_16khz.yaml b/config/dset/internal/sounds_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f3401a1b44ce300e22f3f64ef9c54d5c013c153 --- /dev/null +++ b/config/dset/internal/sounds_16khz.yaml @@ -0,0 +1,12 @@ +# @package __global__ + +# environmental sounds dataset compiling all datasets +# with applied filters on tags +datasource: + max_sample_rate: 16000 + max_channels: 1 + + train: egs/sound/sounds_16khz/train + valid: egs/sound/sounds_16khz/valid + evaluate: egs/sound/sounds_16khz/test + generate: egs/sound/sounds_16khz/test # identical to evaluate diff --git a/config/model/encodec/default.yaml b/config/model/encodec/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec62c6c8ef9a686890bdca8b8f27a2f1c232205d --- /dev/null +++ b/config/model/encodec/default.yaml @@ -0,0 +1,54 @@ +# @package __global__ + +compression_model: encodec + +encodec: + autoencoder: seanet + quantizer: rvq + sample_rate: ${sample_rate} + channels: ${channels} + causal: false + renormalize: false + +seanet: + dimension: 128 + channels: ${channels} + causal: ${encodec.causal} + n_filters: 32 + n_residual_layers: 1 + ratios: [8, 5, 4, 2] + activation: ELU + activation_params: {"alpha": 1.} + norm: weight_norm + norm_params: {} + kernel_size: 7 + residual_kernel_size: 3 + last_kernel_size: 7 + dilation_base: 2 + pad_mode: constant + true_skip: true + compress: 2 + lstm: 2 + disable_norm_outer_blocks: 0 + # Specific encoder or decoder params. + # You can also override any param for the encoder or decoder only + # by using Hydra `+param=` syntax, i.e.` + # `+seanet.decoder.n_filters=64`. + decoder: + trim_right_ratio: 1.0 + final_activation: null + final_activation_params: null + encoder: {} + +rvq: + n_q: 8 + q_dropout: false + bins: 1024 + decay: 0.99 + kmeans_init: true + kmeans_iters: 50 + threshold_ema_dead_code: 2 + orthogonal_reg_weight: 0.0 + orthogonal_reg_active_codes_only: false + +no_quant: {} diff --git a/config/model/encodec/encodec_base_causal.yaml b/config/model/encodec/encodec_base_causal.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ca555bcdc69433f172915400bb71c3b63e68681 --- /dev/null +++ b/config/model/encodec/encodec_base_causal.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +defaults: + - encodec/default + +encodec: + causal: true + +rvq: + n_q: 32 + q_dropout: true diff --git a/config/model/encodec/encodec_large_nq4_s320.yaml b/config/model/encodec/encodec_large_nq4_s320.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f2d77590afd8a81185358c705a6e42853e257c3 --- /dev/null +++ b/config/model/encodec/encodec_large_nq4_s320.yaml @@ -0,0 +1,13 @@ +# @package __global__ + +defaults: + - encodec/default + +seanet: + # default ratios are [8, 5, 4, 2] + n_filters: 64 + +rvq: + bins: 2048 + n_q: 4 + q_dropout: false diff --git a/config/model/encodec/encodec_large_nq4_s640.yaml b/config/model/encodec/encodec_large_nq4_s640.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fcb7e87f4f700554164b0a58e9927b2f96a2c5a --- /dev/null +++ b/config/model/encodec/encodec_large_nq4_s640.yaml @@ -0,0 +1,13 @@ +# @package __global__ + +defaults: + - encodec/default + +seanet: + ratios: [8, 5, 4, 4] + n_filters: 64 + +rvq: + bins: 2048 + n_q: 4 + q_dropout: false diff --git a/config/model/lm/audiogen_lm.yaml b/config/model/lm/audiogen_lm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..696f74620af193c12208ce66fdb93a37f8ea9d80 --- /dev/null +++ b/config/model/lm/audiogen_lm.yaml @@ -0,0 +1,36 @@ +# @package __global__ + +defaults: + - lm/default + - override /conditioner: text2sound + - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly + +lm_model: transformer_lm + +codebooks_pattern: + modeling: delay + delay: + delays: [0, 1, 2, 3] + flatten_first: 0 + empty_initial: 0 + unroll: + flattening: [0, 1, 2, 3] + delays: [0, 0, 0, 0] + music_lm: + group_by: 2 + valle: + delays: [0, 0, 0] + +transformer_lm: + n_q: 4 + card: 2048 + memory_efficient: true + bias_proj: false + bias_ff: false + bias_attn: false + norm_first: true + layer_scale: null + weight_init: gaussian + depthwise_init: current + zero_bias_init: true + attention_as_float32: false diff --git a/config/model/lm/default.yaml b/config/model/lm/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d256ad14ef69d25d62c19b73599937c8546e79b --- /dev/null +++ b/config/model/lm/default.yaml @@ -0,0 +1,47 @@ +# @package __global__ +defaults: + - _self_ + - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly + +lm_model: transformer_lm + +codebooks_pattern: + modeling: parallel + +transformer_lm: + dim: 512 + num_heads: 8 + num_layers: 8 + hidden_scale: 4 + n_q: 8 # number of streams to model + card: 1024 + dropout: 0. + emb_lr: null + activation: gelu + norm_first: false # use pre-norm instead of post-norm + bias_ff: true # use bias for the feedforward + bias_attn: true # use bias for the attention + bias_proj: true # use bias for the output projections + past_context: null + causal: true + custom: false # use custom MHA implementation + memory_efficient: false # use flash attention + attention_as_float32: false # use float32 for the attention part, + # recommended at the moment when memory_efficient is True. + layer_scale: null + positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope). + xpos: false # apply xpos decay (rope only). + checkpointing: none # layer checkpointing method, can be none, torch, xformers_default. + # torch is the slowest but uses the least memory, + # xformers_default is somewhere in between. + weight_init: null # weight initialization (null, gaussian or uniform) + depthwise_init: null # perform depthwise initialization (null, current, global) + zero_bias_init: false # initialize bias to zero if bias in linears and + # if a weight_init method is used. + norm: layer_norm # normalization method to use in transformer. + cross_attention: false + qk_layer_norm: false + qk_layer_norm_cross: false + attention_dropout: null + kv_repeat: 1 + two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not... diff --git a/config/model/lm/model_scale/base.yaml b/config/model/lm/model_scale/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3da88d2305e4c380435de1a3eecfe311ecfc82f9 --- /dev/null +++ b/config/model/lm/model_scale/base.yaml @@ -0,0 +1,3 @@ +# @package __global__ + +# overrides nothing because default is already transformer base (~ 60M params) diff --git a/config/model/lm/model_scale/large.yaml b/config/model/lm/model_scale/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d355bfb93618003ac8994bc093eb7bc96ac60114 --- /dev/null +++ b/config/model/lm/model_scale/large.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +# gpt2 inspired, even bigger (~3.3B params) +transformer_lm: + dim: 2048 + num_heads: 32 + num_layers: 48 diff --git a/config/model/lm/model_scale/medium.yaml b/config/model/lm/model_scale/medium.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c825d1ff6c3b8cc9ae4959a898e14b40409d95e8 --- /dev/null +++ b/config/model/lm/model_scale/medium.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +# gpt2 like (~1.5B params) +transformer_lm: + dim: 1536 + num_heads: 24 + num_layers: 48 diff --git a/config/model/lm/model_scale/small.yaml b/config/model/lm/model_scale/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..88d89cb5ac1b183fb3a9092834cea83aa16c70a8 --- /dev/null +++ b/config/model/lm/model_scale/small.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +# 300M Param. + +transformer_lm: + dim: 1024 + num_heads: 16 + num_layers: 24 diff --git a/config/model/lm/model_scale/xsmall.yaml b/config/model/lm/model_scale/xsmall.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e98d4370d4fe7497f12aeb58f092a88797d1afa1 --- /dev/null +++ b/config/model/lm/model_scale/xsmall.yaml @@ -0,0 +1,8 @@ +# @package _global_ +# just used for debugging or when we just want to populate the cache +# and do not care about training. + +transformer_lm: + dim: 64 + num_heads: 2 + num_layers: 2 diff --git a/config/model/lm/musicgen_lm.yaml b/config/model/lm/musicgen_lm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bc87a628789a34e381e2aa8ba5ef6ed780669d7 --- /dev/null +++ b/config/model/lm/musicgen_lm.yaml @@ -0,0 +1,36 @@ +# @package __global__ + +defaults: + - lm/default + - override /conditioner: text2music + - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly + +lm_model: transformer_lm + +codebooks_pattern: + modeling: delay + delay: + delays: [0, 1, 2, 3] + flatten_first: 0 + empty_initial: 0 + unroll: + flattening: [0, 1, 2, 3] + delays: [0, 0, 0, 0] + music_lm: + group_by: 2 + valle: + delays: [0, 0, 0] + +transformer_lm: + n_q: 4 + card: 2048 + memory_efficient: true + bias_proj: false + bias_ff: false + bias_attn: false + norm_first: true + layer_scale: null + weight_init: gaussian + depthwise_init: current + zero_bias_init: true + attention_as_float32: false diff --git a/config/model/none.yaml b/config/model/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d4169f468d462c794ee6ed25017c3d78ae45d06 --- /dev/null +++ b/config/model/none.yaml @@ -0,0 +1,4 @@ +# @package __global__ + +# This file exist so that model is recognized as a config group +# by Hydra, and Dora. A bit weird we might need a better fix someday. diff --git a/config/model/score/basic.yaml b/config/model/score/basic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75fbc3783942602beaddaa38d0aca977aeee2dda --- /dev/null +++ b/config/model/score/basic.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +diffusion_unet: + hidden: 48 + depth: 4 + res_blocks: 1 + norm_groups: 4 + kernel: 8 + stride: 4 + growth: 4 + max_channels: 10_000 + dropout: 0. + emb_all_layers: true + bilstm: false + codec_dim: null + transformer: false + cross_attention: false \ No newline at end of file diff --git a/config/solver/audiogen/audiogen_base_16khz.yaml b/config/solver/audiogen/audiogen_base_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd6aee785c74db19ce9d6f488e68e6eeb471c026 --- /dev/null +++ b/config/solver/audiogen/audiogen_base_16khz.yaml @@ -0,0 +1,70 @@ +# @package __global__ + +# This is the training loop solver +# for the base AudioGen model (text-to-sound) +# on monophonic audio sampled at 16 kHz +# using a similar EnCodec+LM setup to MusicGen +defaults: + - audiogen/default + - /model: lm/audiogen_lm + - override /dset: audio/default + - _self_ + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 16khz +# with a total stride of 320 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //reference/bd44a852/checkpoint.th + +channels: 1 +sample_rate: 16000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128) + num_workers: 10 + segment_duration: 10 + min_segment_ratio: 1.0 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + external_metadata_source: null + # sample mixing augmentation at train time + train: + batch_size: 256 # matching AudioGen paper setup + aug_p: 0.5 # perform audio mixing 50% of the time + mix_p: 0.5 # proportion of batch items mixed together + # important: note that this will reduce the + # actual batch size used at train time + # which will be equal to mix_p * batch_size + mix_snr_low: -5 + mix_snr_high: 5 + mix_min_overlap: 0.5 + +generate: + lm: + use_sampling: true + top_k: 250 + top_p: 0.0 + +optim: + epochs: 100 + optimizer: adamw + lr: 5e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: inverse_sqrt + inverse_sqrt: + warmup: 3000 + warmup_init_lr: 0.0 diff --git a/config/solver/audiogen/debug.yaml b/config/solver/audiogen/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fbda8281c6d552d9445e04fee498641a26549aa5 --- /dev/null +++ b/config/solver/audiogen/debug.yaml @@ -0,0 +1,52 @@ +# @package __global__ + +# This is a minimal debugging configuration +# for MusicGen training solver +defaults: + - audiogen/default + - /model: lm/audiogen_lm + - override /model/lm/model_scale: xsmall + - override /dset: audio/example + - _self_ + +autocast: false +compression_model_checkpoint: null + +codebooks_pattern: + modeling: parallel + +channels: 1 +sample_rate: 16000 + +deadlock: + use: false # deadlock detection + +dataset: + batch_size: 4 + segment_duration: 5 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + audio: + strategy: peak + lm: + use_sampling: false + top_k: 0 + top_p: 0.0 + +checkpoint: + save_every: 0 + keep_last: 0 + +optim: + epochs: 2 + updates_per_epoch: 10 + optimizer: adamw + lr: 1e-4 + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: null diff --git a/config/solver/audiogen/default.yaml b/config/solver/audiogen/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afee63c65e0dd7350e3e89d2133bbca221d17631 --- /dev/null +++ b/config/solver/audiogen/default.yaml @@ -0,0 +1,40 @@ +# @package __global__ + +defaults: + - /solver/musicgen/default + - _self_ + - /solver/audiogen/evaluation: none + - override /dset: audio/default + +# See config/solver/musicgen/default.yaml for a list of possible values. +# We only keep the most important here. + +autocast: true +autocast_dtype: float16 + +solver: audiogen +sample_rate: ??? +channels: ??? +compression_model_checkpoint: ??? + +tokens: + padding_with_special_token: false + +dataset: + batch_size: 128 + segment_duration: 10 + min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence. + +optim: + epochs: 100 + updates_per_epoch: 2000 + lr: 1e-4 + optimizer: adamw + max_norm: 1.0 + adam: + betas: [0.9, 0.95] + weight_decay: 0.1 + eps: 1e-8 + +schedule: + lr_scheduler: null diff --git a/config/solver/audiogen/evaluation/none.yaml b/config/solver/audiogen/evaluation/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e739995ed6488700527529862a7a24f1afdcc7a --- /dev/null +++ b/config/solver/audiogen/evaluation/none.yaml @@ -0,0 +1,5 @@ +# @package __global__ + +dataset: + evaluate: + num_samples: 10000 diff --git a/config/solver/audiogen/evaluation/objective_eval.yaml b/config/solver/audiogen/evaluation/objective_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32fcc10033f3c3ff317216fe2876c65c6834e59b --- /dev/null +++ b/config/solver/audiogen/evaluation/objective_eval.yaml @@ -0,0 +1,29 @@ +# @package __global__ + +# Setup for execute only on audiocaps for audio generation +# evaluation with objective metrics +# execute_only=evaluate + +dataset: + max_audio_duration: null + # ensure the proper values are broadcasted here for evaluate + evaluate: + min_audio_duration: 1. # some metrics requires a minimum audio length + max_audio_duration: null # all samples from audiocaps should be ~10s + num_samples: null + segment_duration: null + generate: + min_audio_duration: 1. + max_audio_duration: null + num_samples: 500 + +evaluate: + metrics: + fad: true + kld: true + text_consistency: true + +metrics: + kld: + passt: + pretrained_length: 10 # similarly to reported results in AudioGen paper diff --git a/config/solver/compression/debug.yaml b/config/solver/compression/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54dac175278d4ff509b0e44905d6b6195441f2c6 --- /dev/null +++ b/config/solver/compression/debug.yaml @@ -0,0 +1,55 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_base_causal + - override /dset: audio/example + - _self_ + +channels: 1 +sample_rate: 16000 + +# debug config uses just L1 +losses: + adv: 0. + feat: 0. + l1: 1. + mel: 0. + msspec: 0. +# no balancer +balancer: + balance_grads: false + ema_decay: 1. + total_norm: 1. + per_batch_item: false +# no adversaries +adversarial: + adversaries: [] + adv_loss: hinge + feat_loss: l1 + +# faster model for local dev +seanet: + dimension: 16 + n_filters: 4 + +# very small dataset +dataset: + batch_size: 8 + num_workers: 10 + num_samples: 100 + segment_duration: 1 + evaluate: + batch_size: 32 + generate: + batch_size: 1 + num_samples: 5 + segment_duration: 10 + +# limited training +evaluate: + every: 5 +generate: + every: 5 +optim: + epochs: 50 diff --git a/config/solver/compression/default.yaml b/config/solver/compression/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41c812ba9ff8afe7ee10302ad5b9f05b745877d9 --- /dev/null +++ b/config/solver/compression/default.yaml @@ -0,0 +1,160 @@ +# @package __global__ + +defaults: + - ../default + - override /dset: audio/default + - _self_ + +solver: compression +sample_rate: ??? +channels: ??? + +# loss balancing +losses: + adv: 4. + feat: 4. + l1: 0.1 + mel: 0. + msspec: 2. + sisnr: 0. +balancer: + balance_grads: true + ema_decay: 0.999 + per_batch_item: true + total_norm: 1. + +adversarial: + every: 1 + adversaries: [msstftd] + adv_loss: hinge + feat_loss: l1 + +# losses hyperparameters +l1: {} +l2: {} +mrstft: + factor_sc: .5 + factor_mag: .5 + normalized: false +mel: + sample_rate: ${sample_rate} + n_fft: 1024 + hop_length: 256 + win_length: 1024 + n_mels: 64 + f_min: 64 + f_max: null + normalized: false + floor_level: 1e-5 +sisnr: + sample_rate: ${sample_rate} + segment: 5. +msspec: + sample_rate: ${sample_rate} + range_start: 6 + range_end: 11 + n_mels: 64 + f_min: 64 + f_max: null + normalized: true + alphas: false + floor_level: 1e-5 + +# metrics +metrics: + visqol: + mode: audio + bin: null # path to visqol install + model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 + +# adversaries hyperparameters +msstftd: + in_channels: 1 + out_channels: 1 + filters: 32 + norm: weight_norm + n_ffts: [1024, 2048, 512, 256, 128] + hop_lengths: [256, 512, 128, 64, 32] + win_lengths: [1024, 2048, 512, 256, 128] + activation: LeakyReLU + activation_params: {negative_slope: 0.3} +msd: + in_channels: 1 + out_channels: 1 + scale_norms: [spectral_norm, weight_norm, weight_norm] + kernel_sizes: [5, 3] + filters: 16 + max_filters: 1024 + downsample_scales: [4, 4, 4, 4] + inner_kernel_sizes: null + groups: [4, 4, 4, 4] + strides: null + paddings: null + activation: LeakyReLU + activation_params: {negative_slope: 0.3} +mpd: + in_channels: 1 + out_channels: 1 + periods: [2, 3, 5, 7, 11] + n_layers: 5 + kernel_size: 5 + stride: 3 + filters: 8 + filter_scales: 4 + max_filters: 1024 + activation: LeakyReLU + activation_params: {negative_slope: 0.3} + norm: weight_norm + +# data hyperparameters +dataset: + batch_size: 64 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 500000 + valid: + num_samples: 10000 + evaluate: + batch_size: 32 + num_samples: 10000 + generate: + batch_size: 32 + num_samples: 50 + segment_duration: 10 + +# solver hyperparameters +evaluate: + every: 25 + num_workers: 5 + metrics: + visqol: false + sisnr: true +generate: + every: 25 + num_workers: 5 + audio: + sample_rate: ${sample_rate} + +# checkpointing schedule +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + +# optimization hyperparameters +optim: + epochs: 200 + updates_per_epoch: 2000 + lr: 3e-4 + max_norm: 0. + optimizer: adam + adam: + betas: [0.5, 0.9] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used diff --git a/config/solver/compression/encodec_audiogen_16khz.yaml b/config/solver/compression/encodec_audiogen_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..654deaa01ba9cace3f7144cc91921791c081b32a --- /dev/null +++ b/config/solver/compression/encodec_audiogen_16khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_large_nq4_s320 + - override /dset: audio/default + - _self_ + +channels: 1 +sample_rate: 16000 diff --git a/config/solver/compression/encodec_base_24khz.yaml b/config/solver/compression/encodec_base_24khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..018ad1cd61af84b616ad3088f055e8eaa36729eb --- /dev/null +++ b/config/solver/compression/encodec_base_24khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_base_causal + - override /dset: audio/default + - _self_ + +channels: 1 +sample_rate: 24000 diff --git a/config/solver/compression/encodec_musicgen_32khz.yaml b/config/solver/compression/encodec_musicgen_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eca4b90fb221372dace164fe59bb15822207a980 --- /dev/null +++ b/config/solver/compression/encodec_musicgen_32khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_large_nq4_s640 + - override /dset: audio/default + - _self_ + +channels: 1 +sample_rate: 32000 diff --git a/config/solver/default.yaml b/config/solver/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7452ea1e415516dceaaae86d692cbb8c811bd57 --- /dev/null +++ b/config/solver/default.yaml @@ -0,0 +1,108 @@ +# @package __global__ + +# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft +# Please don't update this file directly. Instead use distinct configuration files +# to override the below configuration. +solver: ??? + +fsdp: + use: false # should we use FSDP. + param_dtype: float16 # equivalent to autocast_dtype for FSDP. + reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability. + buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it. + sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard. + # full_shard will use less memory but slower ?? + per_block: true # If True, uses nested FSDP. + +profiler: + enabled: false + +deadlock: + use: false + timeout: 600 + +dataset: + batch_size: ??? + num_workers: 10 + segment_duration: null + num_samples: null + return_info: false + shuffle: false + sample_on_duration: true + sample_on_weight: true + min_segment_ratio: 0.5 + train: + num_samples: null + shuffle: true + shuffle_seed: 0 # if you want to sample the data differently. + permutation_on_files: false + valid: + num_samples: null + evaluate: + num_samples: null + generate: + num_samples: null + return_info: true + +checkpoint: + save_last: true + save_every: null + keep_last: null + keep_every_states: null + +generate: + every: null + path: 'samples' + audio: + format: 'mp3' + strategy: 'clip' + sample_rate: null + lm: + use_sampling: false + temp: 1.0 + top_k: 0 + top_p: 0.0 +evaluate: + every: null + num_workers: 5 + truncate_audio: null + fixed_generation_duration: null # in secs + metrics: + base: true # run default evaluation (e.g. like train/valid stage) + +optim: + epochs: ??? + updates_per_epoch: null + lr: ??? + optimizer: ??? + adam: + betas: [0.9, 0.999] + weight_decay: 0. + ema: + use: false # whether to use EMA or not + updates: ${optim.updates_per_epoch} # frequency of updates of the EMA + device: cpu # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + +schedule: + lr_scheduler: null + step: + step_size: null + gamma: null + exponential: + lr_decay: null + cosine: + warmup: null + lr_min_ratio: 0.0 + cycle_length: 1.0 + polynomial_decay: + warmup: null + zero_lr_warmup_steps: 0 + end_lr: 0.0 + power: 1 + inverse_sqrt: + warmup: null + warmup_init_lr: 0.0 + linear_warmup: + warmup: null + warmup_init_lr: 0.0 diff --git a/config/solver/diffusion/debug.yaml b/config/solver/diffusion/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc27c53486f7215a080d167032972402b90f5c77 --- /dev/null +++ b/config/solver/diffusion/debug.yaml @@ -0,0 +1,106 @@ +# @package __global__ + +defaults: + - /solver/default + - /model: score/basic + - override /dset: audio/default + - _self_ + +solver: diffusion + +sample_rate: 16000 +channels: 1 +compression_model_checkpoint: //sig/5091833e +n_q: 2 # number of codebooks to keep + +dataset: + batch_size: 8 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 100 + valid: + num_samples: 100 + evaluate: + batch_size: 8 + num_samples: 10 + generate: + batch_size: 8 + num_samples: 10 + segment_duration: 10 + +loss: + kind: mse + norm_power: 0. + +valid: + every: 1 + +evaluate: + every: 5 + num_workers: 5 + metrics: + visqol: false + sisnr: false + rvm: true + +generate: + every: 5 + num_workers: 5 + audio: + sample_rate: ${sample_rate} + +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + + +optim: + epochs: 50 + updates_per_epoch: 2000 + lr: 2e-4 + max_norm: 0 + optimizer: adam + adam: + betas: [0.9, 0.999] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + +processor: + name: multi_band_processor + use: false + n_bands: 8 + num_samples: 10_000 + power_std: 1. + +resampling: + use: false + target_sr: 16000 + +filter: + use: false + n_bands: 4 + idx_band: 0 + cutoffs: null + +schedule: + repartition: "power" + variable_step_batch: true + beta_t0: 1.0e-5 + beta_t1: 2.9e-2 + beta_exp: 7.5 + num_steps: 1000 + variance: 'beta' + clip: 5. + rescale: 1. + n_bands: null + noise_scale: 1.0 + +metrics: + num_stage: 4 diff --git a/config/solver/diffusion/default.yaml b/config/solver/diffusion/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3793d4d08d912db575c022a6803a8909c2b25273 --- /dev/null +++ b/config/solver/diffusion/default.yaml @@ -0,0 +1,107 @@ +# @package __global__ + +defaults: + - /solver/default + - /model: score/basic + - override /dset: audio/default + - _self_ + +solver: diffusion + +sample_rate: ??? +channels: ??? +compression_model_checkpoint: ??? +n_q: ??? # number of codebooks to keep + + +dataset: + batch_size: 128 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 500000 + valid: + num_samples: 10000 + evaluate: + batch_size: 16 + num_samples: 10000 + generate: + batch_size: 32 + num_samples: 50 + segment_duration: 10 + audio: + sample_rate: ${sample_rate} + +loss: + kind: mse + norm_power: 0. + +valid: + every: 1 + +evaluate: + every: 20 + num_workers: 5 + metrics: + visqol: false + sisnr: false + rvm: true + +generate: + every: 25 + num_workers: 5 + +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + + +optim: + epochs: 20000 + updates_per_epoch: 2000 + lr: 2e-4 + max_norm: 0 + optimizer: adam + adam: + betas: [0.9, 0.999] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + +processor: + name: multi_band_processor + use: false + n_bands: 8 + num_samples: 10_000 + power_std: 1. + +resampling: + use: false + target_sr: 16000 + +filter: + use: false + n_bands: 4 + idx_band: 0 + cutoffs: null + +schedule: + repartition: "power" + variable_step_batch: true + beta_t0: 1.0e-5 + beta_t1: 2.9e-2 + beta_exp: 7.5 + num_steps: 1000 + variance: 'beta' + clip: 5. + rescale: 1. + n_bands: null + noise_scale: 1.0 + +metrics: + num_stage: 4 diff --git a/config/solver/diffusion/encodec_24khz.yaml b/config/solver/diffusion/encodec_24khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..774e88f43d54980daef0c68d11717ddb7a214db1 --- /dev/null +++ b/config/solver/diffusion/encodec_24khz.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +defaults: + - diffusion/default + - _self_ + + +sample_rate: 24000 +channels: 1 +compression_model_checkpoint: //pretrained/facebook/encodec_24khz +n_q: 4 # num quantizers, 3kbps diff --git a/config/solver/musicgen/debug.yaml b/config/solver/musicgen/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec658f9d2fb0262cc8eab19d0cf333963c646a98 --- /dev/null +++ b/config/solver/musicgen/debug.yaml @@ -0,0 +1,55 @@ +# @package __global__ + +# This is a minimal debugging configuration +# for MusicGen training solver +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /model/lm/model_scale: xsmall + - override /dset: audio/example + - _self_ + +autocast: false +compression_model_checkpoint: //pretrained/debug_compression_model +transformer_lm: + n_q: 4 + card: 400 + +codebooks_pattern: + modeling: parallel + +channels: 1 +sample_rate: 32000 + +deadlock: + use: false # deadlock detection + +dataset: + batch_size: 4 + segment_duration: 5 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + audio: + strategy: peak + lm: + use_sampling: false + top_k: 0 + top_p: 0.0 + +checkpoint: + save_every: 0 + keep_last: 0 + +optim: + epochs: 2 + updates_per_epoch: 10 + optimizer: adamw + lr: 1e-4 + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: null diff --git a/config/solver/musicgen/default.yaml b/config/solver/musicgen/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59e011376fb2b909fe599bc86bf0ef4029ce5d6e --- /dev/null +++ b/config/solver/musicgen/default.yaml @@ -0,0 +1,119 @@ +# @package __global__ + +defaults: + - /solver/default + - /conditioner: none + - _self_ + - /solver/musicgen/evaluation: none + - override /dset: audio/default + +autocast: true +autocast_dtype: float16 + +solver: musicgen +sample_rate: ??? +channels: ??? +compression_model_checkpoint: ??? + +tokens: + padding_with_special_token: false + +cache: + path: + write: false + write_shard: 0 + write_num_shards: 1 + + +dataset: + batch_size: 128 + num_workers: 10 + segment_duration: 30 + min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence. + return_info: true + train: + num_samples: 1000000 # need a randomly large number here for AudioDataset + valid: + num_samples: 10000 + generate: + num_samples: 50 + +metrics: + fad: + use_gt: false + model: tf + tf: + bin: null # path to local frechet_audio_distance code + model_path: //reference/fad/vggish_model.ckpt + kld: + use_gt: false + model: passt + passt: + pretrained_length: 20 + text_consistency: + use_gt: false + model: clap + clap: + model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt + model_arch: 'HTSAT-base' + enable_fusion: false + chroma_cosine: + use_gt: false + model: chroma_base + chroma_base: + sample_rate: ${sample_rate} + n_chroma: 12 + radix2_exp: 14 + argmax: true + +generate: + every: 25 + num_workers: 5 + path: samples + audio: + format: wav + strategy: loudness + sample_rate: ${sample_rate} + loudness_headroom_db: 14 + lm: + prompted_samples: true + unprompted_samples: true + gen_gt_samples: false + prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4 + gen_duration: null # if not set, will use dataset.generate.segment_duration + remove_prompts: false + # generation params + use_sampling: false + temp: 1.0 + top_k: 0 + top_p: 0.0 +evaluate: + every: 25 + num_workers: 5 + metrics: + base: false + fad: false + kld: false + text_consistency: false + chroma_cosine: false + +checkpoint: + save_last: true + save_every: 50 + keep_last: 10 + keep_every_states: null + +optim: + epochs: 200 + updates_per_epoch: 2000 + lr: 1e-4 + optimizer: adamw + max_norm: 1.0 + eager_sync: true + adam: + betas: [0.9, 0.95] + weight_decay: 0.1 + eps: 1e-8 + +schedule: + lr_scheduler: null diff --git a/config/solver/musicgen/evaluation/none.yaml b/config/solver/musicgen/evaluation/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e739995ed6488700527529862a7a24f1afdcc7a --- /dev/null +++ b/config/solver/musicgen/evaluation/none.yaml @@ -0,0 +1,5 @@ +# @package __global__ + +dataset: + evaluate: + num_samples: 10000 diff --git a/config/solver/musicgen/evaluation/objective_eval.yaml b/config/solver/musicgen/evaluation/objective_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4881e9d86cddf36b306a75fb498253e1e12ec5be --- /dev/null +++ b/config/solver/musicgen/evaluation/objective_eval.yaml @@ -0,0 +1,24 @@ +# @package __global__ + +# Setup for execute only on musiccaps for audio generation +# evaluation with objective metrics +# execute_only=evaluate + +dataset: + max_audio_duration: null + # ensure the proper values are broadcasted here for evaluate + evaluate: + min_audio_duration: 1. # some metrics requires a minimum audio length + max_audio_duration: null # all samples from musiccaps should be < 20s + num_samples: null + segment_duration: null + generate: + min_audio_duration: 1. + max_audio_duration: null + num_samples: 500 + +evaluate: + metrics: + fad: true + kld: true + text_consistency: true diff --git a/config/solver/musicgen/musicgen_base_32khz.yaml b/config/solver/musicgen/musicgen_base_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b32c9c898a70718f91af862caa79f5553a5107e1 --- /dev/null +++ b/config/solver/musicgen/musicgen_base_32khz.yaml @@ -0,0 +1,55 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/default + - _self_ + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 192 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + lm: + use_sampling: true + top_k: 250 + top_p: 0.0 + +optim: + epochs: 500 + optimizer: dadam + lr: 1 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 diff --git a/config/solver/musicgen/musicgen_melody_32khz.yaml b/config/solver/musicgen/musicgen_melody_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ad3e0aeeb9583887d6e8ecd6d32a3dc69e102ed --- /dev/null +++ b/config/solver/musicgen/musicgen_melody_32khz.yaml @@ -0,0 +1,56 @@ +# @package __global__ + +# This is the training loop solver +# for the melody MusicGen model (text+chroma to music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /conditioner: chroma2music + - override /dset: audio/default + - _self_ + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 192 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + lm: + use_sampling: true + top_k: 250 + top_p: 0.0 + +optim: + epochs: 500 + optimizer: dadam + lr: 1 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 diff --git a/config/teams/default.yaml b/config/teams/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..407066df1e154208af2823a6e46d16df381c5d42 --- /dev/null +++ b/config/teams/default.yaml @@ -0,0 +1,12 @@ +default: + dora_dir: /tmp/audiocraft_${oc.env:USER} + partitions: + global: debug + team: debug + reference_dir: /tmp +darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc. + dora_dir: /tmp/audiocraft_${oc.env:USER} + partitions: + global: debug + team: debug + reference_dir: /tmp diff --git a/config/teams/labs.yaml b/config/teams/labs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..da350a94bc5758531ced5d9e4332624fe86f3d57 --- /dev/null +++ b/config/teams/labs.yaml @@ -0,0 +1,28 @@ +aws: + dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs + partitions: + global: learnlab + team: learnlab + reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference + dataset_mappers: + "^/checkpoint/[a-z]+": "/fsx-audio-craft-llm" +fair: + dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs + partitions: + global: learnlab + team: learnlab + reference_dir: /large_experiments/audiocraft/reference + dataset_mappers: + "^/datasets01/datasets01": "/datasets01" +darwin: + dora_dir: /tmp/audiocraft_${oc.env:USER} + partitions: + global: debug + team: debug + reference_dir: /tmp +rsc: + dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs + partitions: + global: learn + team: learn + reference_dir: /checkpoint/audiocraft/shared/reference diff --git a/dataset/example/electro_1.json b/dataset/example/electro_1.json new file mode 100644 index 0000000000000000000000000000000000000000..eeffc95038a1e031fad5598f822ddf2538d7f4da --- /dev/null +++ b/dataset/example/electro_1.json @@ -0,0 +1 @@ +{"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]} diff --git a/dataset/example/electro_1.mp3 b/dataset/example/electro_1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..8fa509266df4ee76519b82bfbea247cb0b18bcda Binary files /dev/null and b/dataset/example/electro_1.mp3 differ diff --git a/dataset/example/electro_2.json b/dataset/example/electro_2.json new file mode 100644 index 0000000000000000000000000000000000000000..3ee91c89c1d4b603f3e4d3fcc029618dc110e730 --- /dev/null +++ b/dataset/example/electro_2.json @@ -0,0 +1 @@ +{"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []} diff --git a/dataset/example/electro_2.mp3 b/dataset/example/electro_2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..01ab323e4322d08546635861959b868c3d7b416b Binary files /dev/null and b/dataset/example/electro_2.mp3 differ diff --git a/docs/AUDIOGEN.md b/docs/AUDIOGEN.md new file mode 100644 index 0000000000000000000000000000000000000000..a0ff481190fb52fe865aa66aaaa10176f7cf995c --- /dev/null +++ b/docs/AUDIOGEN.md @@ -0,0 +1,158 @@ +# AudioGen: Textually-guided audio generation + +AudioCraft provides the code and a model re-implementing AudioGen, a [textually-guided audio generation][audiogen_arxiv] +model that performs text-to-sound generation. + +The provided AudioGen reimplementation follows the LM model architecture introduced in [MusicGen][musicgen_arxiv] +and is a single stage auto-regressive Transformer model trained over a 16kHz +EnCodec tokenizer with 4 codebooks sampled at 50 Hz. +This model variant reaches similar audio quality than the original implementation introduced in the AudioGen publication +while providing faster generation speed given the smaller frame rate. + +**Important note:** The provided models are NOT the original models used to report numbers in the +[AudioGen publication][audiogen_arxiv]. Refer to the model card to learn more about architectural changes. + +Listen to samples from the **original AudioGen implementation** in our [sample page][audiogen_samples]. + + +## Model Card + +See [the model card](../model_cards/AUDIOGEN_MODEL_CARD.md). + + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + +AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). + +## API and usage + +We provide a simple API and 1 pre-trained models for AudioGen: + +`facebook/audiogen-medium`: 1.5B model, text to sound - [🤗 Hub](https://huggingface.co/facebook/audiogen-medium) + +You can play with AudioGen by running the jupyter notebook at [`demos/audiogen_demo.ipynb`](../demos/audiogen_demo.ipynb) locally (if you have a GPU). + +See after a quick example for using the API. + +```python +import torchaudio +from audiocraft.models import AudioGen +from audiocraft.data.audio import audio_write + +model = AudioGen.get_pretrained('facebook/audiogen-medium') +model.set_generation_params(duration=5) # generate 5 seconds. +descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor'] +wav = model.generate(descriptions) # generates 3 samples. + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +## Training + +The [AudioGenSolver](../audiocraft/solvers/audiogen.py) implements the AudioGen's training pipeline +used to develop the released model. Note that this may not fully reproduce the results presented in the paper. +Similarly to MusicGen, it defines an autoregressive language modeling task over multiple streams of +discrete tokens extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) +for more details on how to train such model) with dataset-specific changes for environmental sound +processing. + +Note that **we do NOT provide any of the datasets** used for training AudioGen. + +### Example configurations and grids + +We provide configurations to reproduce the released models and our research. +AudioGen solvers configuration are available in [config/solver/audiogen](../config/solver/audiogen). +The base training configuration used for the released models is the following: +[`solver=audiogen/audiogen_base_16khz`](../config/solver/audiogen/audiogen_base_16khz.yaml) + +Please find some example grids to train AudioGen at +[audiocraft/grids/audiogen](../audiocraft/grids/audiogen/). + +```shell +# text-to-sound +dora grid audiogen.audiogen_base_16khz +``` + +### Sound dataset and metadata + +AudioGen's underlying dataset is an AudioDataset augmented with description metadata. +The AudioGen dataset implementation expects the metadata to be available as `.json` files +at the same location as the audio files or through specified external folder. +Learn more in the [datasets section](./DATASETS.md). + +### Evaluation stage + +By default, evaluation stage is also computing the cross-entropy and the perplexity over the +evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run +or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md) +for more details on the requirements for each metric. + +We provide an off-the-shelf configuration to enable running the objective metrics +for audio generation in +[config/solver/audiogen/evaluation/objective_eval](../config/solver/audiogen/evaluation/objective_eval.yaml). + +One can then activate evaluation the following way: +```shell +# using the configuration +dora run solver=audiogen/debug solver/audiogen/evaluation=objective_eval +# specifying each of the fields, e.g. to activate KL computation +dora run solver=audiogen/debug evaluate.metrics.kld=true +``` + +See [an example evaluation grid](../audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py). + +### Generation stage + +The generation stage allows to generate samples conditionally and/or unconditionally and to perform +audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling +from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples +generated and the batch size used are controlled by the `dataset.generate` configuration +while the other generation parameters are defined in `generate.lm`. + +```shell +# control sampling parameters +dora run solver=audiogen/debug generate.lm.gen_duration=5 generate.lm.use_sampling=true generate.lm.top_k=15 +``` + +## More information + +Refer to [MusicGen's instructions](./MUSICGEN.md). + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation + +AudioGen +``` +@article{kreuk2022audiogen, + title={Audiogen: Textually guided audio generation}, + author={Kreuk, Felix and Synnaeve, Gabriel and Polyak, Adam and Singer, Uriel and D{\'e}fossez, Alexandre and Copet, Jade and Parikh, Devi and Taigman, Yaniv and Adi, Yossi}, + journal={arXiv preprint arXiv:2209.15352}, + year={2022} +} +``` + +MusicGen +``` +@article{copet2023simple, + title={Simple and Controllable Music Generation}, + author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + year={2023}, + journal={arXiv preprint arXiv:2306.05284}, +} +``` + +## License + +See license information in the [model card](../model_cards/AUDIOGEN_MODEL_CARD.md). + +[audiogen_arxiv]: https://arxiv.org/abs/2209.15352 +[musicgen_arxiv]: https://arxiv.org/abs/2306.05284 +[audiogen_samples]: https://felixkreuk.github.io/audiogen/ diff --git a/docs/CONDITIONING.md b/docs/CONDITIONING.md new file mode 100644 index 0000000000000000000000000000000000000000..6e356cb8e9912d3e18fc84598c1acf77c6e7abc5 --- /dev/null +++ b/docs/CONDITIONING.md @@ -0,0 +1,146 @@ +# AudioCraft conditioning modules + +AudioCraft provides a +[modular implementation of conditioning modules](../audiocraft/modules/conditioners.py) +that can be used with the language model to condition the generation. +The codebase was developed in order to easily extend the set of modules +currently supported to easily develop new ways of controlling the generation. + + +## Conditioning methods + +For now, we support 3 main types of conditioning within AudioCraft: +* Text-based conditioning methods +* Waveform-based conditioning methods +* Joint embedding conditioning methods for text and audio projected in a shared latent space. + +The Language Model relies on 2 core components that handle processing information: +* The `ConditionProvider` class, that maps metadata to processed conditions leveraging +all the defined conditioners for the given task. +* The `ConditionFuser` class, that takes preprocessed conditions and properly fuse the +conditioning embedding to the language model inputs following a given fusing strategy. + +Different conditioners (for text, waveform, joint embeddings...) are provided as torch +modules in AudioCraft and are used internally in the language model to process the +conditioning signals and feed them to the language model. + + +## Core concepts + +### Conditioners + +The `BaseConditioner` torch module is the base implementation for all conditioners in audiocraft. + +Each conditioner is expected to implement 2 methods: +* The `tokenize` method that is used as a preprocessing method that contains all processing +that can lead to synchronization points (e.g. BPE tokenization with transfer to the GPU). +The output of the tokenize method will then be used to feed the forward method. +* The `forward` method that takes the output of the tokenize method and contains the core computation +to obtain the conditioning embedding along with a mask indicating valid indices (e.g. padding tokens). + +### ConditionProvider + +The ConditionProvider prepares and provides conditions given a dictionary of conditioners. + +Conditioners are specified as a dictionary of attributes and the corresponding conditioner +providing the processing logic for the given attribute. + +Similarly to the conditioners, the condition provider works in two steps to avoid sychronization points: +* A `tokenize` method that takes a list of conditioning attributes for the batch, +and run all tokenize steps for the set of conditioners. +* A `forward` method that takes the output of the tokenize step and run all the forward steps +for the set of conditioners. + +The list of conditioning attributes is passed as a list of `ConditioningAttributes` +that is presented just below. + +### ConditionFuser + +Once all conditioning signals have been extracted and processed by the `ConditionProvider` +as dense embeddings, they remain to be passed to the language model along with the original +language model inputs. + +The `ConditionFuser` handles specifically the logic to combine the different conditions +to the actual model input, supporting different strategies to combine them. + +One can therefore define different strategies to combine or fuse the condition to the input, in particular: +* Prepending the conditioning signal to the input with the `prepend` strategy, +* Summing the conditioning signal to the input with the `sum` strategy, +* Combining the conditioning relying on a cross-attention mechanism with the `cross` strategy, +* Using input interpolation with the `input_interpolate` strategy. + +### SegmentWithAttributes and ConditioningAttributes: From metadata to conditions + +The `ConditioningAttributes` dataclass is the base class for metadata +containing all attributes used for conditioning the language model. + +It currently supports the following types of attributes: +* Text conditioning attributes: Dictionary of textual attributes used for text-conditioning. +* Wav conditioning attributes: Dictionary of waveform attributes used for waveform-based +conditioning such as the chroma conditioning. +* JointEmbed conditioning attributes: Dictionary of text and waveform attributes +that are expected to be represented in a shared latent space. + +These different types of attributes are the attributes that are processed +by the different conditioners. + +`ConditioningAttributes` are extracted from metadata loaded along the audio in the datasets, +provided that the metadata used by the dataset implements the `SegmentWithAttributes` abstraction. + +All metadata-enabled datasets to use for conditioning in AudioCraft inherits +the [`audiocraft.data.info_dataset.InfoAudioDataset`](../audiocraft/data/info_audio_dataset.py) class +and the corresponding metadata inherits and implements the `SegmentWithAttributes` abstraction. +Refer to the [`audiocraft.data.music_dataset.MusicAudioDataset`](../audiocraft/data/music_dataset.py) +class as an example. + + +## Available conditioners + +### Text conditioners + +All text conditioners are expected to inherit from the `TextConditioner` class. + +AudioCraft currently provides two text conditioners: +* The `LUTConditioner` that relies on look-up-table of embeddings learned at train time, +and relying on either no tokenizer or a spacy tokenizer. This conditioner is particularly +useful for simple experiments and categorical labels. +* The `T5Conditioner` that relies on a +[pre-trained T5 model](https://huggingface.co/docs/transformers/model_doc/t5) +frozen or fine-tuned at train time to extract the text embeddings. + +### Waveform conditioners + +All waveform conditioners are expected to inherit from the `WaveformConditioner` class and +consists of conditioning method that takes a waveform as input. The waveform conditioner +must implement the logic to extract the embedding from the waveform and define the downsampling +factor from the waveform to the resulting embedding. + +The `ChromaStemConditioner` conditioner is a waveform conditioner for the chroma features +conditioning used by MusicGen. It takes a given waveform, extract relevant stems for melody +(namely all non drums and bass stems) using a +[pre-trained Demucs model](https://github.com/facebookresearch/demucs) +and then extract the chromagram bins from the remaining mix of stems. + +### Joint embeddings conditioners + +We finally provide support for conditioning based on joint text and audio embeddings through +the `JointEmbeddingConditioner` class and the `CLAPEmbeddingConditioner` that implements such +a conditioning method relying on a [pretrained CLAP model](https://github.com/LAION-AI/CLAP). + +## Classifier Free Guidance + +We provide a Classifier Free Guidance implementation in AudioCraft. With the classifier free +guidance dropout, all attributes are dropped with the same probability. + +## Attribute Dropout + +We further provide an attribute dropout strategy. Unlike the classifier free guidance dropout, +the attribute dropout drops given attributes with a defined probability, allowing the model +not to expect all conditioning signals to be provided at once. + +## Faster computation of conditions + +Conditioners that require some heavy computation on the waveform can be cached, in particular +the `ChromaStemConditioner` or `CLAPEmbeddingConditioner`. You just need to provide the +`cache_path` parameter to them. We recommend running dummy jobs for filling up the cache quickly. +An example is provied in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py). \ No newline at end of file diff --git a/docs/DATASETS.md b/docs/DATASETS.md new file mode 100644 index 0000000000000000000000000000000000000000..b0890c03cf732450eb498559638c6b45d50e40c3 --- /dev/null +++ b/docs/DATASETS.md @@ -0,0 +1,82 @@ +# AudioCraft datasets + +Our dataset manifest files consist in 1-json-per-line files, potentially gzipped, +as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio +file and associated metadata. The manifest files are then provided in the configuration, +as `datasource` sub-configuration. A datasource contains the pointers to the paths of +the manifest files for each AudioCraft stage (or split) along with additional information +(eg. maximum sample rate to use against this dataset). All the datasources are under the +`dset` group config, with a dedicated configuration file for each dataset. + +## Getting started + +### Example + +See the provided example in the directory that provides a manifest to use the example dataset +provided under the [dataset folder](../dataset/example). + +The manifest files are stored in the [egs folder](../egs/example). + +```shell +egs/ + example/data.json.gz +``` + +A datasource is defined in the configuration folder, in the dset group config for this dataset +at [config/dset/audio/example](../config/dset/audio/example.yaml): + +```shell +# @package __global__ + +datasource: + max_sample_rate: 44100 + max_channels: 2 + + train: egs/example + valid: egs/example + evaluate: egs/example + generate: egs/example +``` + +For proper dataset, one should create manifest for each of the splits and specify the correct path +to the given manifest in the datasource for each split. + +Then, using a dataset through the configuration can be done pointing to the +corresponding dataset configuration: +```shell +dset= # should match the yaml file name + +# for example +dset=audio/example +``` + +### Creating manifest files + +Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use +the following command to create new manifest files from a given folder containing audio files: + +```shell +python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz + +# For example to generate the manifest for dset=audio/example +# note: we don't use any split and we don't compress the jsonl file for this dummy example +python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl + +# More info with: python -m audiocraft.data.audio_dataset --help +``` + +## Additional information + +### MusicDataset and metadata + +The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects +the additional metadata to be stored in a JSON file that has the same path as the corresponding +audio file, but with a `.json` extension. + +### SoundDataset and metadata + +The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset, +the SoundDataset expects the additional metadata to be stored in a JSON file that has the same +path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset +supports an additional parameter pointing to an extra folder `external_metadata_source` containing +all the JSON metadata files given they have the same filename as the audio file. diff --git a/docs/ENCODEC.md b/docs/ENCODEC.md new file mode 100644 index 0000000000000000000000000000000000000000..efc2bcc7ec50190b907c887b920b70fd799c6953 --- /dev/null +++ b/docs/ENCODEC.md @@ -0,0 +1,179 @@ +# EnCodec: High Fidelity Neural Audio Compression + +AudioCraft provides the training code for EnCodec, a state-of-the-art deep learning +based audio codec supporting both mono stereo audio, presented in the +[High Fidelity Neural Audio Compression][arxiv] paper. +Check out our [sample page][encodec_samples]. + +## Original EnCodec models + +The EnCodec models presented in High Fidelity Neural Audio Compression can be accessed +and used with the [EnCodec repository](https://github.com/facebookresearch/encodec). + +**Note**: We do not guarantee compatibility between the AudioCraft and EnCodec codebases +and released checkpoints at this stage. + + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + + +## Training + +The [CompressionSolver](../audiocraft/solvers/compression.py) implements the audio reconstruction +task to train an EnCodec model. Specifically, it trains an encoder-decoder with a quantization +bottleneck - a SEANet encoder-decoder with Residual Vector Quantization bottleneck for EnCodec - +using a combination of objective and perceptual losses in the forms of discriminators. + +The default configuration matches a causal EnCodec training with at a single bandwidth. + +### Example configuration and grids + +We provide sample configuration and grids for training EnCodec models. + +The compression configuration are defined in +[config/solver/compression](../config/solver/compression). + +The example grids are available at +[audiocraft/grids/compression](../audiocraft/grids/compression). + +```shell +# base causal encodec on monophonic audio sampled at 24 khz +dora grid compression.encodec_base_24khz +# encodec model used for MusicGen on monophonic audio sampled at 32 khz +dora grid compression.encodec_musicgen_32khz +``` + +### Training and valid stages + +The model is trained using a combination of objective and perceptual losses. +More specifically, EnCodec is trained with the MS-STFT discriminator along with +objective losses through the use of a loss balancer to effectively weight +the different losses, in an intuitive manner. + +### Evaluation stage + +Evaluations metrics for audio generation: +* SI-SNR: Scale-Invariant Signal-to-Noise Ratio. +* ViSQOL: Virtual Speech Quality Objective Listener. + +Note: Path to the ViSQOL binary (compiled with bazel) needs to be provided in +order to run the ViSQOL metric on the reference and degraded signals. +The metric is disabled by default. +Please refer to the [metrics documentation](../METRICS.md) to learn more. + +### Generation stage + +The generation stage consists in generating the reconstructed audio from samples +with the current model. The number of samples generated and the batch size used are +controlled by the `dataset.generate` configuration. The output path and audio formats +are defined in the generate stage configuration. + +```shell +# generate samples every 5 epoch +dora run solver=compression/encodec_base_24khz generate.every=5 +# run with a different dset +dora run solver=compression/encodec_base_24khz generate.path= +# limit the number of samples or use a different batch size +dora grid solver=compression/encodec_base_24khz dataset.generate.num_samples=10 dataset.generate.batch_size=4 +``` + +### Playing with the model + +Once you have a model trained, it is possible to get the entire solver, or just +the trained model with the following functions: + +```python +from audiocraft.solvers import CompressionSolver + +# If you trained a custom model with signature SIG. +model = CompressionSolver.model_from_checkpoint('//sig/SIG') +# If you want to get one of the pretrained models with the `//pretrained/` prefix. +model = CompressionSolver.model_from_checkpoint('//pretrained/facebook/encodec_32khz') +# Or load from a custom checkpoint path +model = CompressionSolver.model_from_checkpoint('/my_checkpoints/foo/bar/checkpoint.th') + + +# If you only want to use a pretrained model, you can also directly get it +# from the CompressionModel base model class. +from audiocraft.models import CompressionModel + +# Here do not put the `//pretrained/` prefix! +model = CompressionModel.get_pretrained('facebook/encodec_32khz') +model = CompressionModel.get_pretrained('dac_44khz') + +# Finally, you can also retrieve the full Solver object, with its dataloader etc. +from audiocraft import train +from pathlib import Path +import logging +import os +import sys + +# uncomment the following line if you want some detailed logs when loading a Solver. +logging.basicConfig(stream=sys.stderr, level=logging.INFO) +# You must always run the following function from the root directory. +os.chdir(Path(train.__file__).parent.parent) + + +# You can also get the full solver (only for your own experiments). +# You can provide some overrides to the parameters to make things more convenient. +solver = train.get_solver_from_sig('SIG', {'device': 'cpu', 'dataset': {'batch_size': 8}}) +solver.model +solver.dataloaders +``` + +### Importing / Exporting models + +At the moment we do not have a definitive workflow for exporting EnCodec models, for +instance to Hugging Face (HF). We are working on supporting automatic convertion between +AudioCraft and Hugging Face implementations. + +We still have some support for fine tuning an EnCodec model coming from HF in AudioCraft, +using for instance `continue_from=//pretrained/facebook/encodec_32k`. + +An AudioCraft checkpoint can be exported in a more compact format (excluding the optimizer etc.) +using `audiocraft.utils.export.export_encodec`. For instance, you could run + +```python +from audiocraft.utils import export +from audiocraft import train +xp = train.main.get_xp_from_sig('SIG') +export.export_encodec( + xp.folder / 'checkpoint.th', + '/checkpoints/my_audio_lm/compression_state_dict.bin') + + +from audiocraft.models import CompressionModel +model = CompressionModel.get_pretrained('/checkpoints/my_audio_lm/compression_state_dict.bin') + +from audiocraft.solvers import CompressionSolver +# The two are strictly equivalent, but this function supports also loading from non already exported models. +model = CompressionSolver.model_from_checkpoint('//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin') +``` + +We will see then how to use this model as a tokenizer for MusicGen/Audio gen in the +[MusicGen documentation](./MUSICGEN.md). + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation +``` +@article{defossez2022highfi, + title={High Fidelity Neural Audio Compression}, + author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, + journal={arXiv preprint arXiv:2210.13438}, + year={2022} +} +``` + + +## License + +See license information in the [README](../README.md). + +[arxiv]: https://arxiv.org/abs/2210.13438 +[encodec_samples]: https://ai.honu.io/papers/encodec/samples.html diff --git a/docs/MBD.md b/docs/MBD.md new file mode 100644 index 0000000000000000000000000000000000000000..4288a89dfd2bb99a42ebe7c3da3eba39a7acc227 --- /dev/null +++ b/docs/MBD.md @@ -0,0 +1,117 @@ +# MultiBand Diffusion + +AudioCraft provides the code and models for MultiBand Diffusion, [From Discrete Tokens to High Fidelity Audio using MultiBand Diffusion][arxiv]. +MultiBand diffusion is a collection of 4 models that can decode tokens from +EnCodec tokenizer into waveform audio. You can listen to some examples on the sample page. + + + Open In Colab + +
+ + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + + +## Usage + +We offer a number of way to use MultiBand Diffusion: +1. The MusicGen demo includes a toggle to try diffusion decoder. You can use the demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py), or through the [MusicGen Colab](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing). +2. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). + +## API + +We provide a simple API and pre-trained models for MusicGen and for EnCodec at 24 khz for 3 bitrates (1.5 kbps, 3 kbps and 6 kbps). + +See after a quick example for using MultiBandDiffusion with the MusicGen API: + +```python +import torchaudio +from audiocraft.models import MusicGen, MultiBandDiffusion +from audiocraft.data.audio import audio_write + +model = MusicGen.get_pretrained('facebook/musicgen-melody') +mbd = MultiBandDiffusion.get_mbd_musicgen() +model.set_generation_params(duration=8) # generate 8 seconds. +wav, tokens = model.generate_unconditional(4, return_tokens=True) # generates 4 unconditional audio samples and keep the tokens for MBD generation +descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] +wav_diffusion = mbd.tokens_to_wav(tokens) +wav, tokens = model.generate(descriptions, return_tokens=True) # generates 3 samples and keep the tokens. +wav_diffusion = mbd.tokens_to_wav(tokens) +melody, sr = torchaudio.load('./assets/bach.mp3') +# Generates using the melody from the given audio and the provided descriptions, returns audio and audio tokens. +wav, tokens = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr, return_tokens=True) +wav_diffusion = mbd.tokens_to_wav(tokens) + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav and {idx}_diffusion.wav, with loudness normalization at -14 db LUFS for comparing the methods. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) + audio_write(f'{idx}_diffusion', wav_diffusion[idx].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +For the compression task (and to compare with [EnCodec](https://github.com/facebookresearch/encodec)): + +```python +import torch +from audiocraft.models import MultiBandDiffusion +from encodec import EncodecModel +from audiocraft.data.audio import audio_read, audio_write + +bandwidth = 3.0 # 1.5, 3.0, 6.0 +mbd = MultiBandDiffusion.get_mbd_24khz(bw=bandwidth) +encodec = EncodecModel.encodec_model_24khz() + +somepath = '' +wav, sr = audio_read(somepath) +with torch.no_grad(): + compressed_encodec = encodec(wav) + compressed_diffusion = mbd.regenerate(wav, sample_rate=sr) + +audio_write('sample_encodec', compressed_encodec.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True) +audio_write('sample_diffusion', compressed_diffusion.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True) +``` + + +## Training + +The [DiffusionSolver](../audiocraft/solvers/diffusion.py) implements our diffusion training pipeline. +It generates waveform audio conditioned on the embeddings extracted from a pre-trained EnCodec model +(see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model). + +Note that **we do NOT provide any of the datasets** used for training our diffusion models. +We provide a dummy dataset containing just a few examples for illustrative purposes. + +### Example configurations and grids + +One can train diffusion models as described in the paper by using this [dora grid](../audiocraft/grids/diffusion/4_bands_base_32khz.py). +```shell +# 4 bands MBD trainning +dora grid diffusion.4_bands_base_32khz +``` + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation + +``` +@article{sanroman2023fromdi, + title={From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion}, + author={San Roman, Robin and Adi, Yossi and Deleforge, Antoine and Serizel, Romain and Synnaeve, Gabriel and Défossez, Alexandre}, + journal={arXiv preprint arXiv:}, + year={2023} +} +``` + + +## License + +See license information in the [README](../README.md). + + +[arxiv]: https://dl.fbaipublicfiles.com/encodec/Diffusion/paper.pdf +[mbd_samples]: https://ai.honu.io/papers/mbd/ diff --git a/docs/METRICS.md b/docs/METRICS.md new file mode 100644 index 0000000000000000000000000000000000000000..e2ae9a184cbccb8bfefb4ce77afa5ddab743a051 --- /dev/null +++ b/docs/METRICS.md @@ -0,0 +1,127 @@ +# AudioCraft objective metrics + +In addition to training losses, AudioCraft provides a set of objective metrics +for audio synthesis and audio generation. As these metrics may require +extra dependencies and can be costly to train, they are often disabled by default. +This section provides guidance for setting up and using these metrics in +the AudioCraft training pipelines. + +## Available metrics + +### Audio synthesis quality metrics + +#### SI-SNR + +We provide an implementation of the Scale-Invariant Signal-to-Noise Ratio in PyTorch. +No specific requirement is needed for this metric. Please activate the metric at the +evaluation stage with the appropriate flag: + +```shell +dora run <...> evaluate.metrics.sisnr=true +``` + +#### ViSQOL + +We provide a Python wrapper around the ViSQOL [official implementation](https://github.com/google/visqol) +to conveniently run ViSQOL within the training pipelines. + +One must specify the path to the ViSQOL installation through the configuration in order +to enable ViSQOL computations in AudioCraft: + +```shell +# the first parameter is used to activate visqol computation while the second specify +# the path to visqol's library to be used by our python wrapper +dora run <...> evaluate.metrics.visqol=true metrics.visqol.bin= +``` + +See an example grid: [Compression with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py) + +To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the +instructions available in the [open source repository](https://github.com/google/visqol). + +### Audio generation metrics + +#### Frechet Audio Distance + +Similarly to ViSQOL, we use a Python wrapper around the Frechet Audio Distance +[official implementation](https://github.com/google-research/google-research/tree/master/frechet_audio_distance) +in TensorFlow. + +Note that we had to make several changes to the actual code in order to make it work. +Please refer to the [FrechetAudioDistanceMetric](../audiocraft/metrics/fad.py) class documentation +for more details. We do not plan to provide further support in obtaining a working setup for the +Frechet Audio Distance at this stage. + +```shell +# the first parameter is used to activate FAD metric computation while the second specify +# the path to FAD library to be used by our python wrapper +dora run <...> evaluate.metrics.fad=true metrics.fad.bin= +``` + +See an example grid: [Evaluation with FAD](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py) + +#### Kullback-Leibler Divergence + +We provide a PyTorch implementation of the Kullback-Leibler Divergence computed over the probabilities +of the labels obtained by a state-of-the-art audio classifier. We provide our implementation of the KLD +using the [PaSST classifier](https://github.com/kkoutini/PaSST). + +In order to use the KLD metric over PaSST, you must install the PaSST library as an extra dependency: +```shell +pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' +``` + +Then similarly, you can use the metric activating the corresponding flag: + +```shell +# one could extend the kld metric with additional audio classifier models that can then be picked through the configuration +dora run <...> evaluate.metrics.kld=true metrics.kld.model=passt +``` + +#### Text consistency + +We provide a text-consistency metric, similarly to the MuLan Cycle Consistency from +[MusicLM](https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in +[Make-An-Audio](https://arxiv.org/pdf/2301.12661v1.pdf). +More specifically, we provide a PyTorch implementation of a Text consistency metric +relying on a pre-trained [Contrastive Language-Audio Pretraining (CLAP)](https://github.com/LAION-AI/CLAP). + +Please install the CLAP library as an extra dependency prior to using the metric: +```shell +pip install laion_clap +``` + +Then similarly, you can use the metric activating the corresponding flag: + +```shell +# one could extend the text consistency metric with additional audio classifier models that can then be picked through the configuration +dora run ... evaluate.metrics.text_consistency=true metrics.text_consistency.model=clap +``` + +Note that the text consistency metric based on CLAP will require the CLAP checkpoint to be +provided in the configuration. + +#### Chroma cosine similarity + +Finally, as introduced in MusicGen, we provide a Chroma Cosine Similarity metric in PyTorch. +No specific requirement is needed for this metric. Please activate the metric at the +evaluation stage with the appropriate flag: + +```shell +dora run ... evaluate.metrics.chroma_cosine=true +``` + +#### Comparing against reconstructed audio + +For all the above audio generation metrics, we offer the option to compute the metric on the reconstructed audio +fed in EnCodec instead of the generated sample using the flag `.use_gt=true`. + +## Example usage + +You will find example of configuration for the different metrics introduced above in: +* The [musicgen's default solver](../config/solver/musicgen/default.yaml) for all audio generation metrics +* The [compression's default solver](../config/solver/compression/default.yaml) for all audio synthesis metrics + +Similarly, we provide different examples in our grids: +* [Evaluation with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py) +* [Evaluation with FAD and others](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py) diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md new file mode 100644 index 0000000000000000000000000000000000000000..661144b56e830185244dfe3d2637503bcf6d4c2b --- /dev/null +++ b/docs/MUSICGEN.md @@ -0,0 +1,373 @@ +# MusicGen: Simple and Controllable Music Generation + +AudioCraft provides the code and models for MusicGen, [a simple and controllable model for music generation](https://arxiv.org/abs/2306.05284). +MusicGen is a single stage auto-regressive Transformer model trained over a 32kHz +EnCodec tokenizer with 4 codebooks sampled at 50 Hz. +Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require +a self-supervised semantic representation, and it generates all 4 codebooks in one pass. By introducing +a small delay between the codebooks, we show we can predict them in parallel, thus having only 50 auto-regressive +steps per second of audio. +Check out our [sample page](https://ai.honu.io/papers/musicgen/) or test the available demo! + + + Open In Colab + + + Open in HugginFace + +
+ +We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset +of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. + +## Model Card + +See [the model card](../model_cards/MUSICGEN_MODEL_CARD.md). + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + +AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). + +## Usage + +We offer a number of way to interact with MusicGen: + +1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co/spaces/facebook/MusicGen) + (huge thanks to all the HF team for their support). +2. You can run the extended demo on a Colab: + [colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing) +3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py). +4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). +5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) + which is regularly updated with contributions from @camenduru and the community. + +## API + +We provide a simple API and 4 pre-trained models. The pre trained models are: + +- `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small) +- `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium) +- `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody) +- `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large) + +We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model. +In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller +GPUs will be able to generate short sequences, or longer sequences with the `facebook/musicgen-small` model. + +See after a quick example for using the API. + +```python +import torchaudio +from audiocraft.models import MusicGen +from audiocraft.data.audio import audio_write + +model = MusicGen.get_pretrained('facebook/musicgen-melody') +model.set_generation_params(duration=8) # generate 8 seconds. +wav = model.generate_unconditional(4) # generates 4 unconditional audio samples +descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] +wav = model.generate(descriptions) # generates 3 samples. + +melody, sr = torchaudio.load('./assets/bach.mp3') +# generates using the melody from the given audio and the provided descriptions. +wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr) + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) + +``` + +## 🤗 Transformers Usage + +MusicGen is available in the 🤗 Transformers library from version 4.31.0 onwards, requiring minimal dependencies +and additional packages. Steps to get started: + +1. First install the 🤗 [Transformers library](https://github.com/huggingface/transformers) from main: + +```shell +pip install git+https://github.com/huggingface/transformers.git + +``` + +2. Run the following Python code to generate text-conditional audio samples: + +```py +from transformers import AutoProcessor, MusicgenForConditionalGeneration + + +processor = AutoProcessor.from_pretrained("facebook/musicgen-small") +model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + +inputs = processor( + text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + padding=True, + return_tensors="pt", +) + +audio_values = model.generate(**inputs, max_new_tokens=256) + +``` + +3. Listen to the audio samples either in an ipynb notebook: + +```py +from IPython.display import Audio + +sampling_rate = model.config.audio_encoder.sampling_rate +Audio(audio_values[0].numpy(), rate=sampling_rate) + +``` + +Or save them as a `.wav` file using a third-party library, e.g. `scipy`: + +```py +import scipy + +sampling_rate = model.config.audio_encoder.sampling_rate +scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy()) + +``` + +For more details on using the MusicGen model for inference using the 🤗 Transformers library, refer to the +[MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on +[Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb). + +## Training + +The [MusicGenSolver](../audiocraft/solvers/musicgen.py) implements MusicGen's training pipeline. +It defines an autoregressive language modeling task over multiple streams of discrete tokens +extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) +for more details on how to train such model). + +Note that **we do NOT provide any of the datasets** used for training MusicGen. +We provide a dummy dataset containing just a few examples for illustrative purposes. + +Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section. + +### Example configurations and grids + +We provide configurations to reproduce the released models and our research. +MusicGen solvers configuration are available in [config/solver/musicgen](../config/solver/musicgen), +in particular: + +* MusicGen base model for text-to-music: + [`solver=musicgen/musicgen_base_32khz`](../config/solver/musicgen/musicgen_base_32khz.yaml) +* MusicGen model with chromagram-conditioning support: + [`solver=musicgen/musicgen_melody_32khz`](../config/solver/musicgen/musicgen_melody_32khz.yaml) + +We provide 3 different scales, e.g. `model/lm/model_scale=small` (300M), or `medium` (1.5B), and `large` (3.3B). + +Please find some example grids to train MusicGen at +[audiocraft/grids/musicgen](../audiocraft/grids/musicgen/). + +```shell +# text-to-music +dora grid musicgen.musicgen_base_32khz --dry_run --init +# melody-guided music generation +dora grid musicgen.musicgen_melody_base_32khz --dry_run --init +# Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup. + +``` + +### Music dataset and metadata + +MusicGen's underlying dataset is an AudioDataset augmented with music-specific metadata. +The MusicGen dataset implementation expects the metadata to be available as `.json` files +at the same location as the audio files. Learn more in the [datasets section](./DATASETS.md). + +### Audio tokenizers + +We support a number of audio tokenizers: either pretrained EnCodec models, [DAC](https://github.com/descriptinc/descript-audio-codec), or your own models. +The tokenizer is controlled with the setting `compression_model_checkpoint`. +For instance, + +```bash +# Using the 32kHz EnCodec trained on music +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained/facebook/encodec_32khz \ + transformer_lm.n_q=4 transformer_lm.card=2048 + +# Using DAC +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained/dac_44khz \ + transformer_lm.n_q=9 transformer_lm.card=1024 \ + 'codebooks_pattern.delay.delays=[0,1,2,3,4,5,6,7,8]' + +# Using your own model after export (see ENCODEC.md) +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin \ + transformer_lm.n_q=... transformer_lm.card=... + +# Using your own model from its training checkpoint. +dora run solver=musicgen/debug \ + compression_model_checkpoint=//sig/SIG \ # where SIG is the Dora signature of the EnCodec XP. + transformer_lm.n_q=... transformer_lm.card=... + +``` + +__Warning:__ you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. . + +### Fine tuning existing models + +You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular + +```bash +# Using pretrained MusicGen model. +dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//pretrained/facebook/musicgen-medium conditioner=text2music + +# Using another model you already trained with a Dora signature SIG. +dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//sig/SIG conditioner=text2music + +# Or providing manually a path +dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=/checkpoints/my_other_xp/checkpoint.th + +``` + +__Warning:__ You are responsible for selecting the other parameters accordingly, in a way that make it compatible +with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`. + +__Warning:__ We currently do not support fine tuning a model with slightly different layers. If you decide +to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`. +If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict +`{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix. + +### Caching of EnCodec tokens + +It is possible to precompute the EnCodec tokens and other metadata. +An example of generating and using this cache provided in the [musicgen.musicgen_base_cached_32khz grid](../audiocraft/grids/musicgen/musicgen_base_cached_32khz.py). + +### Evaluation stage + +By default, evaluation stage is also computing the cross-entropy and the perplexity over the +evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run +or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md) +for more details on the requirements for each metric. + +We provide an off-the-shelf configuration to enable running the objective metrics +for audio generation in +[config/solver/musicgen/evaluation/objective_eval](../config/solver/musicgen/evaluation/objective_eval.yaml). + +One can then activate evaluation the following way: + +```shell +# using the configuration +dora run solver=musicgen/debug solver/musicgen/evaluation=objective_eval +# specifying each of the fields, e.g. to activate KL computation +dora run solver=musicgen/debug evaluate.metrics.kld=true + +``` + +See [an example evaluation grid](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py). + +### Generation stage + +The generation stage allows to generate samples conditionally and/or unconditionally and to perform +audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling +from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples +generated and the batch size used are controlled by the `dataset.generate` configuration +while the other generation parameters are defined in `generate.lm`. + +```shell +# control sampling parameters +dora run solver=musicgen/debug generate.lm.gen_duration=10 generate.lm.use_sampling=true generate.lm.top_k=15 + +``` + +#### Listening to samples + +Note that generation happens automatically every 25 epochs. You can easily access and +compare samples between models (as long as they are trained) on the same dataset using the +MOS tool. For that first `pip install Flask gunicorn`. Then + +```sh +gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - + +``` + +And access the tool at [https://127.0.0.1:8895](https://127.0.0.1:8895). + +### Playing with the model + +Once you have launched some experiments, you can easily get access +to the Solver with the latest trained model using the following snippet. + +```python +from audiocraft.solvers.musicgen import MusicGen + +solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) +solver.model +solver.dataloaders + +``` + +### Importing / Exporting models + +We do not support currently loading a model from the Hugging Face implementation or exporting to it. +If you want to export your model in a way that is compatible with `audiocraft.models.MusicGen` +API, you can run: + +```python +from audiocraft.utils import export +from audiocraft import train +xp = train.main.get_xp_from_sig('SIG_OF_LM') +export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin') +# You also need to bundle the EnCodec model you used !! +## Case 1) you trained your own +xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC') +export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin') +## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix. +## This will actually not dump the actual model, simply a pointer to the right model to download. +export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin') + +``` + +Now you can load your custom model with: + +```python +import audiocraft.models +musicgen = audiocraft.models.MusicGen.get_pretrained('/checkpoints/my_audio_lm/') + +``` + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + +## FAQ + +#### I need help on Windows + +@FurkanGozukara made a complete tutorial for [AudioCraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4) + +#### I need help for running the demo on Colab + +Check [@camenduru tutorial on YouTube](https://www.youtube.com/watch?v=EGfxuTy9Eeo). + +#### What are top-k, top-p, temperature and classifier-free guidance? + +Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt). + +#### Should I use FSDP or autocast ? + +The two are mutually exclusive (because FSDP does autocast on its own). +You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU. +FSDP makes everything more complex but will free up some memory for the actual +activations by sharding the optimizer state. + +## Citation + +```json +@article{copet2023simple, + title={Simple and Controllable Music Generation}, + author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + year={2023}, + journal={arXiv preprint arXiv:2306.05284}, +} + +``` + +## License + +See license information in the [model card](../model_cards/MUSICGEN_MODEL_CARD.md). diff --git a/docs/TRAINING.md b/docs/TRAINING.md new file mode 100644 index 0000000000000000000000000000000000000000..148de295f2ddfed2e4e893576bf31e1485038b8e --- /dev/null +++ b/docs/TRAINING.md @@ -0,0 +1,312 @@ +# AudioCraft training pipelines + +AudioCraft training pipelines are built on top of PyTorch as our core deep learning library +and [Flashy](https://github.com/facebookresearch/flashy) as our training pipeline design library, +and [Dora](https://github.com/facebookresearch/dora) as our experiment manager. +AudioCraft training pipelines are designed to be research and experiment-friendly. + + +## Environment setup + +For the base installation, follow the instructions from the [README.md](../README.md). +Below are some additional instructions for setting up environment to train new models. + +### Team and cluster configuration + +In order to support multiple teams and clusters, AudioCraft uses an environment configuration. +The team configuration allows to specify cluster-specific configurations (e.g. SLURM configuration), +or convenient mapping of paths between the supported environments. + +Each team can have a yaml file under the [configuration folder](../config). To select a team set the +`AUDIOCRAFT_TEAM` environment variable to a valid team name (e.g. `labs` or `default`): +```shell +conda env config vars set AUDIOCRAFT_TEAM=default +``` + +Alternatively, you can add it to your `.bashrc`: +```shell +export AUDIOCRAFT_TEAM=default +``` + +If not defined, the environment will default to the `default` team. + +The cluster is automatically detected, but it is also possible to override it by setting +the `AUDIOCRAFT_CLUSTER` environment variable. + +Based on this team and cluster, the environment is then configured with: +* The dora experiment outputs directory. +* The available slurm partitions: categorized by global and team. +* A shared reference directory: In order to facilitate sharing research models while remaining +agnostic to the used compute cluster, we created the `//reference` symbol that can be used in +YAML config to point to a defined reference folder containing shared checkpoints +(e.g. baselines, models for evaluation...). + +**Important:** The default output dir for trained models and checkpoints is under `/tmp/`. This is suitable +only for quick testing. If you are doing anything serious you MUST edit the file `default.yaml` and +properly set the `dora_dir` entries. + +#### Overriding environment configurations + +You can set the following environmet variables to bypass the team's environment configuration: +* `AUDIOCRAFT_CONFIG`: absolute path to a team config yaml file. +* `AUDIOCRAFT_DORA_DIR`: absolute path to a custom dora directory. +* `AUDIOCRAFT_REFERENCE_DIR`: absolute path to the shared reference directory. + +## Training pipelines + +Each task supported in AudioCraft has its own training pipeline and dedicated solver. +Learn more about solvers and key designs around AudioCraft training pipeline below. +Please refer to the documentation of each task and model for specific information on a given task. + + +### Solvers + +The core training component in AudioCraft is the solver. A solver holds the definition +of how to solve a given task: It implements the training pipeline logic, combining the datasets, +model, optimization criterion and components and the full training loop. We refer the reader +to [Flashy](https://github.com/facebookresearch/flashy) for core principles around solvers. + +AudioCraft proposes an initial solver, the `StandardSolver` that is used as the base implementation +for downstream solvers. This standard solver provides a nice base management of logging, +checkpoints loading/saving, xp restoration, etc. on top of the base Flashy implementation. +In AudioCraft, we made the assumption that all tasks are following the same set of stages: +train, valid, evaluate and generation, each relying on a dedicated dataset. + +Each solver is responsible for defining the task to solve and the associated stages +of the training loop in order to leave the full ownership of the training pipeline +to the researchers. This includes loading the datasets, building the model and +optimisation components, registering them and defining the execution of each stage. +To create a new solver for a given task, one should extend the StandardSolver +and define each stage of the training loop. One can further customise its own solver +starting from scratch instead of inheriting from the standard solver. + +```python +from . import base +from .. import optim + + +class MyNewSolver(base.StandardSolver): + + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + # one can add custom attributes to the solver + self.criterion = torch.nn.L1Loss() + + def best_metric(self): + # here optionally specify which metric to use to keep track of best state + return 'loss' + + def build_model(self): + # here you can instantiate your models and optimization related objects + # this method will be called by the StandardSolver init method + self.model = ... + # the self.cfg attribute contains the raw configuration + self.optimizer = optim.build_optimizer(self.model.parameters(), self.cfg.optim) + # don't forget to register the states you'd like to include in your checkpoints! + self.register_stateful('model', 'optimizer') + # keep the model best state based on the best value achieved at validation for the given best_metric + self.register_best('model') + # if you want to add EMA around the model + self.register_ema('model') + + def build_dataloaders(self): + # here you can instantiate your dataloaders + # this method will be called by the StandardSolver init method + self.dataloaders = ... + + ... + + # For both train and valid stages, the StandardSolver relies on + # a share common_train_valid implementation that is in charge of + # accessing the appropriate loader, iterate over the data up to + # the specified number of updates_per_epoch, run the ``run_step`` + # function that you need to implement to specify the behavior + # and finally update the EMA and collect the metrics properly. + @abstractmethod + def run_step(self, idx: int, batch: tp.Any, metrics: dict): + """Perform one training or valid step on a given batch. + """ + ... # provide your implementation of the solver over a batch + + def train(self): + """Train stage. + """ + return self.common_train_valid('train') + + def valid(self): + """Valid stage. + """ + return self.common_train_valid('valid') + + @abstractmethod + def evaluate(self): + """Evaluate stage. + """ + ... # provide your implementation here! + + @abstractmethod + def generate(self): + """Generate stage. + """ + ... # provide your implementation here! +``` + +### About Epochs + +AudioCraft Solvers uses the concept of Epoch. One epoch doesn't necessarily mean one pass over the entire +dataset, but instead represent the smallest amount of computation that we want to work with before checkpointing. +Typically, we find that having an Epoch time around 30min is ideal both in terms of safety (checkpointing often enough) +and getting updates often enough. One Epoch is at least a `train` stage that lasts for `optim.updates_per_epoch` (2000 by default), +and a `valid` stage. You can control how long the valid stage takes with `dataset.valid.num_samples`. +Other stages (`evaluate`, `generate`) will only happen every X epochs, as given by `evaluate.every` and `generate.every`). + + +### Models + +In AudioCraft, a model is a container object that wraps one or more torch modules together +with potential processing logic to use in a solver. For example, a model would wrap an encoder module, +a quantisation bottleneck module, a decoder and some tensor processing logic. Each of the previous components +can be considered as a small « model unit » on its own but the container model is a practical component +to manipulate and train a set of modules together. + +### Datasets + +See the [dedicated documentation on datasets](./DATASETS.md). + +### Metrics + +See the [dedicated documentation on metrics](./METRICS.md). + +### Conditioners + +AudioCraft language models can be conditioned in various ways and the codebase offers a modular implementation +of different conditioners that can be potentially combined together. +Learn more in the [dedicated documentation on conditioning](./CONDITIONING.md). + +### Configuration + +AudioCraft's configuration is defined in yaml files and the framework relies on +[hydra](https://hydra.cc/docs/intro/) and [omegaconf](https://omegaconf.readthedocs.io/) to parse +and manipulate the configuration through Dora. + +##### :warning: Important considerations around configurations + +Our configuration management relies on Hydra and the concept of group configs to structure +and compose configurations. Updating the root default configuration files will then have +an impact on all solvers and tasks. +**One should never change the default configuration files. Instead they should use Hydra config groups in order to store custom configuration.** +Once this configuration is created and used for running experiments, you should not edit it anymore. + +Note that as we are using Dora as our experiment manager, all our experiment tracking is based on +signatures computed from delta between configurations. +**One must therefore ensure backward compatibilty of the configuration at all time.** +See [Dora's README](https://github.com/facebookresearch/dora) and the +[section below introduction Dora](#running-experiments-with-dora). + +##### Configuration structure + +The configuration is organized in config groups: +* `conditioner`: default values for conditioning modules. +* `dset`: contains all data source related information (paths to manifest files +and metadata for a given dataset). +* `model`: contains configuration for each model defined in AudioCraft and configurations +for different variants of models. +* `solver`: contains the default configuration for each solver as well as configuration +for each solver task, combining all the above components. +* `teams`: contains the cluster configuration per teams. See environment setup for more details. + +The `config.yaml` file is the main configuration that composes the above groups +and contains default configuration for AudioCraft. + +##### Solver's core configuration structure + +The core configuration structure shared across solver is available in `solvers/default.yaml`. + +##### Other configuration modules + +AudioCraft configuration contains the different setups we used for our research and publications. + +## Running experiments with Dora + +### Launching jobs + +Try launching jobs for different tasks locally with dora run: + +```shell +# run compression task with lightweight encodec +dora run solver=compression/debug +``` + +Most of the time, the jobs are launched through dora grids, for example: + +```shell +# run compression task through debug grid +dora grid compression.debug +``` + +Learn more about running experiments with Dora below. + +### A small introduction to Dora + +[Dora](https://github.com/facebookresearch/dora) is the experiment manager tool used in AudioCraft. +Check out the README to learn how Dora works. Here is a quick summary of what to know: +* An XP is a unique set of hyper-parameters with a given signature. The signature is a hash +of those hyper-parameters. We always refer to an XP with its signature, e.g. 9357e12e. We will see +after that one can retrieve the hyper-params and re-rerun it in a single command. +* In fact, the hash is defined as a delta between the base config and the one obtained +with the config overrides you passed from the command line. This means you must never change +the `conf/**.yaml` files directly., except for editing things like paths. Changing the default values +in the config files means the XP signature won't reflect that change, and wrong checkpoints might be reused. +I know, this is annoying, but the reason is that otherwise, any change to the config file would mean +that all XPs ran so far would see their signature change. + +#### Dora commands + +```shell +dora info -f 81de367c # this will show the hyper-parameter used by a specific XP. + # Be careful some overrides might present twice, and the right most one + # will give you the right value for it. + +dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c. + # `-d` is for distributed, it will use all available GPUs. + +dora run -d -f 81de367c dataset.batch_size=32 # start from the config of XP 81de367c but change some hyper-params. + # This will give you a new XP with a new signature (e.g. 3fe9c332). + +dora info -f SIG -t # will tail the log (if the XP has scheduled). +# if you need to access the logs of the process for rank > 0, in particular because a crash didn't happen in the main +# process, then use `dora info -f SIG` to get the main log name (finished into something like `/5037674_0_0_log.out`) +# and worker K can accessed as `/5037674_0_{K}_log.out`. +# This is only for scheduled jobs, for local distributed runs with `-d`, then you should go into the XP folder, +# and look for `worker_{K}.log` logs. +``` + +An XP runs from a specific folder based on its signature, under the +`//experiments/audiocraft/outputs/` folder. +You can safely interrupt a training and resume it, it will reuse any existing checkpoint, +as it will reuse the same folder. If you made some change to the code and need to ignore +a previous checkpoint you can use `dora run --clear [RUN ARGS]`. + +If you have a Slurm cluster, you can also use the dora grid command, e.g. + +```shell +# run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py` +dora grid my_grid_folder.my_grid_name +# Run the following will simply display the grid and also initialized the Dora experiments database. +# You can then simply refer to a config using its signature (e.g. as `dora run -f SIG`). +dora grid my_grid_folder.my_grid_name --dry_run --init +``` + +Please refer to the [Dora documentation](https://github.com/facebookresearch/dora) for more information. + + +#### Clearing up past experiments + +```shell +# This will cancel all the XPs and delete their folder and checkpoints. +# It will then reschedule them starting from scratch. +dora grid my_grid_folder.my_grid_name --clear +# The following will delete the folder and checkpoint for a single XP, +# and then run it afresh. +dora run [-f BASE_SIG] [ARGS] --clear +``` diff --git a/model_cards/AUDIOGEN_MODEL_CARD.md b/model_cards/AUDIOGEN_MODEL_CARD.md new file mode 100644 index 0000000000000000000000000000000000000000..5dcd23d8276d8f474043976672ea249d8b2a9dd1 --- /dev/null +++ b/model_cards/AUDIOGEN_MODEL_CARD.md @@ -0,0 +1,79 @@ +# AudioGen Model Card + +## Model details +**Organization developing the model:** The FAIR team of Meta AI. + +**Model date:** This version of AudioGen was trained between July 2023 and August 2023. + +**Model version:** This is version 2 of the model, not to be confused with the original AudioGen model published in ["AudioGen: Textually Guided Audio Generation"][audiogen]. +In this version (v2), AudioGen was trained on the same data, but with some other differences: +1. This model was trained on 10 seconds (vs. 5 seconds in v1). +2. The discrete representation used under the hood is extracted using a retrained EnCodec model on the environmental sound data, following the EnCodec setup detailed in the ["Simple and Controllable Music Generation" paper][musicgen]. +3. No audio mixing augmentations. + +**Model type:** AudioGen consists of an EnCodec model for audio tokenization, and an auto-regressive language model based on the transformer architecture for audio modeling. The released model has 1.5B parameters. + +**Paper or resource for more information:** More information can be found in the paper [AudioGen: Textually Guided Audio Generation](https://arxiv.org/abs/2209.15352). + +**Citation details:** See [AudioGen paper][audiogen] + +**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. + +**Where to send questions or comments about the model:** Questions and comments about AudioGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. + +## Intended use +**Primary intended use:** The primary use of AudioGen is research on AI-based audio generation, including: +- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science +- Generation of sound guided by text to understand current abilities of generative AI models by machine learning amateurs + +**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. + +**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate audio pieces that create hostile or alienating environments for people. This includes generating audio that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. + +## Metrics + +**Models performance measures:** We used the following objective measure to evaluate the model on a standard audio benchmark: +- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) +- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) + +Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: +- Overall quality of the audio samples; +- Text relevance to the provided text input; + +More details on performance measures and human studies can be found in the paper. + +**Decision thresholds:** Not applicable. + +## Evaluation datasets + +The model was evaluated on the [AudioCaps benchmark](https://audiocaps.github.io/). + +## Training datasets + +The model was trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects). + +## Evaluation results + +Below are the objective metrics obtained with the released model on AudioCaps (consisting of 10-second long samples). Note that the model differs from the original AudioGen model introduced in the paper, hence the difference in the metrics. + +| Model | Frechet Audio Distance | KLD | Text consistency | +|---|---|---|---| +| facebook/audiogen-medium | 1.77 | 1.58 | 0.30 | + +More information can be found in the paper [AudioGen: Textually Guided Audio Generation][audiogen], in the Experiments section. + +## Limitations and biases + +**Limitations:** +- The model is not able to generate realistic vocals. +- The model has been trained with English descriptions and will not perform as well in other languages. +- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. + +**Biases:** The datasets used for training may be lacking of diversity and are not representative of all possible sound events. The generated samples from the model will reflect the biases from the training data. + +**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. + +**Use cases:** Users must be aware of the biases, limitations and risks of the model. AudioGen is a model developed for artificial intelligence research on audio generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. + +[musicgen]: https://arxiv.org/abs/2306.05284 +[audiogen]: https://arxiv.org/abs/2209.15352 diff --git a/model_cards/MUSICGEN_MODEL_CARD.md b/model_cards/MUSICGEN_MODEL_CARD.md new file mode 100644 index 0000000000000000000000000000000000000000..9543136862fd900f7e94e2fde23784ffe1d5bf52 --- /dev/null +++ b/model_cards/MUSICGEN_MODEL_CARD.md @@ -0,0 +1,90 @@ +# MusicGen Model Card + +## Model details + +**Organization developing the model:** The FAIR team of Meta AI. + +**Model date:** MusicGen was trained between April 2023 and May 2023. + +**Model version:** This is the version 1 of the model. + +**Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation. + +**Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv]. + +**Citation details:** See [our paper][arxiv] + +**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. + +**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. + +## Intended use +**Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including: + +- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science +- Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs + +**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. + +**Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. + +## Metrics + +**Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark: + +- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) +- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) +- CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model + +Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: + +- Overall quality of the music samples; +- Text relevance to the provided text input; +- Adherence to the melody for melody-guided music generation. + +More details on performance measures and human studies can be found in the paper. + +**Decision thresholds:** Not applicable. + +## Evaluation datasets + +The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set. + +## Training datasets + +The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing. + +## Evaluation results + +Below are the objective metrics obtained on MusicCaps with the released model. Note that for the publicly released models, we had all the datasets go through a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs), in order to keep only the instrumental part. This explains the difference in objective metrics with the models used in the paper. + +| Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity | +|---|---|---|---|---| +| facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - | +| facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - | +| facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - | +| facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 | + +More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section. + +## Limitations and biases + +**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model. + +**Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). + +**Limitations:** + +- The model is not able to generate realistic vocals. +- The model has been trained with English descriptions and will not perform as well in other languages. +- The model does not perform equally well for all music styles and cultures. +- The model sometimes generates end of songs, collapsing to silence. +- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. + +**Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive. + +**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. + +**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. + +[arxiv]: https://arxiv.org/abs/2306.05284 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..6ab60f2fd7545c803fca221614704a075b8f2188 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] + +[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*] +ignore_missing_imports = True diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/scripts/__pycache__/__init__.cpython-39.pyc b/scripts/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2986430529d858fe7c28a5e588ade57b935a12f3 Binary files /dev/null and b/scripts/__pycache__/__init__.cpython-39.pyc differ diff --git a/scripts/__pycache__/mos.cpython-311.pyc b/scripts/__pycache__/mos.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1734225a1f9017feee5c9ad67b0e4a8d0b01f05 Binary files /dev/null and b/scripts/__pycache__/mos.cpython-311.pyc differ diff --git a/scripts/mos.py b/scripts/mos.py new file mode 100644 index 0000000000000000000000000000000000000000..a711c9ece23e72ed3a07032c7834ef7c56ab4f11 --- /dev/null +++ b/scripts/mos.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" +To run this script, from the root of the repo. Make sure to have Flask installed + + FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567 + # or if you have gunicorn + gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - + +""" +from collections import defaultdict +from functools import wraps +from hashlib import sha1 +import json +import math +from pathlib import Path +import random +import typing as tp + +from flask import Flask, redirect, render_template, request, session, url_for + +from audiocraft import train +from audiocraft.utils.samples.manager import get_samples_for_xps + + +SAMPLES_PER_PAGE = 8 +MAX_RATING = 5 +storage = Path(train.main.dora.dir / 'mos_storage') +storage.mkdir(exist_ok=True) +surveys = storage / 'surveys' +surveys.mkdir(exist_ok=True) +magma_root = Path(train.__file__).parent.parent +app = Flask('mos', static_folder=str(magma_root / 'scripts/static'), + template_folder=str(magma_root / 'scripts/templates')) +app.secret_key = b'audiocraft makes the best songs' + + +def normalize_path(path: Path): + """Just to make path a bit nicer, make them relative to the Dora root dir. + """ + path = path.resolve() + dora_dir = train.main.dora.dir.resolve() / 'xps' + return path.relative_to(dora_dir) + + +def get_full_path(normalized_path: Path): + """Revert `normalize_path`. + """ + return train.main.dora.dir.resolve() / 'xps' / normalized_path + + +def get_signature(xps: tp.List[str]): + """Return a signature for a list of XP signatures. + """ + return sha1(json.dumps(xps).encode()).hexdigest()[:10] + + +def ensure_logged(func): + """Ensure user is logged in. + """ + @wraps(func) + def _wrapped(*args, **kwargs): + user = session.get('user') + if user is None: + return redirect(url_for('login', redirect_to=request.url)) + return func(*args, **kwargs) + return _wrapped + + +@app.route('/login', methods=['GET', 'POST']) +def login(): + """Login user if not already, then redirect. + """ + user = session.get('user') + if user is None: + error = None + if request.method == 'POST': + user = request.form['user'] + if not user: + error = 'User cannot be empty' + if user is None or error: + return render_template('login.html', error=error) + assert user + session['user'] = user + redirect_to = request.args.get('redirect_to') + if redirect_to is None: + redirect_to = url_for('index') + return redirect(redirect_to) + + +@app.route('/', methods=['GET', 'POST']) +@ensure_logged +def index(): + """Offer to create a new study. + """ + errors = [] + if request.method == 'POST': + xps_or_grids = [part.strip() for part in request.form['xps'].split()] + xps = set() + for xp_or_grid in xps_or_grids: + xp_path = train.main.dora.dir / 'xps' / xp_or_grid + if xp_path.exists(): + xps.add(xp_or_grid) + continue + grid_path = train.main.dora.dir / 'grids' / xp_or_grid + if grid_path.exists(): + for child in grid_path.iterdir(): + if child.is_symlink(): + xps.add(child.name) + continue + errors.append(f'{xp_or_grid} is neither an XP nor a grid!') + assert xps or errors + blind = 'true' if request.form.get('blind') == 'on' else 'false' + xps = list(xps) + if not errors: + signature = get_signature(xps) + manifest = { + 'xps': xps, + } + survey_path = surveys / signature + survey_path.mkdir(exist_ok=True) + with open(survey_path / 'manifest.json', 'w') as f: + json.dump(manifest, f, indent=2) + return redirect(url_for('survey', blind=blind, signature=signature)) + return render_template('index.html', errors=errors) + + +@app.route('/survey/', methods=['GET', 'POST']) +@ensure_logged +def survey(signature): + success = request.args.get('success', False) + seed = int(request.args.get('seed', 4321)) + blind = request.args.get('blind', 'false') in ['true', 'on', 'True'] + exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True'] + exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True'] + max_epoch = int(request.args.get('max_epoch', '-1')) + survey_path = surveys / signature + assert survey_path.exists(), survey_path + + user = session['user'] + result_folder = survey_path / 'results' + result_folder.mkdir(exist_ok=True) + result_file = result_folder / f'{user}_{seed}.json' + + with open(survey_path / 'manifest.json') as f: + manifest = json.load(f) + + xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']] + names, ref_name = train.main.get_names(xps) + + samples_kwargs = { + 'exclude_prompted': exclude_prompted, + 'exclude_unprompted': exclude_unprompted, + 'max_epoch': max_epoch, + } + matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) # fetch latest epoch + models_by_id = { + id: [{ + 'xp': xps[idx], + 'xp_name': names[idx], + 'model_id': f'{xps[idx].sig}-{sample.id}', + 'sample': sample, + 'is_prompted': sample.prompt is not None, + 'errors': [], + } for idx, sample in enumerate(samples)] + for id, samples in matched_samples.items() + } + experiments = [ + {'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch} + for idx, xp in enumerate(xps) + ] + + keys = list(matched_samples.keys()) + keys.sort() + rng = random.Random(seed) + rng.shuffle(keys) + model_ids = keys[:SAMPLES_PER_PAGE] + + if blind: + for key in model_ids: + rng.shuffle(models_by_id[key]) + + ok = True + if request.method == 'POST': + all_samples_results = [] + for id in model_ids: + models = models_by_id[id] + result = { + 'id': id, + 'is_prompted': models[0]['is_prompted'], + 'models': {} + } + all_samples_results.append(result) + for model in models: + rating = request.form[model['model_id']] + if rating: + rating = int(rating) + assert rating <= MAX_RATING and rating >= 1 + result['models'][model['xp'].sig] = rating + model['rating'] = rating + else: + ok = False + model['errors'].append('Please rate this model.') + if ok: + result = { + 'results': all_samples_results, + 'seed': seed, + 'user': user, + 'blind': blind, + 'exclude_prompted': exclude_prompted, + 'exclude_unprompted': exclude_unprompted, + } + print(result) + with open(result_file, 'w') as f: + json.dump(result, f) + seed = seed + 1 + return redirect(url_for( + 'survey', signature=signature, blind=blind, seed=seed, + exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, + max_epoch=max_epoch, success=True)) + + ratings = list(range(1, MAX_RATING + 1)) + return render_template( + 'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success, + exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, + experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[], + ref_name=ref_name, already_filled=result_file.exists()) + + +@app.route('/audio/') +def audio(path: str): + full_path = Path('/') / path + assert full_path.suffix in [".mp3", ".wav"] + return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'} + + +def mean(x): + return sum(x) / len(x) + + +def std(x): + m = mean(x) + return math.sqrt(sum((i - m)**2 for i in x) / len(x)) + + +@app.route('/results/') +@ensure_logged +def results(signature): + + survey_path = surveys / signature + assert survey_path.exists(), survey_path + result_folder = survey_path / 'results' + result_folder.mkdir(exist_ok=True) + + # ratings per model, then per user. + ratings_per_model = defaultdict(list) + users = [] + for result_file in result_folder.iterdir(): + if result_file.suffix != '.json': + continue + with open(result_file) as f: + results = json.load(f) + users.append(results['user']) + for result in results['results']: + for sig, rating in result['models'].items(): + ratings_per_model[sig].append(rating) + + fmt = '{:.2f}' + models = [] + for model in sorted(ratings_per_model.keys()): + ratings = ratings_per_model[model] + + models.append({ + 'sig': model, + 'samples': len(ratings), + 'mean_rating': fmt.format(mean(ratings)), + # the value 1.96 was probably chosen to achieve some + # confidence interval assuming gaussianity. + 'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5), + }) + return render_template('results.html', signature=signature, models=models, users=users) diff --git a/scripts/resample_dataset.py b/scripts/resample_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..af5288712b8d2cde2d9814c747275e69f6e970c8 --- /dev/null +++ b/scripts/resample_dataset.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Resampling script. +""" +import argparse +from pathlib import Path +import shutil +import typing as tp + +import submitit +import tqdm + +from audiocraft.data.audio import audio_read, audio_write +from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files +from audiocraft.data.audio_utils import convert_audio +from audiocraft.environment import AudioCraftEnvironment + + +def read_txt_files(path: tp.Union[str, Path]): + with open(args.files_path) as f: + lines = [line.rstrip() for line in f] + print(f"Read {len(lines)} in .txt") + lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']] + print(f"Filtered and keep {len(lines)} from .txt") + return lines + + +def read_egs_files(path: tp.Union[str, Path]): + path = Path(path) + if path.is_dir(): + if (path / 'data.jsonl').exists(): + path = path / 'data.jsonl' + elif (path / 'data.jsonl.gz').exists(): + path = path / 'data.jsonl.gz' + else: + raise ValueError("Don't know where to read metadata from in the dir. " + "Expecting either a data.jsonl or data.jsonl.gz file but none found.") + meta = load_audio_meta(path) + return [m.path for m in meta] + + +def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None): + if task_index is None: + env = submitit.JobEnvironment() + task_index = env.global_rank + shard_index = node_index * args.tasks_per_node + task_index + + if args.files_path is None: + lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)] + else: + files_path = Path(args.files_path) + if files_path.suffix == '.txt': + print(f"Reading file list from .txt file: {args.files_path}") + lines = read_txt_files(args.files_path) + else: + print(f"Reading file list from egs: {args.files_path}") + lines = read_egs_files(args.files_path) + + total_files = len(lines) + print( + f"Total of {total_files} processed with {n_shards} shards. " + + f"Current idx = {shard_index} -> {total_files // n_shards} files to process" + ) + for idx, line in tqdm.tqdm(enumerate(lines)): + + # skip if not part of this shard + if idx % n_shards != shard_index: + continue + + path = str(AudioCraftEnvironment.apply_dataset_mappers(line)) + root_path = str(args.root_path) + if not root_path.endswith('/'): + root_path += '/' + assert path.startswith(str(root_path)), \ + f"Mismatch between path and provided root: {path} VS {root_path}" + + try: + metadata_path = Path(path).with_suffix('.json') + out_path = args.out_path / path[len(root_path):] + out_metadata_path = out_path.with_suffix('.json') + out_done_token = out_path.with_suffix('.done') + + # don't reprocess existing files + if out_done_token.exists(): + continue + + print(idx, out_path, path) + mix, sr = audio_read(path) + mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0) + # enforce simple stereo + out_channels = mix_channels + if out_channels > 2: + print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels") + out_channels = 2 + out_sr = args.sample_rate if args.sample_rate is not None else sr + out_wav = convert_audio(mix, sr, out_sr, out_channels) + audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr, + format=args.format, normalize=False, strategy='clip') + if metadata_path.exists(): + shutil.copy(metadata_path, out_metadata_path) + else: + print(f"No metadata found at {str(metadata_path)}") + out_done_token.touch() + except Exception as e: + print(f"Error processing file line: {line}, {e}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Resample dataset with SLURM.") + parser.add_argument( + "--log_root", + type=Path, + default=Path.home() / 'tmp' / 'resample_logs', + ) + parser.add_argument( + "--files_path", + type=Path, + help="List of files to process, either .txt (one file per line) or a jsonl[.gz].", + ) + parser.add_argument( + "--root_path", + type=Path, + required=True, + help="When rewriting paths, this will be the prefix to remove.", + ) + parser.add_argument( + "--out_path", + type=Path, + required=True, + help="When rewriting paths, `root_path` will be replaced by this.", + ) + parser.add_argument("--xp_name", type=str, default="shutterstock") + parser.add_argument( + "--nodes", + type=int, + default=4, + ) + parser.add_argument( + "--tasks_per_node", + type=int, + default=20, + ) + parser.add_argument( + "--cpus_per_task", + type=int, + default=4, + ) + parser.add_argument( + "--memory_gb", + type=int, + help="Memory in GB." + ) + parser.add_argument( + "--format", + type=str, + default="wav", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=32000, + ) + parser.add_argument( + "--channels", + type=int, + ) + parser.add_argument( + "--partition", + default='learnfair', + ) + parser.add_argument("--qos") + parser.add_argument("--account") + parser.add_argument("--timeout", type=int, default=4320) + parser.add_argument('--debug', action='store_true', help='debug mode (local run)') + args = parser.parse_args() + n_shards = args.tasks_per_node * args.nodes + if args.files_path is None: + print("Warning: --files_path not provided, not recommended when processing more than 10k files.") + if args.debug: + print("Debugging mode") + process_dataset(args, n_shards=n_shards, node_index=0, task_index=0) + else: + + log_folder = Path(args.log_root) / args.xp_name / '%j' + print(f"Logging to: {log_folder}") + log_folder.parent.mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=str(log_folder)) + if args.qos: + executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account) + else: + executor.update_parameters(slurm_partition=args.partition) + executor.update_parameters( + slurm_job_name=args.xp_name, timeout_min=args.timeout, + cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1) + if args.memory_gb: + executor.update_parameters(mem=f'{args.memory_gb}GB') + jobs = [] + with executor.batch(): + for node_index in range(args.nodes): + job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index) + jobs.append(job) + for job in jobs: + print(f"Waiting on job {job.job_id}") + job.results() diff --git a/scripts/static/style.css b/scripts/static/style.css new file mode 100644 index 0000000000000000000000000000000000000000..a0df7c63a0d2dd9a79f33f5d869ca31c9da87e8d --- /dev/null +++ b/scripts/static/style.css @@ -0,0 +1,113 @@ +body { + background-color: #fbfbfb; + margin: 0; +} + +select, input { + font-size: 1em; + max-width: 100%; +} + +.xp_name { + font-family: monospace; +} + +.simple_form { + background-color: #dddddd; + padding: 1em; + margin: 0.5em; +} + +textarea { + margin-top: 0.5em; + margin-bottom: 0.5em; +} + +.rating { + background-color: grey; + padding-top: 5px; + padding-bottom: 5px; + padding-left: 8px; + padding-right: 8px; + margin-right: 2px; + cursor:pointer; +} + +.rating_selected { + background-color: purple; +} + +.content { + font-family: sans-serif; + background-color: #f6f6f6; + padding: 40px; + margin: 0 auto; + max-width: 1000px; +} + +.track label { + padding-top: 10px; + padding-bottom: 10px; +} +.track { + padding: 15px; + margin: 5px; + background-color: #c8c8c8; +} + +.submit-big { + width:400px; + height:30px; + font-size: 20px; +} + +.error { + color: red; +} + +.ratings { + margin-left: 10px; +} + +.important { + font-weight: bold; +} + +.survey { + margin-bottom: 100px; +} + +.success { + color: #25901b; + font-weight: bold; +} +.warning { + color: #8a1f19; + font-weight: bold; +} +.track>section { + display: flex; + align-items: center; +} + +.prompt { + display: flex; + align-items: center; +} + +.track>section>div { + padding-left: 10px; +} + +audio { + max-width: 280px; + max-height: 40px; + margin-left: 10px; + margin-right: 10px; +} + +.special { + font-weight: bold; + color: #2c2c2c; +} + diff --git a/scripts/templates/base.html b/scripts/templates/base.html new file mode 100644 index 0000000000000000000000000000000000000000..f74668c19ecb83090a8a2d82c026bf417190ec6d --- /dev/null +++ b/scripts/templates/base.html @@ -0,0 +1,16 @@ + + + + {% block head %} + + + AudioCraft — MOS + {% endblock %} + + +
+

AudioCraft — MOS

+ {% block content %}{% endblock %} +
+ + diff --git a/scripts/templates/index.html b/scripts/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..7bd3afe9d933271bb922c1a0a534dd6b86fe67bc --- /dev/null +++ b/scripts/templates/index.html @@ -0,0 +1,28 @@ +{% extends "base.html" %} +{% block content %} + +

+ Welcome {{session['user']}} to the internal MOS assistant for AudioCraft. + You can create custom surveys between your models, that you can + evaluate yourself, or with the help of your teammates, by simply + sharing a link! +

+ +{% for error in errors %} +

{{error}}

+{% endfor %} +
+
+
+ +
+
+ +
+ + + +{% endblock %} diff --git a/scripts/templates/login.html b/scripts/templates/login.html new file mode 100644 index 0000000000000000000000000000000000000000..dd89ac654bceca14a9dec7d1a7f8206d1425a7a1 --- /dev/null +++ b/scripts/templates/login.html @@ -0,0 +1,20 @@ +{% extends "base.html" %} +{% block content %} + +

+ You must identify yourself first! We use a highly secured protocol + where you just decide your username, and that's it. No password, no encryption, + just pure trust. +

+ +{% if error %} +

{{error}}

+{% endif %} + + + + + +{% endblock %} diff --git a/scripts/templates/results.html b/scripts/templates/results.html new file mode 100644 index 0000000000000000000000000000000000000000..8ddce59f0f617a836db75c8bc9768db7f9f17511 --- /dev/null +++ b/scripts/templates/results.html @@ -0,0 +1,17 @@ +{% extends "base.html" %} +{% block content %} + +

Results for survey #{{signature}}

+

Checkout the survey page for details on the models.

+

The following users voted: + {% for user in users %} + {{user}} + {% endfor %} + +{% for model in models %} +

{{model['sig']}} ({{model['samples']}} samples)

+

Ratings: {{model['mean_rating']}} ± {{model['std_rating']}}

+ +{% endfor %} + +{% endblock %} diff --git a/scripts/templates/survey.html b/scripts/templates/survey.html new file mode 100644 index 0000000000000000000000000000000000000000..785d1e61b7ac21619416ba70dd4719ff250f3f4b --- /dev/null +++ b/scripts/templates/survey.html @@ -0,0 +1,131 @@ +{% extends "base.html" %} +{% block content %} +

Survey #{{signature}}

+{% if success %} +

Your ratings have been saved! +You have been moved to the next random seed, if you want +to keep rating more samples.

+{% endif %} +{% if already_filled %} +

You already rated those samples in the past, + filling this form will override your previous ratings. +

+{% endif %} +

Welcome {{session['user']}} to the survey #{{signature}}. +Go to the result page to check the results. Go to the home page to start a new survey. +

+ +{% for error in errors %} +

{{error}}

+{% endfor %} + +{% if not blind %} +

Base config is: {{ref_name}}

+

The following experiments are compared:

+
    + {% for experiment in experiments %} +
  • {{experiment.xp.sig}} ({{experiment.epoch}} epochs): {{experiment.name}}
  • + {% endfor %} +
+{% else %} +

This is a blind experiment, the order of all XPs is shuffled with every sample.

+{% endif %} +

The current random seed is {{seed}}. You can change it with the following form, and also update blind/non blind. +

+ + + + + + + +
+ +

Samples

+
+
+{% for id in model_ids %} +
+

{{id}}

+ {% for model in models_by_id[id] %} + {% if loop.index == 1 and model.is_prompted %} +
+

Prompt is

+ +

Ground truth is

+ +
+ {% endif %} + {% for err in model['errors'] %} +

{{err}}

+ {% endfor %} +
+ {% if not blind %} +

{{model.xp.sig}}:

+ {% endif %} + +

Rating:

+
+ {% for rating in ratings %} + {{rating}} + {% endfor %} + +
+

+
+ {% endfor %} +
+
+{% endfor %} + + +
+ +{% endblock %} diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..a00890009a88752714357210a73709a83b395849 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,14 @@ +[pep8] +max-line-length = 120 + +[flake8] +max-line-length = 120 + +[coverage:report] +include = audiocraft/* +omit = + audiocraft/environment.py + audiocraft/solvers/* + audiocraft/utils/* + audiocraft/*/loaders.py + audiocraft/*/builders.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..64e7d6fcb1092748f8151f6d3ed1767d3be1b34b --- /dev/null +++ b/setup.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +from setuptools import setup, find_packages + + +NAME = 'audiocraft' +DESCRIPTION = 'Audio generation research library for PyTorch' + +URL = 'https://github.com/facebookresearch/audiocraft' +AUTHOR = 'FAIR Speech & Audio' +EMAIL = 'defossez@meta.com, jadecopet@meta.com' +REQUIRES_PYTHON = '>=3.8.0' + +for line in open('audiocraft/__init__.py'): + line = line.strip() + if '__version__' in line: + context = {} + exec(line, context) + VERSION = context['__version__'] + +HERE = Path(__file__).parent + +try: + with open(HERE / "README.md", encoding='utf-8') as f: + long_description = '\n' + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')] + +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + author_email=EMAIL, + long_description=long_description, + long_description_content_type='text/markdown', + author=AUTHOR, + url=URL, + python_requires=REQUIRES_PYTHON, + install_requires=REQUIRED, + extras_require={ + 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'], + }, + packages=find_packages(), + package_data={'audiocraft': ['py.typed']}, + include_package_data=True, + license='MIT License', + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + 'License :: OSI Approved :: MIT License', + 'Topic :: Multimedia :: Sound/Audio', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/adversarial/__init__.py b/tests/adversarial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/adversarial/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/adversarial/test_discriminators.py b/tests/adversarial/test_discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..fad89a0ae4534dc7967b6ccda194b9fd1dedbffe --- /dev/null +++ b/tests/adversarial/test_discriminators.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import torch + +from audiocraft.adversarial.discriminators import ( + MultiPeriodDiscriminator, + MultiScaleDiscriminator, + MultiScaleSTFTDiscriminator +) + + +class TestMultiPeriodDiscriminator: + + def test_mpd_discriminator(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + periods = [1, 2, 3] + mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C) + logits, fmaps = mpd(t0) + + assert len(logits) == len(periods) + assert len(fmaps) == len(periods) + assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) + assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) + + +class TestMultiScaleDiscriminator: + + def test_msd_discriminator(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + scale_norms = ['weight_norm', 'weight_norm'] + msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C) + logits, fmaps = msd(t0) + + assert len(logits) == len(scale_norms) + assert len(fmaps) == len(scale_norms) + assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits]) + assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) + + +class TestMultiScaleStftDiscriminator: + + def test_msstftd_discriminator(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + n_filters = 4 + n_ffts = [128, 256, 64] + hop_lengths = [32, 64, 16] + win_lengths = [128, 256, 64] + + msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths, + win_lengths=win_lengths, in_channels=C) + logits, fmaps = msstftd(t0) + + assert len(logits) == len(n_ffts) + assert len(fmaps) == len(n_ffts) + assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) + assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) diff --git a/tests/adversarial/test_losses.py b/tests/adversarial/test_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..0e30bc3a6dde00003e13c00f15e977e39425063c --- /dev/null +++ b/tests/adversarial/test_losses.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import random + +import torch + +from audiocraft.adversarial import ( + AdversarialLoss, + get_adv_criterion, + get_real_criterion, + get_fake_criterion, + FeatureMatchingLoss, + MultiScaleDiscriminator, +) + + +class TestAdversarialLoss: + + def test_adversarial_single_multidiscriminator(self): + adv = MultiScaleDiscriminator() + optimizer = torch.optim.Adam( + adv.parameters(), + lr=1e-4, + ) + loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') + adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake) + + B, C, T = 4, 1, random.randint(1000, 5000) + real = torch.randn(B, C, T) + fake = torch.randn(B, C, T) + + disc_loss = adv_loss.train_adv(fake, real) + assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float) + + loss, loss_feat = adv_loss(fake, real) + assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) + # we did not specify feature loss + assert loss_feat.item() == 0. + + def test_adversarial_feat_loss(self): + adv = MultiScaleDiscriminator() + optimizer = torch.optim.Adam( + adv.parameters(), + lr=1e-4, + ) + loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') + feat_loss = FeatureMatchingLoss() + adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss) + + B, C, T = 4, 1, random.randint(1000, 5000) + real = torch.randn(B, C, T) + fake = torch.randn(B, C, T) + + loss, loss_feat = adv_loss(fake, real) + + assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) + assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float) + + +class TestGeneratorAdversarialLoss: + + def test_hinge_generator_adv_loss(self): + adv_loss = get_adv_criterion(loss_type='hinge') + + t0 = torch.randn(1, 2, 0) + t1 = torch.FloatTensor([1.0, 2.0, 3.0]) + + assert adv_loss(t0).item() == 0.0 + assert adv_loss(t1).item() == -2.0 + + def test_mse_generator_adv_loss(self): + adv_loss = get_adv_criterion(loss_type='mse') + + t0 = torch.randn(1, 2, 0) + t1 = torch.FloatTensor([1.0, 1.0, 1.0]) + t2 = torch.FloatTensor([2.0, 5.0, 5.0]) + + assert adv_loss(t0).item() == 0.0 + assert adv_loss(t1).item() == 0.0 + assert adv_loss(t2).item() == 11.0 + + +class TestDiscriminatorAdversarialLoss: + + def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor): + disc_loss_real = get_real_criterion(loss_type) + disc_loss_fake = get_fake_criterion(loss_type) + + loss = disc_loss_fake(fake) + disc_loss_real(real) + return loss + + def test_hinge_discriminator_adv_loss(self): + loss_type = 'hinge' + t0 = torch.FloatTensor([0.0, 0.0, 0.0]) + t1 = torch.FloatTensor([1.0, 2.0, 3.0]) + + assert self._disc_loss(loss_type, t0, t0).item() == 2.0 + assert self._disc_loss(loss_type, t1, t1).item() == 3.0 + + def test_mse_discriminator_adv_loss(self): + loss_type = 'mse' + + t0 = torch.FloatTensor([0.0, 0.0, 0.0]) + t1 = torch.FloatTensor([1.0, 1.0, 1.0]) + + assert self._disc_loss(loss_type, t0, t0).item() == 1.0 + assert self._disc_loss(loss_type, t1, t0).item() == 2.0 + + +class TestFeatureMatchingLoss: + + def test_features_matching_loss_base(self): + ft_matching_loss = FeatureMatchingLoss() + length = random.randrange(1, 100_000) + t1 = torch.randn(1, 2, length) + + loss = ft_matching_loss([t1], [t1]) + assert isinstance(loss, torch.Tensor) + assert loss.item() == 0.0 + + def test_features_matching_loss_raises_exception(self): + ft_matching_loss = FeatureMatchingLoss() + length = random.randrange(1, 100_000) + t1 = torch.randn(1, 2, length) + t2 = torch.randn(1, 2, length + 1) + + with pytest.raises(AssertionError): + ft_matching_loss([], []) + + with pytest.raises(AssertionError): + ft_matching_loss([t1], [t1, t1]) + + with pytest.raises(AssertionError): + ft_matching_loss([t1], [t2]) + + def test_features_matching_loss_output(self): + loss_nonorm = FeatureMatchingLoss(normalize=False) + loss_layer_normed = FeatureMatchingLoss(normalize=True) + + length = random.randrange(1, 100_000) + t1 = torch.randn(1, 2, length) + t2 = torch.randn(1, 2, length) + + assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0 + assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0 + + t3 = torch.FloatTensor([1.0, 2.0, 3.0]) + t4 = torch.FloatTensor([2.0, 10.0, 3.0]) + + assert loss_nonorm([t3], [t4]).item() == 3.0 + assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0 + + assert loss_layer_normed([t3], [t4]).item() == 3.0 + assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0 diff --git a/tests/common_utils/__init__.py b/tests/common_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74ffcfef96fec35c99b2a1a053a61f44f7a8bbe9 --- /dev/null +++ b/tests/common_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .temp_utils import TempDirMixin +from .wav_utils import get_batch_white_noise, get_white_noise, save_wav diff --git a/tests/common_utils/temp_utils.py b/tests/common_utils/temp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d896836799edcf1fee271409b390b3b6e4127 --- /dev/null +++ b/tests/common_utils/temp_utils.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile + + +class TempDirMixin: + """Mixin to provide easy access to temp dir. + """ + + temp_dir_ = None + + @classmethod + def get_base_temp_dir(cls): + # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = "AUDIOCRAFT_TEST_DIR" + if key in os.environ: + return os.environ[key] + if cls.temp_dir_ is None: + cls.temp_dir_ = tempfile.TemporaryDirectory() + return cls.temp_dir_.name + + @classmethod + def tearDownClass(cls): + if cls.temp_dir_ is not None: + try: + cls.temp_dir_.cleanup() + cls.temp_dir_ = None + except PermissionError: + # On Windows there is a know issue with `shutil.rmtree`, + # which fails intermittently. + # https://github.com/python/cpython/issues/74168 + # Following the above thread, we ignore it. + pass + super().tearDownClass() + + @property + def id(self): + return self.__class__.__name__ + + def get_temp_path(self, *paths): + temp_dir = os.path.join(self.get_base_temp_dir(), self.id) + path = os.path.join(temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + + def get_temp_dir(self, *paths): + temp_dir = os.path.join(self.get_base_temp_dir(), self.id) + path = os.path.join(temp_dir, *paths) + os.makedirs(path, exist_ok=True) + return path diff --git a/tests/common_utils/wav_utils.py b/tests/common_utils/wav_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a563ee1749a58217ece55c9a08b8d93c0fc386 --- /dev/null +++ b/tests/common_utils/wav_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import typing as tp + +import torch +import torchaudio + + +def get_white_noise(chs: int = 1, num_frames: int = 1): + wav = torch.randn(chs, num_frames) + return wav + + +def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): + wav = torch.randn(bs, chs, num_frames) + return wav + + +def save_wav(path: str, wav: torch.Tensor, sample_rate: int): + fp = Path(path) + kwargs: tp.Dict[str, tp.Any] = {} + if fp.suffix == '.wav': + kwargs['encoding'] = 'PCM_S' + kwargs['bits_per_sample'] = 16 + elif fp.suffix == '.mp3': + kwargs['compression'] = 320 + torchaudio.save(str(fp), wav, sample_rate, **kwargs) diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/data/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/data/test_audio.py b/tests/data/test_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..40c0d5ed69eff92a766dc6d176e532f0df6c2b5e --- /dev/null +++ b/tests/data/test_audio.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product +import random + +import numpy as np +import torch +import torchaudio + +from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read + +from ..common_utils import TempDirMixin, get_white_noise, save_wav + + +class TestInfo(TempDirMixin): + + def test_info_mp3(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + for sample_rate, ch in product(sample_rates, channels): + wav = get_white_noise(ch, int(sample_rate * duration)) + path = self.get_temp_path('sample_wav.mp3') + save_wav(path, wav, sample_rate) + info = audio_info(path) + assert info.sample_rate == sample_rate + assert info.channels == ch + # we cannot trust torchaudio for num_frames, so we don't check + + def _test_info_format(self, ext: str): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + wav = get_white_noise(ch, n_frames) + path = self.get_temp_path(f'sample_wav{ext}') + save_wav(path, wav, sample_rate) + info = audio_info(path) + assert info.sample_rate == sample_rate + assert info.channels == ch + assert np.isclose(info.duration, duration, atol=1e-5) + + def test_info_wav(self): + self._test_info_format('.wav') + + def test_info_flac(self): + self._test_info_format('.flac') + + def test_info_ogg(self): + self._test_info_format('.ogg') + + def test_info_m4a(self): + # TODO: generate m4a file programmatically + # self._test_info_format('.m4a') + pass + + +class TestRead(TempDirMixin): + + def test_read_full_wav(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) + path = self.get_temp_path('sample_wav.wav') + save_wav(path, wav, sample_rate) + read_wav, read_sr = audio_read(path) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[1] == wav.shape[1] + assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04) + + def test_read_partial_wav(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + read_duration = torch.rand(1).item() + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + read_frames = int(sample_rate * read_duration) + wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) + path = self.get_temp_path('sample_wav.wav') + save_wav(path, wav, sample_rate) + read_wav, read_sr = audio_read(path, 0, read_duration) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[1] == read_frames + assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04) + + def test_read_seek_time_wav(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + read_duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) + path = self.get_temp_path('sample_wav.wav') + save_wav(path, wav, sample_rate) + seek_time = torch.rand(1).item() + read_wav, read_sr = audio_read(path, seek_time, read_duration) + seek_frames = int(sample_rate * seek_time) + expected_frames = n_frames - seek_frames + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[1] == expected_frames + assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04) + + def test_read_seek_time_wav_padded(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + read_duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + read_frames = int(sample_rate * read_duration) + wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) + path = self.get_temp_path('sample_wav.wav') + save_wav(path, wav, sample_rate) + seek_time = torch.rand(1).item() + seek_frames = int(sample_rate * seek_time) + expected_frames = n_frames - seek_frames + read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True) + expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[1] == read_frames + assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04) + assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav) + + +class TestAvRead(TempDirMixin): + + def test_avread_seek_base(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 2. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + wav = get_white_noise(ch, n_frames) + path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav') + save_wav(path, wav, sample_rate) + for _ in range(100): + # seek will always load a full duration segment in the file + seek_time = random.uniform(0.0, 1.0) + seek_duration = random.uniform(0.001, 1.0) + read_wav, read_sr = _av_read(path, seek_time, seek_duration) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[-1] == int(seek_duration * sample_rate) + + def test_avread_seek_partial(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + wav = get_white_noise(ch, n_frames) + path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav') + save_wav(path, wav, sample_rate) + for _ in range(100): + # seek will always load a partial segment + seek_time = random.uniform(0.5, 1.) + seek_duration = 1. + expected_num_frames = n_frames - int(seek_time * sample_rate) + read_wav, read_sr = _av_read(path, seek_time, seek_duration) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[-1] == expected_num_frames + + def test_avread_seek_outofbound(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(sample_rate * duration) + wav = get_white_noise(ch, n_frames) + path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav') + save_wav(path, wav, sample_rate) + seek_time = 1.5 + read_wav, read_sr = _av_read(path, seek_time, 1.) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[-1] == 0 + + def test_avread_seek_edge(self): + sample_rates = [8000, 16_000] + # some of these values will have + # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1) + n_frames = [1000, 1001, 1002] + channels = [1, 2] + for sample_rate, ch, frames in product(sample_rates, channels, n_frames): + duration = frames / sample_rate + wav = get_white_noise(ch, frames) + path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav') + save_wav(path, wav, sample_rate) + seek_time = (frames - 1) / sample_rate + seek_frames = int(seek_time * sample_rate) + read_wav, read_sr = _av_read(path, seek_time, duration) + assert read_sr == sample_rate + assert read_wav.shape[0] == wav.shape[0] + assert read_wav.shape[-1] == (frames - seek_frames) + + +class TestAudioWrite(TempDirMixin): + + def test_audio_write_wav(self): + torch.manual_seed(1234) + sample_rates = [8000, 16_000] + n_frames = [1000, 1001, 1002] + channels = [1, 2] + strategies = ["peak", "clip", "rms"] + formats = ["wav", "mp3"] + for sample_rate, ch, frames in product(sample_rates, channels, n_frames): + for format_, strategy in product(formats, strategies): + wav = get_white_noise(ch, frames) + path = self.get_temp_path(f'pred_{sample_rate}_{ch}') + audio_write(path, wav, sample_rate, format_, strategy=strategy) + read_wav, read_sr = torchaudio.load(f'{path}.{format_}') + if format_ == "wav": + assert read_wav.shape == wav.shape + + if format_ == "wav" and strategy in ["peak", "rms"]: + rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max() + # for a Gaussian, the typical max scale will be less than ~5x the std. + # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that. + # For RMS target, rescaling leaves more headroom by default, leading + # to a 20x rescaling typically + atol = (5 if strategy == "peak" else 20) / 2**15 + delta = (rescaled_read_wav - wav).abs().max() + assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol) + formats = ["wav"] # faster unit tests diff --git a/tests/data/test_audio_dataset.py b/tests/data/test_audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b591ea6137f48d0d97fcd1243c5f5d258670a474 --- /dev/null +++ b/tests/data/test_audio_dataset.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from itertools import product +import json +import math +import os +import random +import typing as tp + +import pytest +import torch +from torch.utils.data import DataLoader + +from audiocraft.data.audio_dataset import ( + AudioDataset, + AudioMeta, + _get_audio_meta, + load_audio_meta, + save_audio_meta +) +from audiocraft.data.zip import PathInZip + +from ..common_utils import TempDirMixin, get_white_noise, save_wav + + +class TestAudioMeta(TempDirMixin): + + def test_get_audio_meta(self): + sample_rates = [8000, 16_000] + channels = [1, 2] + duration = 1. + for sample_rate, ch in product(sample_rates, channels): + n_frames = int(duration * sample_rate) + wav = get_white_noise(ch, n_frames) + path = self.get_temp_path('sample.wav') + save_wav(path, wav, sample_rate) + m = _get_audio_meta(path, minimal=True) + assert m.path == path, 'path does not match' + assert m.sample_rate == sample_rate, 'sample rate does not match' + assert m.duration == duration, 'duration does not match' + assert m.amplitude is None + assert m.info_path is None + + def test_save_audio_meta(self): + audio_meta = [ + AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')), + AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json')) + ] + empty_audio_meta = [] + for idx, meta in enumerate([audio_meta, empty_audio_meta]): + path = self.get_temp_path(f'data_{idx}_save.jsonl') + save_audio_meta(path, meta) + with open(path, 'r') as f: + lines = f.readlines() + read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines] + assert len(read_meta) == len(meta) + for m, read_m in zip(meta, read_meta): + assert m == read_m + + def test_load_audio_meta(self): + try: + import dora + except ImportError: + dora = None # type: ignore + + audio_meta = [ + AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')), + AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json')) + ] + empty_meta = [] + for idx, meta in enumerate([audio_meta, empty_meta]): + path = self.get_temp_path(f'data_{idx}_load.jsonl') + with open(path, 'w') as f: + for m in meta: + json_str = json.dumps(m.to_dict()) + '\n' + f.write(json_str) + read_meta = load_audio_meta(path) + assert len(read_meta) == len(meta) + for m, read_m in zip(meta, read_meta): + if dora: + m.path = dora.git_save.to_absolute_path(m.path) + assert m == read_m, f'original={m}, read={read_m}' + + +class TestAudioDataset(TempDirMixin): + + def _create_audio_files(self, + root_name: str, + num_examples: int, + durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.), + sample_rate: int = 16_000, + channels: int = 1): + root_dir = self.get_temp_dir(root_name) + for i in range(num_examples): + if isinstance(durations, float): + duration = durations + elif isinstance(durations, tuple) and len(durations) == 1: + duration = durations[0] + elif isinstance(durations, tuple) and len(durations) == 2: + duration = random.uniform(durations[0], durations[1]) + else: + assert False + n_frames = int(duration * sample_rate) + wav = get_white_noise(channels, n_frames) + path = os.path.join(root_dir, f'example_{i}.wav') + save_wav(path, wav, sample_rate) + return root_dir + + def _create_audio_dataset(self, + root_name: str, + total_num_examples: int, + durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.), + sample_rate: int = 16_000, + channels: int = 1, + segment_duration: tp.Optional[float] = None, + num_examples: int = 10, + shuffle: bool = True, + return_info: bool = False): + root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels) + dataset = AudioDataset.from_path(root_dir, + minimal_meta=True, + segment_duration=segment_duration, + num_samples=num_examples, + sample_rate=sample_rate, + channels=channels, + shuffle=shuffle, + return_info=return_info) + return dataset + + def test_dataset_full(self): + total_examples = 10 + min_duration, max_duration = 1., 4. + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=(min_duration, max_duration), + sample_rate=sample_rate, channels=channels, segment_duration=None) + assert len(dataset) == total_examples + assert dataset.sample_rate == sample_rate + assert dataset.channels == channels + for idx in range(len(dataset)): + sample = dataset[idx] + assert sample.shape[0] == channels + assert sample.shape[1] <= int(max_duration * sample_rate) + assert sample.shape[1] >= int(min_duration * sample_rate) + + def test_dataset_segment(self): + total_examples = 10 + num_samples = 20 + min_duration, max_duration = 1., 4. + segment_duration = 1. + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples) + assert len(dataset) == num_samples + assert dataset.sample_rate == sample_rate + assert dataset.channels == channels + for idx in range(len(dataset)): + sample = dataset[idx] + assert sample.shape[0] == channels + assert sample.shape[1] == int(segment_duration * sample_rate) + + def test_dataset_equal_audio_and_segment_durations(self): + total_examples = 1 + num_samples = 2 + audio_duration = 1. + segment_duration = 1. + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples) + assert len(dataset) == num_samples + assert dataset.sample_rate == sample_rate + assert dataset.channels == channels + for idx in range(len(dataset)): + sample = dataset[idx] + assert sample.shape[0] == channels + assert sample.shape[1] == int(segment_duration * sample_rate) + # the random seek_time adds variability on audio read + sample_1 = dataset[0] + sample_2 = dataset[1] + assert not torch.allclose(sample_1, sample_2) + + def test_dataset_samples(self): + total_examples = 1 + num_samples = 2 + audio_duration = 1. + segment_duration = 1. + sample_rate = 16_000 + channels = 1 + + create_dataset = partial( + self._create_audio_dataset, + 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples, + ) + + dataset = create_dataset(shuffle=True) + # when shuffle = True, we have different inputs for the same index across epoch + sample_1 = dataset[0] + sample_2 = dataset[0] + assert not torch.allclose(sample_1, sample_2) + + dataset_noshuffle = create_dataset(shuffle=False) + # when shuffle = False, we have same inputs for the same index across epoch + sample_1 = dataset_noshuffle[0] + sample_2 = dataset_noshuffle[0] + assert torch.allclose(sample_1, sample_2) + + def test_dataset_return_info(self): + total_examples = 10 + num_samples = 20 + min_duration, max_duration = 1., 4. + segment_duration = 1. + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True) + assert len(dataset) == num_samples + assert dataset.sample_rate == sample_rate + assert dataset.channels == channels + for idx in range(len(dataset)): + sample, segment_info = dataset[idx] + assert sample.shape[0] == channels + assert sample.shape[1] == int(segment_duration * sample_rate) + assert segment_info.sample_rate == sample_rate + assert segment_info.total_frames == int(segment_duration * sample_rate) + assert segment_info.n_frames <= int(segment_duration * sample_rate) + assert segment_info.seek_time >= 0 + + def test_dataset_return_info_no_segment_duration(self): + total_examples = 10 + num_samples = 20 + min_duration, max_duration = 1., 4. + segment_duration = None + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True) + assert len(dataset) == total_examples + assert dataset.sample_rate == sample_rate + assert dataset.channels == channels + for idx in range(len(dataset)): + sample, segment_info = dataset[idx] + assert sample.shape[0] == channels + assert sample.shape[1] == segment_info.total_frames + assert segment_info.sample_rate == sample_rate + assert segment_info.n_frames <= segment_info.total_frames + + def test_dataset_collate_fn(self): + total_examples = 10 + num_samples = 20 + min_duration, max_duration = 1., 4. + segment_duration = 1. + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False) + batch_size = 4 + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=0 + ) + for idx, batch in enumerate(dataloader): + assert batch.shape[0] == batch_size + + @pytest.mark.parametrize("segment_duration", [1.0, None]) + def test_dataset_with_meta_collate_fn(self, segment_duration): + total_examples = 10 + num_samples = 20 + min_duration, max_duration = 1., 4. + segment_duration = 1. + sample_rate = 16_000 + channels = 1 + dataset = self._create_audio_dataset( + 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, + channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True) + batch_size = 4 + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collater, + num_workers=0 + ) + for idx, batch in enumerate(dataloader): + wav, infos = batch + assert wav.shape[0] == batch_size + assert len(infos) == batch_size + + @pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [ + [1, True, True, 0.5, 0.5, 0.0], + [1, False, True, 0.25, 0.5, 0.25], + [1, True, False, 0.666, 0.333, 0.0], + [1, False, False, 0.333, 0.333, 0.333], + [None, False, False, 0.333, 0.333, 0.333]]) + def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist): + random.seed(1234) + rng = torch.Generator() + rng.manual_seed(1234) + + def _get_histogram(dataset, repetitions=20_000): + counts = {file_meta.path: 0. for file_meta in meta} + for _ in range(repetitions): + file_meta = dataset.sample_file(0, rng) + counts[file_meta.path] += 1 + return {name: count / repetitions for name, count in counts.items()} + + meta = [ + AudioMeta(path='a', duration=5, sample_rate=1, weight=2), + AudioMeta(path='b', duration=10, sample_rate=1, weight=None), + AudioMeta(path='c', duration=5, sample_rate=1, weight=0), + ] + dataset = AudioDataset( + meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight, + sample_on_duration=sample_on_duration) + hist = _get_histogram(dataset) + assert math.isclose(hist['a'], a_hist, abs_tol=0.01) + assert math.isclose(hist['b'], b_hist, abs_tol=0.01) + assert math.isclose(hist['c'], c_hist, abs_tol=0.01) + + def test_meta_duration_filter_all(self): + meta = [ + AudioMeta(path='a', duration=5, sample_rate=1, weight=2), + AudioMeta(path='b', duration=10, sample_rate=1, weight=None), + AudioMeta(path='c', duration=5, sample_rate=1, weight=0), + ] + try: + AudioDataset(meta, segment_duration=11, min_segment_ratio=1) + assert False + except AssertionError: + assert True + + def test_meta_duration_filter_long(self): + meta = [ + AudioMeta(path='a', duration=5, sample_rate=1, weight=2), + AudioMeta(path='b', duration=10, sample_rate=1, weight=None), + AudioMeta(path='c', duration=5, sample_rate=1, weight=0), + ] + dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7) + assert len(dataset) == 2 diff --git a/tests/data/test_audio_utils.py b/tests/data/test_audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0480671bb17281d61ce02bce6373a5ccec89fece --- /dev/null +++ b/tests/data/test_audio_utils.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import julius +import torch +import pytest + +from audiocraft.data.audio_utils import ( + _clip_wav, + convert_audio_channels, + convert_audio, + normalize_audio +) +from ..common_utils import get_batch_white_noise + + +class TestConvertAudioChannels: + + def test_convert_audio_channels_downmix(self): + b, c, t = 2, 3, 100 + audio = get_batch_white_noise(b, c, t) + mixed = convert_audio_channels(audio, channels=2) + assert list(mixed.shape) == [b, 2, t] + + def test_convert_audio_channels_nochange(self): + b, c, t = 2, 3, 100 + audio = get_batch_white_noise(b, c, t) + mixed = convert_audio_channels(audio, channels=c) + assert list(mixed.shape) == list(audio.shape) + + def test_convert_audio_channels_upmix(self): + b, c, t = 2, 1, 100 + audio = get_batch_white_noise(b, c, t) + mixed = convert_audio_channels(audio, channels=3) + assert list(mixed.shape) == [b, 3, t] + + def test_convert_audio_channels_upmix_error(self): + b, c, t = 2, 2, 100 + audio = get_batch_white_noise(b, c, t) + with pytest.raises(ValueError): + convert_audio_channels(audio, channels=3) + + +class TestConvertAudio: + + def test_convert_audio_channels_downmix(self): + b, c, dur = 2, 3, 4. + sr = 128 + audio = get_batch_white_noise(b, c, int(sr * dur)) + out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2) + assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]] + + def test_convert_audio_channels_upmix(self): + b, c, dur = 2, 1, 4. + sr = 128 + audio = get_batch_white_noise(b, c, int(sr * dur)) + out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3) + assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]] + + def test_convert_audio_upsample(self): + b, c, dur = 2, 1, 4. + sr = 2 + new_sr = 3 + audio = get_batch_white_noise(b, c, int(sr * dur)) + out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) + out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) + assert torch.allclose(out, out_j) + + def test_convert_audio_resample(self): + b, c, dur = 2, 1, 4. + sr = 3 + new_sr = 2 + audio = get_batch_white_noise(b, c, int(sr * dur)) + out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) + out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) + assert torch.allclose(out, out_j) + + +class TestNormalizeAudio: + + def test_clip_wav(self): + b, c, dur = 2, 1, 4. + sr = 3 + audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) + _clip_wav(audio) + assert audio.abs().max() <= 1 + + def test_normalize_audio_clip(self): + b, c, dur = 2, 1, 4. + sr = 3 + audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) + norm_audio = normalize_audio(audio, strategy='clip') + assert norm_audio.abs().max() <= 1 + + def test_normalize_audio_rms(self): + b, c, dur = 2, 1, 4. + sr = 3 + audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) + norm_audio = normalize_audio(audio, strategy='rms') + assert norm_audio.abs().max() <= 1 + + def test_normalize_audio_peak(self): + b, c, dur = 2, 1, 4. + sr = 3 + audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) + norm_audio = normalize_audio(audio, strategy='peak') + assert norm_audio.abs().max() <= 1 diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/losses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..b6681e12c453dea5aeba738ab252d1923b7e0941 --- /dev/null +++ b/tests/losses/test_losses.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import torch + +from audiocraft.losses import ( + MelSpectrogramL1Loss, + MultiScaleMelSpectrogramLoss, + MRSTFTLoss, + SISNR, + STFTLoss, +) + + +def test_mel_l1_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) + loss = mel_l1(t1, t2) + loss_same = mel_l1(t1, t1) + + assert isinstance(loss, torch.Tensor) + assert isinstance(loss_same, torch.Tensor) + assert loss_same.item() == 0.0 + + +def test_msspec_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) + loss = msspec(t1, t2) + loss_same = msspec(t1, t1) + + assert isinstance(loss, torch.Tensor) + assert isinstance(loss_same, torch.Tensor) + assert loss_same.item() == 0.0 + + +def test_mrstft_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + mrstft = MRSTFTLoss() + loss = mrstft(t1, t2) + + assert isinstance(loss, torch.Tensor) + + +def test_sisnr_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + sisnr = SISNR() + loss = sisnr(t1, t2) + + assert isinstance(loss, torch.Tensor) + + +def test_stft_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + mrstft = STFTLoss() + loss = mrstft(t1, t2) + + assert isinstance(loss, torch.Tensor) diff --git a/tests/models/test_audiogen.py b/tests/models/test_audiogen.py new file mode 100644 index 0000000000000000000000000000000000000000..3850af066cedd5ea38bd9aead9634d6aaf938218 --- /dev/null +++ b/tests/models/test_audiogen.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from audiocraft.models import AudioGen + + +class TestAudioGenModel: + def get_audiogen(self): + ag = AudioGen.get_pretrained(name='debug', device='cpu') + ag.set_generation_params(duration=2.0, extend_stride=2.) + return ag + + def test_base(self): + ag = self.get_audiogen() + assert ag.frame_rate == 25 + assert ag.sample_rate == 16000 + assert ag.audio_channels == 1 + + def test_generate_continuation(self): + ag = self.get_audiogen() + prompt = torch.randn(3, 1, 16000) + wav = ag.generate_continuation(prompt, 16000) + assert list(wav.shape) == [3, 1, 32000] + + prompt = torch.randn(2, 1, 16000) + wav = ag.generate_continuation( + prompt, 16000, ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 32000] + + prompt = torch.randn(2, 1, 16000) + with pytest.raises(AssertionError): + wav = ag.generate_continuation( + prompt, 16000, ['youpi', 'lapin dort', 'one too many']) + + def test_generate(self): + ag = self.get_audiogen() + wav = ag.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 32000] + + def test_generate_long(self): + ag = self.get_audiogen() + ag.max_duration = 3. + ag.set_generation_params(duration=4., extend_stride=2.) + wav = ag.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 16000 * 4] diff --git a/tests/models/test_encodec_model.py b/tests/models/test_encodec_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9c1db3f69a45f02451b71da95f44356811acbb --- /dev/null +++ b/tests/models/test_encodec_model.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import numpy as np +import torch + +from audiocraft.models import EncodecModel +from audiocraft.modules import SEANetEncoder, SEANetDecoder +from audiocraft.quantization import DummyQuantizer + + +class TestEncodecModel: + + def _create_encodec_model(self, + sample_rate: int, + channels: int, + dim: int = 5, + n_filters: int = 3, + n_residual_layers: int = 1, + ratios: list = [5, 4, 3, 2], + **kwargs): + frame_rate = np.prod(ratios) + encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters, + n_residual_layers=n_residual_layers, ratios=ratios) + decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters, + n_residual_layers=n_residual_layers, ratios=ratios) + quantizer = DummyQuantizer() + model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, + sample_rate=sample_rate, channels=channels, **kwargs) + return model + + def test_model(self): + random.seed(1234) + sample_rate = 24_000 + channels = 1 + model = self._create_encodec_model(sample_rate, channels) + for _ in range(10): + length = random.randrange(1, 10_000) + x = torch.randn(2, channels, length) + res = model(x) + assert res.x.shape == x.shape + + def test_model_renorm(self): + random.seed(1234) + sample_rate = 24_000 + channels = 1 + model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False) + model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True) + + for _ in range(10): + length = random.randrange(1, 10_000) + x = torch.randn(2, channels, length) + codes, scales = model_nonorm.encode(x) + codes, scales = model_renorm.encode(x) + assert scales is not None diff --git a/tests/models/test_multibanddiffusion.py b/tests/models/test_multibanddiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2702a3cb5fe402bf96911dbc992d2749cb18a4c0 --- /dev/null +++ b/tests/models/test_multibanddiffusion.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import numpy as np +import torch +from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess +from audiocraft.models import EncodecModel, DiffusionUnet +from audiocraft.modules import SEANetEncoder, SEANetDecoder +from audiocraft.modules.diffusion_schedule import NoiseSchedule +from audiocraft.quantization import DummyQuantizer + + +class TestMBD: + + def _create_mbd(self, + sample_rate: int, + channels: int, + n_filters: int = 3, + n_residual_layers: int = 1, + ratios: list = [5, 4, 3, 2], + num_steps: int = 1000, + codec_dim: int = 128, + **kwargs): + frame_rate = np.prod(ratios) + encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters, + n_residual_layers=n_residual_layers, ratios=ratios) + decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters, + n_residual_layers=n_residual_layers, ratios=ratios) + quantizer = DummyQuantizer() + compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, + sample_rate=sample_rate, channels=channels, **kwargs) + diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim) + schedule = NoiseSchedule(device='cpu', num_steps=num_steps) + DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule) + mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model) + return mbd + + def test_model(self): + random.seed(1234) + sample_rate = 24_000 + channels = 1 + codec_dim = 128 + mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim) + for _ in range(10): + length = random.randrange(1, 10_000) + x = torch.randn(2, channels, length) + res = mbd.regenerate(x, sample_rate) + assert res.shape == x.shape diff --git a/tests/models/test_musicgen.py b/tests/models/test_musicgen.py new file mode 100644 index 0000000000000000000000000000000000000000..65618a9e2ef5bb382694b50b23dd50958d590d4e --- /dev/null +++ b/tests/models/test_musicgen.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from audiocraft.models import MusicGen + + +class TestMusicGenModel: + def get_musicgen(self): + mg = MusicGen.get_pretrained(name='debug', device='cpu') + mg.set_generation_params(duration=2.0, extend_stride=2.) + return mg + + def test_base(self): + mg = self.get_musicgen() + assert mg.frame_rate == 25 + assert mg.sample_rate == 32000 + assert mg.audio_channels == 1 + + def test_generate_unconditional(self): + mg = self.get_musicgen() + wav = mg.generate_unconditional(3) + assert list(wav.shape) == [3, 1, 64000] + + def test_generate_continuation(self): + mg = self.get_musicgen() + prompt = torch.randn(3, 1, 32000) + wav = mg.generate_continuation(prompt, 32000) + assert list(wav.shape) == [3, 1, 64000] + + prompt = torch.randn(2, 1, 32000) + wav = mg.generate_continuation( + prompt, 32000, ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 64000] + + prompt = torch.randn(2, 1, 32000) + with pytest.raises(AssertionError): + wav = mg.generate_continuation( + prompt, 32000, ['youpi', 'lapin dort', 'one too many']) + + def test_generate(self): + mg = self.get_musicgen() + wav = mg.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 64000] + + def test_generate_long(self): + mg = self.get_musicgen() + mg.max_duration = 3. + mg.set_generation_params(duration=4., extend_stride=2.) + wav = mg.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 32000 * 4] diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/modules/test_activations.py b/tests/modules/test_activations.py new file mode 100644 index 0000000000000000000000000000000000000000..24e30d4cd87683430488bfa442e098b34229a5ee --- /dev/null +++ b/tests/modules/test_activations.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from audiocraft.modules.activations import CustomGLU + + +class TestActivations: + def test_custom_glu_calculation(self): + + activation = CustomGLU(nn.Identity()) + + initial_shape = (4, 8, 8) + + part_a = torch.ones(initial_shape) * 2 + part_b = torch.ones(initial_shape) * -1 + input = torch.cat((part_a, part_b), dim=-1) + + output = activation(input) + + # ensure all dimensions match initial shape + assert output.shape == initial_shape + # ensure the gating was calculated correctly a * f(b) + assert torch.all(output == -2).item() diff --git a/tests/modules/test_codebooks_patterns.py b/tests/modules/test_codebooks_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..b658f4779a369f9ec8dde692a61b7f0fe3485724 --- /dev/null +++ b/tests/modules/test_codebooks_patterns.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from audiocraft.modules.codebooks_patterns import ( + DelayedPatternProvider, + ParallelPatternProvider, + Pattern, + UnrolledPatternProvider, +) + + +class TestParallelPatternProvider: + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [0, 1, 16, 100]) + def test_get_pattern(self, n_q: int, timesteps: int): + provider = ParallelPatternProvider(n_q) + pattern = provider.get_pattern(timesteps) + # + 1 to account for 1st step + assert len(pattern.layout) == timesteps + 1 + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [8, 16, 100]) + def test_pattern_content(self, n_q: int, timesteps: int): + provider = ParallelPatternProvider(n_q) + pattern = provider.get_pattern(timesteps) + for s, v in enumerate(pattern.layout): + for i, code in enumerate(v): + assert i == code.q + assert code.t == s - 1 # account for the 1st empty step + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [8, 16, 100]) + def test_pattern_max_delay(self, n_q: int, timesteps: int): + provider = ParallelPatternProvider(n_q) + pattern = provider.get_pattern(timesteps) + assert pattern.max_delay == 0 + assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay + + +class TestDelayedPatternProvider: + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [0, 1, 16, 100]) + def test_get_pattern(self, n_q: int, timesteps: int): + delays = [ + list(range(n_q)), + [0] + [1] * (n_q - 1), + [0] + [4] * (n_q - 1), + ] + for delay in delays: + provider = DelayedPatternProvider(n_q, delay) + pattern = provider.get_pattern(timesteps) + # + 1 to account for 1st step + assert len(pattern.layout) == timesteps + max(delay) + 1 + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [8, 16, 100]) + def test_pattern_content(self, n_q: int, timesteps: int): + provider = DelayedPatternProvider(n_q) + pattern = provider.get_pattern(timesteps) + for s, v in enumerate(pattern.layout): + for i, code in enumerate(v): + assert i == code.q + assert code.t == max(0, s - code.q - 1) + + @pytest.mark.parametrize("timesteps", [8, 16, 100]) + @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]]) + def test_pattern_max_delay(self, timesteps: int, delay: list): + provider = DelayedPatternProvider(len(delay), delay) + pattern = provider.get_pattern(timesteps) + assert pattern.max_delay == max(delay) + assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay + + +class TestUnrolledPatternProvider: + + @pytest.mark.parametrize("timesteps", [0, 1, 16]) + @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]]) + @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]]) + def test_get_pattern(self, timesteps: int, flattening: list, delays: list): + n_q = len(flattening) + max_delay = max(delays) + provider = UnrolledPatternProvider(n_q, flattening, delays) + pattern = provider.get_pattern(timesteps) + assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay + + @pytest.mark.parametrize("timesteps", [0, 1, 16]) + @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]]) + @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]]) + def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list): + n_q = len(flattening) + max_delay = max(delays) + provider = UnrolledPatternProvider(n_q, flattening, delays) + pattern = provider.get_pattern(timesteps) + assert pattern.max_delay == max_delay + + +class TestPattern: + + def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): + """Reference method to build the sequence from the pattern without using fancy scatter.""" + bs, n_q, T = z.shape + z = z.cpu().numpy() + assert n_q == pattern.n_q + assert T <= pattern.timesteps + inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy() + inp[:] = special_token + for s, v in enumerate(pattern.layout): + for (t, q) in v: + if t < T: + inp[:, q, s] = z[:, q, t] + return torch.from_numpy(inp) + + def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): + """Reference method to revert the sequence from the pattern without using fancy scatter.""" + z = z.cpu().numpy() + bs, n_q, S = z.shape + assert pattern.n_q == n_q + inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy() + inp[:] = special_token + for s, v in enumerate(pattern.layout): + for (t, q) in v: + if t < pattern.timesteps: + inp[:, q, t] = z[:, q, s] + return torch.from_numpy(inp) + + def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float): + """Reference method to revert the logits from the pattern without using fancy scatter.""" + z = z.cpu().numpy() + bs, card, n_q, S = z.shape + assert pattern.n_q == n_q + ref_layout = pattern.layout + inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy() + inp[:] = special_token + for s, v in enumerate(ref_layout[1:]): + if s < S: + for (t, q) in v: + if t < pattern.timesteps: + inp[:, :, q, t] = z[:, :, q, s] + return torch.from_numpy(inp) + + def _get_pattern_providers(self, n_q: int): + pattern_provider_1 = ParallelPatternProvider(n_q) + pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q))) + pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1)) + pattern_provider_4 = UnrolledPatternProvider( + n_q, flattening=list(range(n_q)), delays=[0] * n_q + ) + pattern_provider_5 = UnrolledPatternProvider( + n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q + ) + pattern_provider_6 = UnrolledPatternProvider( + n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1) + ) + return [ + pattern_provider_1, + pattern_provider_2, + pattern_provider_3, + pattern_provider_4, + pattern_provider_5, + pattern_provider_6, + ] + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [16, 72]) + def test_build_pattern_sequence(self, n_q: int, timesteps: int): + bs = 2 + card = 256 + special_token = card + + pattern_providers = self._get_pattern_providers(n_q) + for pattern_provider in pattern_providers: + pattern = pattern_provider.get_pattern(timesteps) + # we can correctly build the sequence from the pattern + z = torch.randint(0, card, (bs, n_q, timesteps)) + ref_res = self.ref_build_pattern_sequence(z, pattern, special_token) + res, indexes, mask = pattern.build_pattern_sequence(z, special_token) + assert (res == ref_res).float().mean() == 1.0 + + # expected assertion fails on the number of timesteps + invalid_timesteps = [timesteps + 1] + if pattern.num_sequence_steps != pattern.timesteps: + invalid_timesteps.append(pattern.num_sequence_steps) + for i_timesteps in invalid_timesteps: + z2 = torch.randint(0, card, (bs, n_q, i_timesteps)) + with pytest.raises(AssertionError): + pattern.build_pattern_sequence(z2, special_token) + + # expected assertion fails on the number of codebooks + invalid_qs = [0, n_q - 1, n_q + 1] + for i_q in invalid_qs: + z3 = torch.randint(0, card, (bs, i_q, timesteps)) + with pytest.raises(AssertionError): + pattern.build_pattern_sequence(z3, special_token) + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [16, 72]) + def test_revert_pattern_sequence(self, n_q: int, timesteps: int): + bs = 2 + card = 256 + special_token = card + + pattern_providers = self._get_pattern_providers(n_q) + for pattern_provider in pattern_providers: + pattern = pattern_provider.get_pattern(timesteps) + # this works assuming previous tests are successful + z = torch.randint(0, card, (bs, n_q, timesteps)) + s = self.ref_build_pattern_sequence(z, pattern, special_token) + ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token) + # ensure our reference script retrieve the original sequence + assert z.shape == ref_out.shape + assert (z == ref_out).float().mean() == 1.0 + # now we can test the scatter version + out, indexes, mask = pattern.revert_pattern_sequence(s, special_token) + assert out.shape == ref_out.shape + assert (out == ref_out).float().mean() == 1.0 + + @pytest.mark.parametrize("n_q", [1, 4, 32]) + @pytest.mark.parametrize("timesteps", [16, 72]) + @pytest.mark.parametrize("card", [1, 2, 256, 1024]) + def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int): + bs = 2 + special_token = card + logits_special_token = float('nan') + + pattern_providers = self._get_pattern_providers(n_q) + for pattern_provider in pattern_providers: + pattern = pattern_provider.get_pattern(timesteps) + # this works assuming previous tests are successful + z = torch.randint(0, card, (bs, n_q, timesteps)) + s = self.ref_build_pattern_sequence(z, pattern, special_token) + logits = torch.randn((bs, card, n_q, s.shape[-1])) + ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token) + # ensure our reference script retrieve the original sequence + assert ref_out.shape == torch.Size([bs, card, n_q, timesteps]) + # now we can test the scatter version + out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token) + assert out.shape == ref_out.shape + assert (out == ref_out).float().mean() == 1.0 diff --git a/tests/modules/test_conv.py b/tests/modules/test_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..28fbc4f1a0ebaf41b56947b767958ae696e75eec --- /dev/null +++ b/tests/modules/test_conv.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product +import math +import random + +import pytest +import torch +from torch import nn + +from audiocraft.modules import ( + NormConv1d, + NormConvTranspose1d, + StreamableConv1d, + StreamableConvTranspose1d, + pad1d, + unpad1d, +) + + +def test_get_extra_padding_for_conv1d(): + # TODO: Implement me! + pass + + +def test_pad1d_zeros(): + x = torch.randn(1, 1, 20) + + xp1 = pad1d(x, (0, 5), mode='constant', value=0.) + assert xp1.shape[-1] == 25 + xp2 = pad1d(x, (5, 5), mode='constant', value=0.) + assert xp2.shape[-1] == 30 + xp3 = pad1d(x, (0, 0), mode='constant', value=0.) + assert xp3.shape[-1] == 20 + xp4 = pad1d(x, (10, 30), mode='constant', value=0.) + assert xp4.shape[-1] == 60 + + with pytest.raises(AssertionError): + pad1d(x, (-1, 0), mode='constant', value=0.) + + with pytest.raises(AssertionError): + pad1d(x, (0, -1), mode='constant', value=0.) + + with pytest.raises(AssertionError): + pad1d(x, (-1, -1), mode='constant', value=0.) + + +def test_pad1d_reflect(): + x = torch.randn(1, 1, 20) + + xp1 = pad1d(x, (0, 5), mode='reflect', value=0.) + assert xp1.shape[-1] == 25 + xp2 = pad1d(x, (5, 5), mode='reflect', value=0.) + assert xp2.shape[-1] == 30 + xp3 = pad1d(x, (0, 0), mode='reflect', value=0.) + assert xp3.shape[-1] == 20 + xp4 = pad1d(x, (10, 30), mode='reflect', value=0.) + assert xp4.shape[-1] == 60 + + with pytest.raises(AssertionError): + pad1d(x, (-1, 0), mode='reflect', value=0.) + + with pytest.raises(AssertionError): + pad1d(x, (0, -1), mode='reflect', value=0.) + + with pytest.raises(AssertionError): + pad1d(x, (-1, -1), mode='reflect', value=0.) + + +def test_unpad1d(): + x = torch.randn(1, 1, 20) + + u1 = unpad1d(x, (5, 5)) + assert u1.shape[-1] == 10 + u2 = unpad1d(x, (0, 5)) + assert u2.shape[-1] == 15 + u3 = unpad1d(x, (5, 0)) + assert u3.shape[-1] == 15 + u4 = unpad1d(x, (0, 0)) + assert u4.shape[-1] == x.shape[-1] + + with pytest.raises(AssertionError): + unpad1d(x, (-1, 0)) + + with pytest.raises(AssertionError): + unpad1d(x, (0, -1)) + + with pytest.raises(AssertionError): + unpad1d(x, (-1, -1)) + + +class TestNormConv1d: + + def test_norm_conv1d_modules(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + C_out, kernel_size, stride = 1, 4, 1 + expected_out_length = int((T - kernel_size) / stride + 1) + wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm') + gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm') + nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none') + + assert isinstance(wn_conv.norm, nn.Identity) + assert isinstance(wn_conv.conv, nn.Conv1d) + + assert isinstance(gn_conv.norm, nn.GroupNorm) + assert isinstance(gn_conv.conv, nn.Conv1d) + + assert isinstance(nn_conv.norm, nn.Identity) + assert isinstance(nn_conv.conv, nn.Conv1d) + + for conv_layer in [wn_conv, gn_conv, nn_conv]: + out = conv_layer(t0) + assert isinstance(out, torch.Tensor) + assert list(out.shape) == [N, C_out, expected_out_length] + + +class TestNormConvTranspose1d: + + def test_normalizations(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + C_out, kernel_size, stride = 1, 4, 1 + expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1 + + wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm') + gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm') + nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none') + + assert isinstance(wn_convtr.norm, nn.Identity) + assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d) + + assert isinstance(gn_convtr.norm, nn.GroupNorm) + assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d) + + assert isinstance(nn_convtr.norm, nn.Identity) + assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d) + + for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]: + out = convtr_layer(t0) + assert isinstance(out, torch.Tensor) + assert list(out.shape) == [N, C_out, expected_out_length] + + +class TestStreamableConv1d: + + def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation): + # StreamableConv1d internally pads to make sure that the last window is full + padding_total = (kernel_size - 1) * dilation - (stride - 1) + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length // stride + + def test_streamable_conv1d(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + C_out = 1 + + # conv params are [(kernel_size, stride, dilation)] + conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)] + for causal, (kernel_size, stride, dilation) in product([False, True], conv_params): + expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation) + sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal) + out = sconv(t0) + assert isinstance(out, torch.Tensor) + print(list(out.shape), [N, C_out, expected_out_length]) + assert list(out.shape) == [N, C_out, expected_out_length] + + +class TestStreamableConvTranspose1d: + + def get_streamable_convtr1d_output_length(self, length, kernel_size, stride): + padding_total = (kernel_size - stride) + return (length - 1) * stride - padding_total + (kernel_size - 1) + 1 + + def test_streamable_convtr1d(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + C_out = 1 + + with pytest.raises(AssertionError): + StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5) + StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.) + StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2) + + # causal params are [(causal, trim_right)] + causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)] + # conv params are [(kernel_size, stride)] + conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)] + for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params): + expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride) + sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, + causal=causal, trim_right_ratio=trim_right_ratio) + out = sconvtr(t0) + assert isinstance(out, torch.Tensor) + assert list(out.shape) == [N, C_out, expected_out_length] diff --git a/tests/modules/test_lstm.py b/tests/modules/test_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..1248964c8191e19f27661f0974bef9cc967eb015 --- /dev/null +++ b/tests/modules/test_lstm.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +import torch + +from audiocraft.modules.lstm import StreamableLSTM + + +class TestStreamableLSTM: + + def test_lstm(self): + B, C, T = 4, 2, random.randint(1, 100) + + lstm = StreamableLSTM(C, 3, skip=False) + x = torch.randn(B, C, T) + y = lstm(x) + + print(y.shape) + assert y.shape == torch.Size([B, C, T]) + + def test_lstm_skip(self): + B, C, T = 4, 2, random.randint(1, 100) + + lstm = StreamableLSTM(C, 3, skip=True) + x = torch.randn(B, C, T) + y = lstm(x) + + assert y.shape == torch.Size([B, C, T]) diff --git a/tests/modules/test_rope.py b/tests/modules/test_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..067c6f067acbf27fb0fef5c2b812c22474c4fcd0 --- /dev/null +++ b/tests/modules/test_rope.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from audiocraft.modules.rope import RotaryEmbedding +from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend + + +def test_rope(): + set_efficient_attention_backend('xformers') + B, T, H, C = 8, 75, 16, 128 + + rope = RotaryEmbedding(dim=C) + xq = torch.rand((B, T, H, C)) + xk = torch.rand((B, T, H, C)) + xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) + + assert list(xq_out.shape) == [B, T, H, C] + assert list(xk_out.shape) == [B, T, H, C] + + +def test_rope_io_dtypes(): + set_efficient_attention_backend('xformers') + B, T, H, C = 8, 75, 16, 128 + + rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32) + rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64) + + # Test bfloat16 inputs w/ both 32 and 64 precision rope. + xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) + xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) + xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16) + assert xq_out.dtype == torch.bfloat16 + xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16) + assert xq_out.dtype == torch.bfloat16 + + # Test float32 inputs w/ both 32 and 64 precision rope. + xq_32 = torch.rand((B, T, H, C)).to(torch.float32) + xk_32 = torch.rand((B, T, H, C)).to(torch.float32) + xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32) + assert xq_out.dtype == torch.float32 + xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32) + assert xq_out.dtype == torch.float32 + + +def test_transformer_with_rope(): + set_efficient_attention_backend('xformers') + torch.manual_seed(1234) + for pos in ['rope', 'sin_rope']: + tr = StreamingTransformer( + 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, + positional_embedding=pos) + tr.eval() + steps = 12 + x = torch.randn(3, steps, 16) + + out = tr(x) + assert list(out.shape) == list(x.shape) + + +@torch.no_grad() +def test_rope_streaming(): + set_efficient_attention_backend('xformers') + torch.manual_seed(1234) + tr = StreamingTransformer( + 16, 4, 2, causal=True, dropout=0., + custom=True, positional_embedding='rope') + tr.eval() + steps = 12 + x = torch.randn(3, steps, 16) + + ref = tr(x) + + with tr.streaming(): + outs = [] + frame_sizes = [1] * steps + + for frame_size in frame_sizes: + frame = x[:, :frame_size] + x = x[:, frame_size:] + outs.append(tr(frame)) + + out = torch.cat(outs, dim=1) + assert list(out.shape) == [3, steps, 16] + delta = torch.norm(out - ref) / torch.norm(out) + assert delta < 1e-6, delta + + +@torch.no_grad() +def test_rope_streaming_past_context(): + set_efficient_attention_backend('xformers') + torch.manual_seed(1234) + + for context in [None, 10]: + tr = StreamingTransformer( + 16, 4, 1 if context else 2, + causal=True, past_context=context, custom=True, + dropout=0., positional_embedding='rope') + tr.eval() + + steps = 20 + x = torch.randn(3, steps, 16) + ref = tr(x) + + with tr.streaming(): + outs = [] + frame_sizes = [1] * steps + + for frame_size in frame_sizes: + frame = x[:, :frame_size] + x = x[:, frame_size:] + outs.append(tr(frame)) + + out = torch.cat(outs, dim=1) + assert list(out.shape) == [3, steps, 16] + delta = torch.norm(out - ref) / torch.norm(out) + assert delta < 1e-6, delta + + +def test_rope_memory_efficient(): + set_efficient_attention_backend('xformers') + torch.manual_seed(1234) + tr = StreamingTransformer( + 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, + positional_embedding='rope') + tr_mem_efficient = StreamingTransformer( + 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1, + positional_embedding='rope') + tr_mem_efficient.load_state_dict(tr.state_dict()) + tr.eval() + steps = 12 + x = torch.randn(3, steps, 16) + + with torch.no_grad(): + y = tr(x) + y2 = tr_mem_efficient(x) + # Check at float precision b/c this is the rope default. + assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm() + + +def test_rope_with_xpos(): + set_efficient_attention_backend('xformers') + B, T, H, C = 8, 75, 16, 128 + + rope = RotaryEmbedding(dim=C, xpos=True) + xq = torch.rand((B, T, H, C)) + xk = torch.rand((B, T, H, C)) + xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) + + assert list(xq_out.shape) == [B, T, H, C] + assert list(xk_out.shape) == [B, T, H, C] + + +def test_positional_scale(): + set_efficient_attention_backend('xformers') + B, T, H, C = 8, 75, 16, 128 + + rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0) + xq = torch.rand((B, T, H, C)) + xk = torch.rand((B, T, H, C)) + xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) + + assert torch.allclose(xq, xq_out) + assert torch.allclose(xk, xk_out) diff --git a/tests/modules/test_seanet.py b/tests/modules/test_seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c51b340a2f94fb2828b14daf83d5fad645073d --- /dev/null +++ b/tests/modules/test_seanet.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import pytest +import torch + +from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock +from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d + + +class TestSEANetModel: + + def test_base(self): + encoder = SEANetEncoder() + decoder = SEANetDecoder() + + x = torch.randn(1, 1, 24000) + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + def test_causal(self): + encoder = SEANetEncoder(causal=True) + decoder = SEANetDecoder(causal=True) + x = torch.randn(1, 1, 24000) + + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + def test_conv_skip_connection(self): + encoder = SEANetEncoder(true_skip=False) + decoder = SEANetDecoder(true_skip=False) + + x = torch.randn(1, 1, 24000) + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + def test_seanet_encoder_decoder_final_act(self): + encoder = SEANetEncoder(true_skip=False) + decoder = SEANetDecoder(true_skip=False, final_activation='Tanh') + + x = torch.randn(1, 1, 24000) + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str): + n_blocks = 0 + for layer in encoder.model: + if isinstance(layer, StreamableConv1d): + n_blocks += 1 + assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm + elif isinstance(layer, SEANetResnetBlock): + for resnet_layer in layer.block: + if isinstance(resnet_layer, StreamableConv1d): + # here we add + 1 to n_blocks as we increment n_blocks just after the block + assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm + + def test_encoder_disable_norm(self): + n_residuals = [0, 1, 3] + disable_blocks = [0, 1, 2, 3, 4, 5, 6] + norms = ['weight_norm', 'none'] + for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): + encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm, + disable_norm_outer_blocks=disable_blocks) + self._check_encoder_blocks_norm(encoder, disable_blocks, norm) + + def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str): + n_blocks = 0 + for layer in decoder.model: + if isinstance(layer, StreamableConv1d): + n_blocks += 1 + assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm + elif isinstance(layer, StreamableConvTranspose1d): + n_blocks += 1 + assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm + elif isinstance(layer, SEANetResnetBlock): + for resnet_layer in layer.block: + if isinstance(resnet_layer, StreamableConv1d): + assert resnet_layer.conv.norm_type == 'none' \ + if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm + + def test_decoder_disable_norm(self): + n_residuals = [0, 1, 3] + disable_blocks = [0, 1, 2, 3, 4, 5, 6] + norms = ['weight_norm', 'none'] + for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): + decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm, + disable_norm_outer_blocks=disable_blocks) + self._check_decoder_blocks_norm(decoder, disable_blocks, norm) + + def test_disable_norm_raises_exception(self): + # Invalid disable_norm_outer_blocks values raise exceptions + with pytest.raises(AssertionError): + SEANetEncoder(disable_norm_outer_blocks=-1) + + with pytest.raises(AssertionError): + SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) + + with pytest.raises(AssertionError): + SEANetDecoder(disable_norm_outer_blocks=-1) + + with pytest.raises(AssertionError): + SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) diff --git a/tests/modules/test_transformer.py b/tests/modules/test_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb79bfd58d535469f9b3c56b8a5fe254db5d8ba --- /dev/null +++ b/tests/modules/test_transformer.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import pytest +import torch + +from audiocraft.modules.transformer import ( + StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend) + + +def test_transformer_causal_streaming(): + torch.manual_seed(1234) + + for context, custom in product([None, 10], [False, True]): + # Test that causality and receptive fields are properly handled. + # looking at the gradients + tr = StreamingTransformer( + 16, 4, 1 if context else 2, + causal=True, past_context=context, custom=custom, + dropout=0.) + steps = 20 + for k in [0, 10, 15, 19]: + x = torch.randn(4, steps, 16, requires_grad=True) + y = tr(x) + y[:, k].abs().sum().backward() + if k + 1 < steps: + assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm() + assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm() + if context is not None and k > context: + limit = k - context - 1 + assert torch.allclose(x.grad[:, :limit], + torch.tensor(0.)), x.grad[:, :limit].norm() + + # Now check that streaming gives the same result at batch eval. + x = torch.randn(4, steps, 16) + y = tr(x) + ys = [] + with tr.streaming(): + for k in range(steps): + chunk = x[:, k:k + 1, :] + ys.append(tr(chunk)) + y_stream = torch.cat(ys, dim=1) + delta = torch.norm(y_stream - y) / torch.norm(y) + assert delta < 1e-6, delta + + +def test_transformer_vs_pytorch(): + torch.manual_seed(1234) + # Check that in the non causal setting, we get the same result as + # PyTorch Transformer encoder. + for custom in [False, True]: + tr = StreamingTransformer( + 16, 4, 2, + causal=False, custom=custom, dropout=0., positional_scale=0.) + layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True) + tr_ref = torch.nn.TransformerEncoder(layer, 2) + tr.load_state_dict(tr_ref.state_dict()) + + x = torch.randn(4, 20, 16) + y = tr(x) + y2 = tr_ref(x) + delta = torch.norm(y2 - y) / torch.norm(y) + assert delta < 1e-6, delta + + +def test_streaming_api(): + tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.) + tr.eval() + steps = 12 + x = torch.randn(1, steps, 16) + + with torch.no_grad(): + with tr.streaming(): + _ = tr(x[:, :1]) + state = {k: v.clone() for k, v in tr.get_streaming_state().items()} + y = tr(x[:, 1:2]) + tr.set_streaming_state(state) + y2 = tr(x[:, 1:2]) + assert torch.allclose(y, y2), (y - y2).norm() + assert tr.flush() is None + + +def test_memory_efficient(): + for backend in ['torch', 'xformers']: + torch.manual_seed(1234) + set_efficient_attention_backend(backend) + + tr = StreamingTransformer( + 16, 4, 2, custom=True, dropout=0., layer_scale=0.1) + tr_mem_efficient = StreamingTransformer( + 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1) + tr_mem_efficient.load_state_dict(tr.state_dict()) + tr.eval() + steps = 12 + x = torch.randn(3, steps, 16) + + with torch.no_grad(): + y = tr(x) + y2 = tr_mem_efficient(x) + assert torch.allclose(y, y2), ((y - y2).norm(), backend) + + +def test_attention_as_float32(): + torch.manual_seed(1234) + cases = [ + {'custom': True}, + {'custom': False}, + ] + for case in cases: + tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case) + tr_float32 = StreamingTransformer( + 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case) + if not case['custom']: + # we are not using autocast here because it doesn't really + # work as expected on CPU, so we have to manually cast the weights of the MHA. + for layer in tr_float32.layers: + layer.self_attn.mha.to(torch.float32) + tr_float32.load_state_dict(tr.state_dict()) + steps = 12 + x = torch.randn(3, steps, 16, dtype=torch.bfloat16) + + with torch.no_grad(): + y = tr(x) + y2 = tr_float32(x) + assert not torch.allclose(y, y2), (y - y2).norm() + + +@torch.no_grad() +def test_streaming_memory_efficient(): + for backend in ['torch', 'xformers']: + torch.manual_seed(1234) + set_efficient_attention_backend(backend) + tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True) + tr_mem_efficient = StreamingTransformer( + 16, 4, 2, dropout=0., memory_efficient=True, causal=True) + tr.load_state_dict(tr_mem_efficient.state_dict()) + tr.eval() + tr_mem_efficient.eval() + steps = 12 + x = torch.randn(3, steps, 16) + + ref = tr(x) + + with tr_mem_efficient.streaming(): + outs = [] + # frame_sizes = [2] + [1] * (steps - 2) + frame_sizes = [1] * steps + + for frame_size in frame_sizes: + frame = x[:, :frame_size] + x = x[:, frame_size:] + outs.append(tr_mem_efficient(frame)) + + out = torch.cat(outs, dim=1) + delta = torch.norm(out - ref) / torch.norm(out) + assert delta < 1e-6, delta + + +def test_cross_attention(): + torch.manual_seed(1234) + for norm_first in [True, False]: + m = StreamingTransformer( + 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True) + m_cross = StreamingTransformer( + 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True) + m_cross.load_state_dict(m.state_dict(), strict=False) + x = torch.randn(2, 5, 16) + cross_x = torch.randn(2, 3, 16) + y_ref = m(x) + y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x) + # With norm_first, the two should be exactly the same, + # but with norm_first=False, we get 2 normalization in a row + # and the epsilon value leads to a tiny change. + atol = 0. if norm_first else 1e-6 + print((y_ref - y_cross_zero).norm() / y_ref.norm()) + assert torch.allclose(y_ref, y_cross_zero, atol=atol) + + # We now expect a difference even with a generous atol of 1e-2. + y_cross = m_cross(x, cross_attention_src=cross_x) + assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2) + + with pytest.raises(AssertionError): + _ = m_cross(x) + _ = m(x, cross_attention_src=cross_x) + + +def test_cross_attention_compat(): + torch.manual_seed(1234) + num_heads = 2 + dim = num_heads * 64 + with pytest.raises(AssertionError): + StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True) + + cross_attn = StreamingMultiheadAttention( + dim, num_heads, dropout=0, cross_attention=True, custom=True) + ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True) + + # We can load the regular attention state dict + # so we have compat when loading old checkpoints. + cross_attn.load_state_dict(ref_attn.state_dict()) + + queries = torch.randn(3, 7, dim) + keys = torch.randn(3, 9, dim) + values = torch.randn(3, 9, dim) + + y = cross_attn(queries, keys, values)[0] + y_ref = ref_attn(queries, keys, values)[0] + assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm() + + # Now let's check that streaming is working properly. + with cross_attn.streaming(): + ys = [] + for step in range(queries.shape[1]): + ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0]) + y_streaming = torch.cat(ys, dim=1) + assert torch.allclose(y_streaming, y, atol=1e-7) + + +def test_repeat_kv(): + torch.manual_seed(1234) + num_heads = 8 + kv_repeat = 4 + dim = num_heads * 64 + with pytest.raises(AssertionError): + mha = StreamingMultiheadAttention( + dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True) + mha = StreamingMultiheadAttention( + dim, num_heads, causal=True, kv_repeat=kv_repeat) + mha = StreamingMultiheadAttention( + dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True) + x = torch.randn(4, 18, dim) + y = mha(x, x, x)[0] + assert x.shape == y.shape + + +def test_qk_layer_norm(): + torch.manual_seed(1234) + tr = StreamingTransformer( + 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False) + steps = 12 + x = torch.randn(3, steps, 16) + y = tr(x) + + tr = StreamingTransformer( + 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True) + z = torch.randn(3, 21, 16) + y = tr(x, cross_attention_src=z) + assert y.shape == x.shape diff --git a/tests/quantization/test_vq.py b/tests/quantization/test_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..c215099fedacae35c6798fdd9b8420a447aa16bb --- /dev/null +++ b/tests/quantization/test_vq.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from audiocraft.quantization.vq import ResidualVectorQuantizer + + +class TestResidualVectorQuantizer: + + def test_rvq(self): + x = torch.randn(1, 16, 2048) + vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8) + res = vq(x, 1.) + assert res.x.shape == torch.Size([1, 16, 2048]) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree.