diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..14b6901d4c3014cde9c3a17a0383fd5ca79b3804
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+__pycache__/
+.ipynb_checkpoints/
+processed
+outputs
+checkpoints
+trash
+examples*
+.env
\ No newline at end of file
diff --git a/.ipynb_checkpoints/demo_part1-checkpoint.ipynb b/.ipynb_checkpoints/demo_part1-checkpoint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..065c54093ec5b2543f94796354832f66b04af1b2
--- /dev/null
+++ b/.ipynb_checkpoints/demo_part1-checkpoint.ipynb
@@ -0,0 +1,236 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b6ee1ede",
+ "metadata": {},
+ "source": [
+ "## Voice Style Control Demo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b7f043ee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import se_extractor\n",
+ "from api import BaseSpeakerTTS, ToneColorConverter"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15116b59",
+ "metadata": {},
+ "source": [
+ "### Initialization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "aacad912",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ckpt_base = 'checkpoints/base_speakers/EN'\n",
+ "ckpt_converter = 'checkpoints/converter'\n",
+ "device = 'cuda:0'\n",
+ "output_dir = 'outputs'\n",
+ "\n",
+ "base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)\n",
+ "base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')\n",
+ "\n",
+ "tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)\n",
+ "tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')\n",
+ "\n",
+ "os.makedirs(output_dir, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7f67740c",
+ "metadata": {},
+ "source": [
+ "### Obtain Tone Color Embedding"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8add279",
+ "metadata": {},
+ "source": [
+ "The `source_se` is the tone color embedding of the base speaker. \n",
+ "It is an average of multiple sentences generated by the base speaker. We directly provide the result here but\n",
+ "the readers feel free to extract `source_se` by themselves."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "63ff6273",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4f71fcc3",
+ "metadata": {},
+ "source": [
+ "The `reference_speaker.mp3` below points to the short audio clip of the reference whose voice we want to clone. We provide an example here. If you use your own reference speakers, please **make sure each speaker has a unique filename.** The `se_extractor` will save the `targeted_se` using the filename of the audio and **will not automatically overwrite.**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "55105eae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reference_speaker = 'resources/example_reference.mp3'\n",
+ "target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a40284aa",
+ "metadata": {},
+ "source": [
+ "### Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "73dc1259",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_path = f'{output_dir}/output_en_default.wav'\n",
+ "\n",
+ "# Run the base speaker tts\n",
+ "text = \"This audio is generated by OpenVoice.\"\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)\n",
+ "\n",
+ "# Run the tone color converter\n",
+ "encode_message = \"@MyShell\"\n",
+ "tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e3ea28a",
+ "metadata": {},
+ "source": [
+ "**Try with different styles and speed.** The style can be controlled by the `speaker` parameter in the `base_speaker_tts.tts` method. Available choices: friendly, cheerful, excited, sad, angry, terrified, shouting, whispering. Note that the tone color embedding need to be updated. The speed can be controlled by the `speed` parameter. Let's try whispering with speed 0.9."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fd022d38",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "source_se = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)\n",
+ "save_path = f'{output_dir}/output_whispering.wav'\n",
+ "\n",
+ "# Run the base speaker tts\n",
+ "text = \"This audio is generated by OpenVoice with a half-performance model.\"\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "base_speaker_tts.tts(text, src_path, speaker='whispering', language='English', speed=0.9)\n",
+ "\n",
+ "# Run the tone color converter\n",
+ "encode_message = \"@MyShell\"\n",
+ "tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5fcfc70b",
+ "metadata": {},
+ "source": [
+ "**Try with different languages.** OpenVoice can achieve multi-lingual voice cloning by simply replace the base speaker. We provide an example with a Chinese base speaker here and we encourage the readers to try `demo_part2.ipynb` for a detailed demo."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a71d1387",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "ckpt_base = 'checkpoints/base_speakers/ZH'\n",
+ "base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)\n",
+ "base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')\n",
+ "\n",
+ "source_se = torch.load(f'{ckpt_base}/zh_default_se.pth').to(device)\n",
+ "save_path = f'{output_dir}/output_chinese.wav'\n",
+ "\n",
+ "# Run the base speaker tts\n",
+ "text = \"今天天气真好,我们一起出去吃饭吧。\"\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "base_speaker_tts.tts(text, src_path, speaker='default', language='Chinese', speed=1.0)\n",
+ "\n",
+ "# Run the tone color converter\n",
+ "encode_message = \"@MyShell\"\n",
+ "tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8e513094",
+ "metadata": {},
+ "source": [
+ "**Tech for good.** For people who will deploy OpenVoice for public usage: We offer you the option to add watermark to avoid potential misuse. Please see the ToneColorConverter class. **MyShell reserves the ability to detect whether an audio is generated by OpenVoice**, no matter whether the watermark is added or not."
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "9d70c38e1c0b038dbdffdaa4f8bfa1f6767c43760905c87a9fbe7800d18c6c35"
+ },
+ "kernelspec": {
+ "display_name": "Python 3.9.18 ('openvoice')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/.ipynb_checkpoints/requirements-checkpoint.txt b/.ipynb_checkpoints/requirements-checkpoint.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a4139330983dccb0178328577cfc92182cc2c5bc
--- /dev/null
+++ b/.ipynb_checkpoints/requirements-checkpoint.txt
@@ -0,0 +1,14 @@
+librosa==0.9.1
+faster-whisper==0.9.0
+pydub==0.25.1
+wavmark==0.0.2
+numpy==1.22.0
+eng_to_ipa==0.0.2
+inflect==7.0.0
+unidecode==1.3.7
+whisper-timestamped==1.14.2
+openai
+python-dotenv
+pypinyin
+jieba
+cn2an
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..1907e667cf169f793972cb66a212584c5f5ba278
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,333 @@
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
diff --git a/README.md b/README.md
index 90dd14dcf760620d1898b74be09995c0668dadb8..4a5e852c7f1f86ade0d317d3970e06bfbf3c1b84 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,100 @@
----
-title: OpenVoice
-emoji: 📚
-colorFrom: indigo
-colorTo: yellow
-sdk: gradio
-sdk_version: 4.12.0
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+
+
+[Paper](https://arxiv.org/abs/2312.01479) |
+[Website](https://research.myshell.ai/open-voice)
+
+
+
+## Introduction
+As we detailed in our [paper](https://arxiv.org/abs/2312.01479) and [website](https://research.myshell.ai/open-voice), the advantages of OpenVoice are three-fold:
+
+**1. Accurate Tone Color Cloning.**
+OpenVoice can accurately clone the reference tone color and generate speech in multiple languages and accents.
+
+**2. Flexible Voice Style Control.**
+OpenVoice enables granular control over voice styles, such as emotion and accent, as well as other style parameters including rhythm, pauses, and intonation.
+
+**3. Zero-shot Cross-lingual Voice Cloning.**
+Neither of the language of the generated speech nor the language of the reference speech needs to be presented in the massive-speaker multi-lingual training dataset.
+
+[Video](https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f)
+
+
+
+
+
+
+
+OpenVoice has been powering the instant voice cloning capability of [myshell.ai](https://app.myshell.ai/explore) since May 2023. Until Nov 2023, the voice cloning model has been used tens of millions of times by users worldwide, and witnessed the explosive user growth on the platform.
+
+## Main Contributors
+
+- [Zengyi Qin](https://www.qinzy.tech) at MIT and MyShell
+- [Wenliang Zhao](https://wl-zhao.github.io) at Tsinghua University
+- [Xumin Yu](https://yuxumin.github.io) at Tsinghua University
+- [Ethan Sun](https://twitter.com/ethan_myshell) at MyShell
+
+## Live Demo
+
+
+
+
+
+
+
+## Disclaimer
+
+This is an open-source implementation that approximates the performance of the internal voice clone technology of [myshell.ai](https://app.myshell.ai/explore). The online version in myshell.ai has better 1) audio quality, 2) voice cloning similarity, 3) speech naturalness and 4) computational efficiency.
+
+## Installation
+Clone this repo, and run
+```
+conda create -n openvoice python=3.9
+conda activate openvoice
+conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
+pip install -r requirements.txt
+```
+Download the checkpoint from [here](https://myshell-public-repo-hosting.s3.amazonaws.com/checkpoints_1226.zip) and extract it to the `checkpoints` folder
+
+## Usage
+
+**1. Flexible Voice Style Control.**
+Please see `demo_part1.ipynb` for an example usage of how OpenVoice enables flexible style control over the cloned voice.
+
+**2. Cross-Lingual Voice Cloning.**
+Please see `demo_part2.ipynb` for an example for languages seen or unseen in the MSML training set.
+
+**3. Advanced Usage.**
+The base speaker model can be replaced with any model (in any language and style) that the user prefer. Please use the `se_extractor.get_se` function as demonstrated in the demo to extract the tone color embedding for the new base speaker.
+
+**4. Tips to Generate Natural Speech.**
+There are many single or multi-speaker TTS methods that can generate natural speech, and are readily available. By simply replacing the base speaker model with the model you prefer, you can push the speech naturalness to a level you desire.
+
+## Roadmap
+
+- [x] Inference code
+- [x] Tone color converter model
+- [x] Multi-style base speaker model
+- [x] Multi-style and multi-lingual demo
+- [x] Base speaker model in other languages
+- [x] EN base speaker model with better naturalness
+
+
+## Citation
+```
+@article{qin2023openvoice,
+ title={OpenVoice: Versatile Instant Voice Cloning},
+ author={Qin, Zengyi and Zhao, Wenliang and Yu, Xumin and Sun, Xin},
+ journal={arXiv preprint arXiv:2312.01479},
+ year={2023}
+}
+```
+
+## License
+This repository is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License, which prohibits commercial usage. **MyShell reserves the ability to detect whether an audio is generated by OpenVoice**, no matter whether the watermark is added or not.
+
+
+## Acknowledgements
+This open-source implementation is based on several open-source projects, [TTS](https://github.com/coqui-ai/TTS), [VITS](https://github.com/jaywalnut310/vits), and [VITS2](https://github.com/daniilrobnikov/vits2). Thanks for their awesome work!
diff --git a/__pycache__/api.cpython-310.pyc b/__pycache__/api.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9654189464cefd1a9df0226e387d68c4fbe401f8
Binary files /dev/null and b/__pycache__/api.cpython-310.pyc differ
diff --git a/__pycache__/attentions.cpython-310.pyc b/__pycache__/attentions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75df14247e5232b9512874a22f3116846ffe5da0
Binary files /dev/null and b/__pycache__/attentions.cpython-310.pyc differ
diff --git a/__pycache__/commons.cpython-310.pyc b/__pycache__/commons.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f5f63adf4c990dd268952854927bf87f2b63c24
Binary files /dev/null and b/__pycache__/commons.cpython-310.pyc differ
diff --git a/__pycache__/mel_processing.cpython-310.pyc b/__pycache__/mel_processing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ba191673c1242b5852b471f66f810989e867803
Binary files /dev/null and b/__pycache__/mel_processing.cpython-310.pyc differ
diff --git a/__pycache__/models.cpython-310.pyc b/__pycache__/models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45a29d1050ca9a9292d8243ec2172a70fc2b27aa
Binary files /dev/null and b/__pycache__/models.cpython-310.pyc differ
diff --git a/__pycache__/modules.cpython-310.pyc b/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46b6050ab4cbc086279b2cf0b8c9015cc7c36e48
Binary files /dev/null and b/__pycache__/modules.cpython-310.pyc differ
diff --git a/__pycache__/se_extractor.cpython-310.pyc b/__pycache__/se_extractor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a519014b6ecfeeb37f79ea1907a5fa286bbf2623
Binary files /dev/null and b/__pycache__/se_extractor.cpython-310.pyc differ
diff --git a/__pycache__/transforms.cpython-310.pyc b/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79109a44c6a3a0531debb8853db75e3187eba998
Binary files /dev/null and b/__pycache__/transforms.cpython-310.pyc differ
diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6d01ede5b9ff33daefc866cca73c15ac0b9c474
Binary files /dev/null and b/__pycache__/utils.cpython-310.pyc differ
diff --git a/api.py b/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b769470883761670986b13f98782255b8f7fe36c
--- /dev/null
+++ b/api.py
@@ -0,0 +1,201 @@
+import torch
+import numpy as np
+import re
+import soundfile
+import utils
+import commons
+import os
+import librosa
+from text import text_to_sequence
+from mel_processing import spectrogram_torch
+from models import SynthesizerTrn
+
+
+class OpenVoiceBaseClass(object):
+ def __init__(self,
+ config_path,
+ device='cuda:0'):
+ if 'cuda' in device:
+ assert torch.cuda.is_available()
+
+ hps = utils.get_hparams_from_file(config_path)
+
+ model = SynthesizerTrn(
+ len(getattr(hps, 'symbols', [])),
+ hps.data.filter_length // 2 + 1,
+ n_speakers=hps.data.n_speakers,
+ **hps.model,
+ ).to(device)
+
+ model.eval()
+ self.model = model
+ self.hps = hps
+ self.device = device
+
+ def load_ckpt(self, ckpt_path):
+ checkpoint_dict = torch.load(ckpt_path)
+ a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
+ print("Loaded checkpoint '{}'".format(ckpt_path))
+ print('missing/unexpected keys:', a, b)
+
+
+class BaseSpeakerTTS(OpenVoiceBaseClass):
+ language_marks = {
+ "english": "EN",
+ "chinese": "ZH",
+ }
+
+ @staticmethod
+ def get_text(text, hps, is_symbol):
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
+ if hps.data.add_blank:
+ text_norm = commons.intersperse(text_norm, 0)
+ text_norm = torch.LongTensor(text_norm)
+ return text_norm
+
+ @staticmethod
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
+ audio_segments = []
+ for segment_data in segment_data_list:
+ audio_segments += segment_data.reshape(-1).tolist()
+ audio_segments += [0] * int((sr * 0.05)/speed)
+ audio_segments = np.array(audio_segments).astype(np.float32)
+ return audio_segments
+
+ @staticmethod
+ def split_sentences_into_pieces(text, language_str):
+ texts = utils.split_sentence(text, language_str=language_str)
+ print(" > Text splitted to sentences.")
+ print('\n'.join(texts))
+ print(" > ===========================")
+ return texts
+
+ def tts(self, text, output_path, speaker, language='English', speed=1.0):
+ mark = self.language_marks.get(language.lower(), None)
+ assert mark is not None, f"language {language} is not supported"
+
+ texts = self.split_sentences_into_pieces(text, mark)
+
+ audio_list = []
+ for t in texts:
+ t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
+ t = f'[{mark}]{t}[{mark}]'
+ stn_tst = self.get_text(t, self.hps, False)
+ device = self.device
+ speaker_id = self.hps.speakers[speaker]
+ with torch.no_grad():
+ x_tst = stn_tst.unsqueeze(0).to(device)
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
+ sid = torch.LongTensor([speaker_id]).to(device)
+ audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
+ audio_list.append(audio)
+ audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
+
+ if output_path is None:
+ return audio
+ else:
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate)
+
+
+class ToneColorConverter(OpenVoiceBaseClass):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if kwargs.get('enable_watermark', True):
+ import wavmark
+ self.watermark_model = wavmark.load_model().to(self.device)
+ else:
+ self.watermark_model = None
+
+
+
+ def extract_se(self, ref_wav_list, se_save_path=None):
+ if isinstance(ref_wav_list, str):
+ ref_wav_list = [ref_wav_list]
+
+ device = self.device
+ hps = self.hps
+ gs = []
+
+ for fname in ref_wav_list:
+ audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
+ y = torch.FloatTensor(audio_ref)
+ y = y.to(device)
+ y = y.unsqueeze(0)
+ y = spectrogram_torch(y, hps.data.filter_length,
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
+ center=False).to(device)
+ with torch.no_grad():
+ g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
+ gs.append(g.detach())
+ gs = torch.stack(gs).mean(0)
+
+ if se_save_path is not None:
+ os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
+ torch.save(gs.cpu(), se_save_path)
+
+ return gs
+
+ def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
+ hps = self.hps
+ # load audio
+ audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
+ audio = torch.tensor(audio).float()
+
+ with torch.no_grad():
+ y = torch.FloatTensor(audio).to(self.device)
+ y = y.unsqueeze(0)
+ spec = spectrogram_torch(y, hps.data.filter_length,
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
+ center=False).to(self.device)
+ spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
+ audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
+ 0, 0].data.cpu().float().numpy()
+ audio = self.add_watermark(audio, message)
+ if output_path is None:
+ return audio
+ else:
+ soundfile.write(output_path, audio, hps.data.sampling_rate)
+
+ def add_watermark(self, audio, message):
+ if self.watermark_model is None:
+ return audio
+ device = self.device
+ bits = utils.string_to_bits(message).reshape(-1)
+ n_repeat = len(bits) // 32
+
+ K = 16000
+ coeff = 2
+ for n in range(n_repeat):
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
+ if len(trunck) != K:
+ print('Audio too short, fail to add watermark')
+ break
+ message_npy = bits[n * 32: (n + 1) * 32]
+
+ with torch.no_grad():
+ signal = torch.FloatTensor(trunck).to(device)[None]
+ message_tensor = torch.FloatTensor(message_npy).to(device)[None]
+ signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
+ signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
+ audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
+ return audio
+
+ def detect_watermark(self, audio, n_repeat):
+ bits = []
+ K = 16000
+ coeff = 2
+ for n in range(n_repeat):
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
+ if len(trunck) != K:
+ print('Audio too short, fail to detect watermark')
+ return 'Fail'
+ with torch.no_grad():
+ signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
+ message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
+ bits.append(message_decoded_npy)
+ bits = np.stack(bits).reshape(-1, 8)
+ message = utils.bits_to_string(bits)
+ return message
+
diff --git a/attentions.py b/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..355a72e5172fa86354fd848954b521dfb2bb158f
--- /dev/null
+++ b/attentions.py
@@ -0,0 +1,465 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import commons
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ window_size=4,
+ isflow=True,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ # if isflow:
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
+ # self.gin_channels = 256
+ self.cond_layer_idx = self.n_layers
+ if "gin_channels" in kwargs:
+ self.gin_channels = kwargs["gin_channels"]
+ if self.gin_channels != 0:
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
+ # vits2 says 3rd block, so idx is 2 by default
+ self.cond_layer_idx = (
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
+ )
+ # logging.debug(self.gin_channels, self.cond_layer_idx)
+ assert (
+ self.cond_layer_idx < self.n_layers
+ ), "cond_layer_idx should be less than n_layers"
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ window_size=window_size,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, g=None):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ if i == self.cond_layer_idx and g is not None:
+ g = self.spk_emb_linear(g.transpose(1, 2))
+ g = g.transpose(1, 2)
+ x = x + g
+ x = x * x_mask
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=True,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+
+ self.drop = nn.Dropout(p_dropout)
+ self.self_attn_layers = nn.ModuleList()
+ self.norm_layers_0 = nn.ModuleList()
+ self.encdec_attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.self_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ proximal_bias=proximal_bias,
+ proximal_init=proximal_init,
+ )
+ )
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
+ self.encdec_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ causal=True,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, h, h_mask):
+ """
+ x: decoder input
+ h: encoder output
+ """
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
+ device=x.device, dtype=x.dtype
+ )
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_0[i](x + y)
+
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ p_dropout=0.0,
+ window_size=None,
+ heads_share=True,
+ block_length=None,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert (
+ t_s == t_t
+ ), "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(
+ query / math.sqrt(self.k_channels), key_relative_embeddings
+ )
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert (
+ t_s == t_t
+ ), "Local attention is only available for self-attention."
+ block_mask = (
+ torch.ones_like(scores)
+ .triu(-self.block_length)
+ .tril(self.block_length)
+ )
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(
+ self.emb_rel_v, t_s
+ )
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = (
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+ )
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # pad along column
+ x = F.pad(
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+ )
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=0.0,
+ activation=None,
+ causal=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
diff --git a/checkpoints/checkpoints/base_speakers/EN/checkpoint.pth b/checkpoints/checkpoints/base_speakers/EN/checkpoint.pth
new file mode 100644
index 0000000000000000000000000000000000000000..fb7c26af57011437a02ebb1c4fe8ed307cc30f21
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/EN/checkpoint.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1db1ae1a5c8ded049bd1536051489aefbfad4a5077c01c2257e9e88fa1bb8422
+size 160467309
diff --git a/checkpoints/checkpoints/base_speakers/EN/config.json b/checkpoints/checkpoints/base_speakers/EN/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f7309ad10eae3c160ea0ef44261372c4f3364587
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/EN/config.json
@@ -0,0 +1,145 @@
+{
+ "data": {
+ "text_cleaners": [
+ "cjke_cleaners2"
+ ],
+ "sampling_rate": 22050,
+ "filter_length": 1024,
+ "hop_length": 256,
+ "win_length": 1024,
+ "n_mel_channels": 80,
+ "add_blank": true,
+ "cleaned_text": true,
+ "n_speakers": 10
+ },
+ "model": {
+ "inter_channels": 192,
+ "hidden_channels": 192,
+ "filter_channels": 768,
+ "n_heads": 2,
+ "n_layers": 6,
+ "n_layers_trans_flow": 3,
+ "kernel_size": 3,
+ "p_dropout": 0.1,
+ "resblock": "1",
+ "resblock_kernel_sizes": [
+ 3,
+ 7,
+ 11
+ ],
+ "resblock_dilation_sizes": [
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ]
+ ],
+ "upsample_rates": [
+ 8,
+ 8,
+ 2,
+ 2
+ ],
+ "upsample_initial_channel": 512,
+ "upsample_kernel_sizes": [
+ 16,
+ 16,
+ 4,
+ 4
+ ],
+ "n_layers_q": 3,
+ "use_spectral_norm": false,
+ "gin_channels": 256
+ },
+ "symbols": [
+ "_",
+ ",",
+ ".",
+ "!",
+ "?",
+ "-",
+ "~",
+ "\u2026",
+ "N",
+ "Q",
+ "a",
+ "b",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ "\u0251",
+ "\u00e6",
+ "\u0283",
+ "\u0291",
+ "\u00e7",
+ "\u026f",
+ "\u026a",
+ "\u0254",
+ "\u025b",
+ "\u0279",
+ "\u00f0",
+ "\u0259",
+ "\u026b",
+ "\u0265",
+ "\u0278",
+ "\u028a",
+ "\u027e",
+ "\u0292",
+ "\u03b8",
+ "\u03b2",
+ "\u014b",
+ "\u0266",
+ "\u207c",
+ "\u02b0",
+ "`",
+ "^",
+ "#",
+ "*",
+ "=",
+ "\u02c8",
+ "\u02cc",
+ "\u2192",
+ "\u2193",
+ "\u2191",
+ " "
+ ],
+ "speakers": {
+ "default": 1,
+ "whispering": 2,
+ "shouting": 3,
+ "excited": 4,
+ "cheerful": 5,
+ "terrified": 6,
+ "angry": 7,
+ "sad": 8,
+ "friendly": 9
+ }
+}
\ No newline at end of file
diff --git a/checkpoints/checkpoints/base_speakers/EN/en_default_se.pth b/checkpoints/checkpoints/base_speakers/EN/en_default_se.pth
new file mode 100644
index 0000000000000000000000000000000000000000..319d7eb4bee7b785a47f4e6191c2132dec12abcf
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/EN/en_default_se.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9cab24002eec738d0fe72cb73a34e57fbc3999c1bd4a1670a7b56ee4e3590ac9
+size 1789
diff --git a/checkpoints/checkpoints/base_speakers/EN/en_style_se.pth b/checkpoints/checkpoints/base_speakers/EN/en_style_se.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c2fd50abf058f6ab65879395b62fb7e3c0289b47
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/EN/en_style_se.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f698153be5004b90a8642d1157c89cae7dd296752a3276450ced6a17b8b98a9
+size 1783
diff --git a/checkpoints/checkpoints/base_speakers/ZH/checkpoint.pth b/checkpoints/checkpoints/base_speakers/ZH/checkpoint.pth
new file mode 100644
index 0000000000000000000000000000000000000000..fcadb5c222e9ea92fc9ada4920249fc65cad1692
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/ZH/checkpoint.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de9fb0eb749f3254130fe0172fcbb20e75f88a9b16b54dd0b73cac0dc40da7d9
+size 160467309
diff --git a/checkpoints/checkpoints/base_speakers/ZH/config.json b/checkpoints/checkpoints/base_speakers/ZH/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..130256092fb8ad00f938149bf8aa1a62aae30023
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/ZH/config.json
@@ -0,0 +1,137 @@
+{
+ "data": {
+ "text_cleaners": [
+ "cjke_cleaners2"
+ ],
+ "sampling_rate": 22050,
+ "filter_length": 1024,
+ "hop_length": 256,
+ "win_length": 1024,
+ "n_mel_channels": 80,
+ "add_blank": true,
+ "cleaned_text": true,
+ "n_speakers": 10
+ },
+ "model": {
+ "inter_channels": 192,
+ "hidden_channels": 192,
+ "filter_channels": 768,
+ "n_heads": 2,
+ "n_layers": 6,
+ "n_layers_trans_flow": 3,
+ "kernel_size": 3,
+ "p_dropout": 0.1,
+ "resblock": "1",
+ "resblock_kernel_sizes": [
+ 3,
+ 7,
+ 11
+ ],
+ "resblock_dilation_sizes": [
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ]
+ ],
+ "upsample_rates": [
+ 8,
+ 8,
+ 2,
+ 2
+ ],
+ "upsample_initial_channel": 512,
+ "upsample_kernel_sizes": [
+ 16,
+ 16,
+ 4,
+ 4
+ ],
+ "n_layers_q": 3,
+ "use_spectral_norm": false,
+ "gin_channels": 256
+ },
+ "symbols": [
+ "_",
+ ",",
+ ".",
+ "!",
+ "?",
+ "-",
+ "~",
+ "\u2026",
+ "N",
+ "Q",
+ "a",
+ "b",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ "\u0251",
+ "\u00e6",
+ "\u0283",
+ "\u0291",
+ "\u00e7",
+ "\u026f",
+ "\u026a",
+ "\u0254",
+ "\u025b",
+ "\u0279",
+ "\u00f0",
+ "\u0259",
+ "\u026b",
+ "\u0265",
+ "\u0278",
+ "\u028a",
+ "\u027e",
+ "\u0292",
+ "\u03b8",
+ "\u03b2",
+ "\u014b",
+ "\u0266",
+ "\u207c",
+ "\u02b0",
+ "`",
+ "^",
+ "#",
+ "*",
+ "=",
+ "\u02c8",
+ "\u02cc",
+ "\u2192",
+ "\u2193",
+ "\u2191",
+ " "
+ ],
+ "speakers": {
+ "default": 0
+ }
+}
\ No newline at end of file
diff --git a/checkpoints/checkpoints/base_speakers/ZH/zh_default_se.pth b/checkpoints/checkpoints/base_speakers/ZH/zh_default_se.pth
new file mode 100644
index 0000000000000000000000000000000000000000..471841ae84a31aae1c8e25c1ef4548b3e87a32bb
--- /dev/null
+++ b/checkpoints/checkpoints/base_speakers/ZH/zh_default_se.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b62e8264962059b8a84dd00b29e2fcccc92f5d3be90eec67dfa082c0cf58ccf
+size 1789
diff --git a/checkpoints/checkpoints/converter/checkpoint.pth b/checkpoints/checkpoints/converter/checkpoint.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c38ff17666bae2bae4236f85bfe2284f4885b31a
--- /dev/null
+++ b/checkpoints/checkpoints/converter/checkpoint.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89ae83aa4e3668fef64b388b789ff7b0ce0def9f801069edfc18a00ea420748d
+size 131327338
diff --git a/checkpoints/checkpoints/converter/config.json b/checkpoints/checkpoints/converter/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a163d4254b637e9fd489712db40c15aeacda169e
--- /dev/null
+++ b/checkpoints/checkpoints/converter/config.json
@@ -0,0 +1,57 @@
+{
+ "data": {
+ "sampling_rate": 22050,
+ "filter_length": 1024,
+ "hop_length": 256,
+ "win_length": 1024,
+ "n_speakers": 0
+ },
+ "model": {
+ "inter_channels": 192,
+ "hidden_channels": 192,
+ "filter_channels": 768,
+ "n_heads": 2,
+ "n_layers": 6,
+ "kernel_size": 3,
+ "p_dropout": 0.1,
+ "resblock": "1",
+ "resblock_kernel_sizes": [
+ 3,
+ 7,
+ 11
+ ],
+ "resblock_dilation_sizes": [
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ]
+ ],
+ "upsample_rates": [
+ 8,
+ 8,
+ 2,
+ 2
+ ],
+ "upsample_initial_channel": 512,
+ "upsample_kernel_sizes": [
+ 16,
+ 16,
+ 4,
+ 4
+ ],
+ "n_layers_q": 3,
+ "use_spectral_norm": false,
+ "gin_channels": 256
+ }
+}
\ No newline at end of file
diff --git a/checkpoints_1226.zip b/checkpoints_1226.zip
new file mode 100644
index 0000000000000000000000000000000000000000..75d6035034c00db7b4a69ed816143ea504d8abf8
--- /dev/null
+++ b/checkpoints_1226.zip
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d86a78fd96caca7e2e9d79c7724391b39581b861d332619a9aaa8de31ad954b7
+size 420616031
diff --git a/commons.py b/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3fa07f65b1681e1f469b04b2fe689b7c174eaaa
--- /dev/null
+++ b/commons.py
@@ -0,0 +1,160 @@
+import math
+import torch
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ layer = pad_shape[::-1]
+ pad_shape = [item for sublist in layer for item in sublist]
+ return pad_shape
+
+
+def intersperse(lst, item):
+ result = [item] * (len(lst) * 2 + 1)
+ result[1::2] = lst
+ return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += (
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+ )
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+ num_timescales - 1
+ )
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+ )
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ layer = pad_shape[::-1]
+ pad_shape = [item for sublist in layer for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2, 3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1.0 / norm_type)
+ return total_norm
diff --git a/demo_part1.ipynb b/demo_part1.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..15d16357e29cbc3ddfa5af80f76cedb5884b5012
--- /dev/null
+++ b/demo_part1.ipynb
@@ -0,0 +1,385 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b6ee1ede",
+ "metadata": {},
+ "source": [
+ "## Voice Style Control Demo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b7f043ee",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Importing the dtw module. When using in academic works please cite:\n",
+ " T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
+ " J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import se_extractor\n",
+ "from api import BaseSpeakerTTS, ToneColorConverter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "237202b7-4e8d-4445-a5b0-41db367e4977",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/workspace/OpenVoice\n"
+ ]
+ }
+ ],
+ "source": [
+ "cd OpenVoice"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "a79d4a34-ce9c-4ae3-b79a-c730b8885df6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Archive: checkpoints_1226.zip\n",
+ " creating: checkpoints/checkpoints/\n",
+ " creating: checkpoints/checkpoints/converter/\n",
+ " inflating: checkpoints/checkpoints/converter/config.json \n",
+ " inflating: checkpoints/checkpoints/converter/checkpoint.pth \n",
+ " creating: checkpoints/checkpoints/base_speakers/\n",
+ " creating: checkpoints/checkpoints/base_speakers/ZH/\n",
+ " inflating: checkpoints/checkpoints/base_speakers/ZH/config.json \n",
+ " inflating: checkpoints/checkpoints/base_speakers/ZH/checkpoint.pth \n",
+ " inflating: checkpoints/checkpoints/base_speakers/ZH/zh_default_se.pth \n",
+ " creating: checkpoints/checkpoints/base_speakers/EN/\n",
+ " inflating: checkpoints/checkpoints/base_speakers/EN/config.json \n",
+ " inflating: checkpoints/checkpoints/base_speakers/EN/en_style_se.pth \n",
+ " inflating: checkpoints/checkpoints/base_speakers/EN/en_default_se.pth \n",
+ " inflating: checkpoints/checkpoints/base_speakers/EN/checkpoint.pth \n"
+ ]
+ }
+ ],
+ "source": [
+ "!unzip checkpoints_1226.zip -d checkpoints"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "175d6dfc-b170-46de-8a44-8ec5a4e3b333",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: huggingface_hub in /opt/conda/lib/python3.10/site-packages (0.17.3)\n",
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (3.13.1)\n",
+ "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (2023.12.2)\n",
+ "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (4.65.0)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (4.7.1)\n",
+ "Requirement already satisfied: packaging>=20.9 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (23.1)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (1.26.18)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface_hub) (2023.11.17)\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "pip install huggingface_hub"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "35b568bf-9951-40b9-843f-0af4e918567f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f30315ee57c9418688ab0d9be585f948",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(HTML(value=' 6\u001b[0m base_speaker_tts \u001b[38;5;241m=\u001b[39m \u001b[43mBaseSpeakerTTS\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mckpt_base\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/config.json\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m base_speaker_tts\u001b[38;5;241m.\u001b[39mload_ckpt(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mckpt_base\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/checkpoint.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 9\u001b[0m tone_color_converter \u001b[38;5;241m=\u001b[39m ToneColorConverter(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mckpt_converter\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/config.json\u001b[39m\u001b[38;5;124m'\u001b[39m, device\u001b[38;5;241m=\u001b[39mdevice)\n",
+ "File \u001b[0;32m/workspace/OpenVoice/api.py:21\u001b[0m, in \u001b[0;36mOpenVoiceBaseClass.__init__\u001b[0;34m(self, config_path, device)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m device:\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available()\n\u001b[0;32m---> 21\u001b[0m hps \u001b[38;5;241m=\u001b[39m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_hparams_from_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m model \u001b[38;5;241m=\u001b[39m SynthesizerTrn(\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mgetattr\u001b[39m(hps, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msymbols\u001b[39m\u001b[38;5;124m'\u001b[39m, [])),\n\u001b[1;32m 25\u001b[0m hps\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mfilter_length \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 26\u001b[0m n_speakers\u001b[38;5;241m=\u001b[39mhps\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mn_speakers,\n\u001b[1;32m 27\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mhps\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 28\u001b[0m )\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 30\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n",
+ "File \u001b[0;32m/workspace/OpenVoice/utils.py:7\u001b[0m, in \u001b[0;36mget_hparams_from_file\u001b[0;34m(config_path)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_hparams_from_file\u001b[39m(config_path):\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 8\u001b[0m data \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[1;32m 9\u001b[0m config \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(data)\n",
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'checkpoints/base_speakers/EN/config.json'"
+ ]
+ }
+ ],
+ "source": [
+ "ckpt_base = 'checkpoints/base_speakers/EN'\n",
+ "ckpt_converter = 'checkpoints/converter'\n",
+ "device = 'cuda:0'\n",
+ "output_dir = 'outputs'\n",
+ "\n",
+ "base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)\n",
+ "base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')\n",
+ "\n",
+ "tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)\n",
+ "tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')\n",
+ "\n",
+ "os.makedirs(output_dir, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7f67740c",
+ "metadata": {},
+ "source": [
+ "### Obtain Tone Color Embedding"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8add279",
+ "metadata": {},
+ "source": [
+ "The `source_se` is the tone color embedding of the base speaker. \n",
+ "It is an average of multiple sentences generated by the base speaker. We directly provide the result here but\n",
+ "the readers feel free to extract `source_se` by themselves."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "63ff6273",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4f71fcc3",
+ "metadata": {},
+ "source": [
+ "The `reference_speaker.mp3` below points to the short audio clip of the reference whose voice we want to clone. We provide an example here. If you use your own reference speakers, please **make sure each speaker has a unique filename.** The `se_extractor` will save the `targeted_se` using the filename of the audio and **will not automatically overwrite.**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "55105eae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reference_speaker = 'resources/example_reference.mp3'\n",
+ "target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a40284aa",
+ "metadata": {},
+ "source": [
+ "### Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "73dc1259",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_path = f'{output_dir}/output_en_default.wav'\n",
+ "\n",
+ "# Run the base speaker tts\n",
+ "text = \"This audio is generated by OpenVoice.\"\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)\n",
+ "\n",
+ "# Run the tone color converter\n",
+ "encode_message = \"@MyShell\"\n",
+ "tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e3ea28a",
+ "metadata": {},
+ "source": [
+ "**Try with different styles and speed.** The style can be controlled by the `speaker` parameter in the `base_speaker_tts.tts` method. Available choices: friendly, cheerful, excited, sad, angry, terrified, shouting, whispering. Note that the tone color embedding need to be updated. The speed can be controlled by the `speed` parameter. Let's try whispering with speed 0.9."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fd022d38",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "source_se = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)\n",
+ "save_path = f'{output_dir}/output_whispering.wav'\n",
+ "\n",
+ "# Run the base speaker tts\n",
+ "text = \"This audio is generated by OpenVoice with a half-performance model.\"\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "base_speaker_tts.tts(text, src_path, speaker='whispering', language='English', speed=0.9)\n",
+ "\n",
+ "# Run the tone color converter\n",
+ "encode_message = \"@MyShell\"\n",
+ "tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5fcfc70b",
+ "metadata": {},
+ "source": [
+ "**Try with different languages.** OpenVoice can achieve multi-lingual voice cloning by simply replace the base speaker. We provide an example with a Chinese base speaker here and we encourage the readers to try `demo_part2.ipynb` for a detailed demo."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a71d1387",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "ckpt_base = 'checkpoints/base_speakers/ZH'\n",
+ "base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)\n",
+ "base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')\n",
+ "\n",
+ "source_se = torch.load(f'{ckpt_base}/zh_default_se.pth').to(device)\n",
+ "save_path = f'{output_dir}/output_chinese.wav'\n",
+ "\n",
+ "# Run the base speaker tts\n",
+ "text = \"今天天气真好,我们一起出去吃饭吧。\"\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "base_speaker_tts.tts(text, src_path, speaker='default', language='Chinese', speed=1.0)\n",
+ "\n",
+ "# Run the tone color converter\n",
+ "encode_message = \"@MyShell\"\n",
+ "tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8e513094",
+ "metadata": {},
+ "source": [
+ "**Tech for good.** For people who will deploy OpenVoice for public usage: We offer you the option to add watermark to avoid potential misuse. Please see the ToneColorConverter class. **MyShell reserves the ability to detect whether an audio is generated by OpenVoice**, no matter whether the watermark is added or not."
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "9d70c38e1c0b038dbdffdaa4f8bfa1f6767c43760905c87a9fbe7800d18c6c35"
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/demo_part2.ipynb b/demo_part2.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ac41c8a798f2b0c5283968aaa78f8e69e99946f4
--- /dev/null
+++ b/demo_part2.ipynb
@@ -0,0 +1,195 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b6ee1ede",
+ "metadata": {},
+ "source": [
+ "## Cross-Lingual Voice Clone Demo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b7f043ee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import se_extractor\n",
+ "from api import ToneColorConverter"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15116b59",
+ "metadata": {},
+ "source": [
+ "### Initialization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "aacad912",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ckpt_converter = 'checkpoints/converter'\n",
+ "device = 'cuda:0'\n",
+ "output_dir = 'outputs'\n",
+ "\n",
+ "tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)\n",
+ "tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')\n",
+ "\n",
+ "os.makedirs(output_dir, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3db80fcf",
+ "metadata": {},
+ "source": [
+ "In this demo, we will use OpenAI TTS as the base speaker to produce multi-lingual speech audio. The users can flexibly change the base speaker according to their own needs. Please create a file named `.env` and place OpenAI key as `OPENAI_API_KEY=xxx`. We have also provided a Chinese base speaker model (see `demo_part1.ipynb`)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3b245ca3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from openai import OpenAI\n",
+ "from dotenv import load_dotenv\n",
+ "\n",
+ "# Please create a file named .env and place your\n",
+ "# OpenAI key as OPENAI_API_KEY=xxx\n",
+ "load_dotenv() \n",
+ "\n",
+ "client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\"))\n",
+ "\n",
+ "response = client.audio.speech.create(\n",
+ " model=\"tts-1\",\n",
+ " voice=\"nova\",\n",
+ " input=\"This audio will be used to extract the base speaker tone color embedding. \" + \\\n",
+ " \"Typically a very short audio should be sufficient, but increasing the audio \" + \\\n",
+ " \"length will also improve the output audio quality.\"\n",
+ ")\n",
+ "\n",
+ "response.stream_to_file(f\"{output_dir}/openai_source_output.mp3\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7f67740c",
+ "metadata": {},
+ "source": [
+ "### Obtain Tone Color Embedding"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8add279",
+ "metadata": {},
+ "source": [
+ "The `source_se` is the tone color embedding of the base speaker. \n",
+ "It is an average for multiple sentences with multiple emotions\n",
+ "of the base speaker. We directly provide the result here but\n",
+ "the readers feel free to extract `source_se` by themselves."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "63ff6273",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_speaker = f\"{output_dir}/openai_source_output.mp3\"\n",
+ "source_se, audio_name = se_extractor.get_se(base_speaker, tone_color_converter, vad=True)\n",
+ "\n",
+ "reference_speaker = 'resources/example_reference.mp3'\n",
+ "target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, vad=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a40284aa",
+ "metadata": {},
+ "source": [
+ "### Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "73dc1259",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Run the base speaker tts\n",
+ "text = [\n",
+ " \"MyShell is a decentralized and comprehensive platform for discovering, creating, and staking AI-native apps.\",\n",
+ " \"MyShell es una plataforma descentralizada y completa para descubrir, crear y apostar por aplicaciones nativas de IA.\",\n",
+ " \"MyShell est une plateforme décentralisée et complète pour découvrir, créer et miser sur des applications natives d'IA.\",\n",
+ " \"MyShell ist eine dezentralisierte und umfassende Plattform zum Entdecken, Erstellen und Staken von KI-nativen Apps.\",\n",
+ " \"MyShell è una piattaforma decentralizzata e completa per scoprire, creare e scommettere su app native di intelligenza artificiale.\",\n",
+ " \"MyShellは、AIネイティブアプリの発見、作成、およびステーキングのための分散型かつ包括的なプラットフォームです。\",\n",
+ " \"MyShell — это децентрализованная и всеобъемлющая платформа для обнаружения, создания и стейкинга AI-ориентированных приложений.\",\n",
+ " \"MyShell هي منصة لامركزية وشاملة لاكتشاف وإنشاء ورهان تطبيقات الذكاء الاصطناعي الأصلية.\",\n",
+ " \"MyShell是一个去中心化且全面的平台,用于发现、创建和投资AI原生应用程序。\",\n",
+ " \"MyShell एक विकेंद्रीकृत और व्यापक मंच है, जो AI-मूल ऐप्स की खोज, सृजन और स्टेकिंग के लिए है।\",\n",
+ " \"MyShell é uma plataforma descentralizada e abrangente para descobrir, criar e apostar em aplicativos nativos de IA.\"\n",
+ "]\n",
+ "src_path = f'{output_dir}/tmp.wav'\n",
+ "\n",
+ "for i, t in enumerate(text):\n",
+ "\n",
+ " response = client.audio.speech.create(\n",
+ " model=\"tts-1\",\n",
+ " voice=\"alloy\",\n",
+ " input=t,\n",
+ " )\n",
+ "\n",
+ " response.stream_to_file(src_path)\n",
+ "\n",
+ " save_path = f'{output_dir}/output_crosslingual_{i}.wav'\n",
+ "\n",
+ " # Run the tone color converter\n",
+ " encode_message = \"@MyShell\"\n",
+ " tone_color_converter.convert(\n",
+ " audio_src_path=src_path, \n",
+ " src_se=source_se, \n",
+ " tgt_se=target_se, \n",
+ " output_path=save_path,\n",
+ " message=encode_message)"
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "9d70c38e1c0b038dbdffdaa4f8bfa1f6767c43760905c87a9fbe7800d18c6c35"
+ },
+ "kernelspec": {
+ "display_name": "Python 3.9.18 ('openvoice')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/mel_processing.py b/mel_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..822d7f19062497b198ae54554a3ab828c10147ad
--- /dev/null
+++ b/mel_processing.py
@@ -0,0 +1,183 @@
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+
+MAX_WAV_VALUE = 32768.0
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
+ if torch.min(y) < -1.1:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.1:
+ print("max value is ", torch.max(y))
+
+ global hann_window
+ dtype_device = str(y.dtype) + "_" + str(y.device)
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
+ dtype=y.dtype, device=y.device
+ )
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[wnsize_dtype_device],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=False,
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+ return spec
+
+
+def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
+ # if torch.min(y) < -1.:
+ # print('min value is ', torch.min(y))
+ # if torch.max(y) > 1.:
+ # print('max value is ', torch.max(y))
+
+ global hann_window
+ dtype_device = str(y.dtype) + '_' + str(y.device)
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+
+ # ******************** original ************************#
+ # y = y.squeeze(1)
+ # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
+ # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
+
+ # ******************** ConvSTFT ************************#
+ freq_cutoff = n_fft // 2 + 1
+ fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
+ forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
+ forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
+
+ import torch.nn.functional as F
+
+ # if center:
+ # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
+ assert center is False
+
+ forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
+ spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
+
+
+ # ******************** Verification ************************#
+ spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
+ assert torch.allclose(spec1, spec2, atol=1e-4)
+
+ spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
+ return spec
+
+
+def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
+ global mel_basis
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
+ dtype=spec.dtype, device=spec.device
+ )
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+ return spec
+
+
+def mel_spectrogram_torch(
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
+):
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ global mel_basis, hann_window
+ dtype_device = str(y.dtype) + "_" + str(y.device)
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
+ dtype=y.dtype, device=y.device
+ )
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
+ dtype=y.dtype, device=y.device
+ )
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[wnsize_dtype_device],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=False,
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
\ No newline at end of file
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..a50d01383357c4c884e760f6e4797ce5ef23c0c9
--- /dev/null
+++ b/models.py
@@ -0,0 +1,497 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import commons
+import modules
+import attentions
+
+from torch.nn import Conv1d, ConvTranspose1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from commons import init_weights, get_padding
+
+
+class TextEncoder(nn.Module):
+ def __init__(self,
+ n_vocab,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout):
+ super().__init__()
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+
+ self.encoder = attentions.Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout)
+ self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths):
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ x = self.encoder(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ return x, m, logs, x_mask
+
+
+class DurationPredictor(nn.Module):
+ def __init__(
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.gin_channels = gin_channels
+
+ self.drop = nn.Dropout(p_dropout)
+ self.conv_1 = nn.Conv1d(
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
+ )
+ self.norm_1 = modules.LayerNorm(filter_channels)
+ self.conv_2 = nn.Conv1d(
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
+ )
+ self.norm_2 = modules.LayerNorm(filter_channels)
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
+
+ def forward(self, x, x_mask, g=None):
+ x = torch.detach(x)
+ if g is not None:
+ g = torch.detach(g)
+ x = x + self.cond(g)
+ x = self.conv_1(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_1(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_2(x)
+ x = self.drop(x)
+ x = self.proj(x * x_mask)
+ return x * x_mask
+
+class StochasticDurationPredictor(nn.Module):
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
+ super().__init__()
+ filter_channels = in_channels # it needs to be removed from future version.
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.log_flow = modules.Log()
+ self.flows = nn.ModuleList()
+ self.flows.append(modules.ElementwiseAffine(2))
+ for i in range(n_flows):
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
+ self.flows.append(modules.Flip())
+
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
+ self.post_flows = nn.ModuleList()
+ self.post_flows.append(modules.ElementwiseAffine(2))
+ for i in range(4):
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
+ self.post_flows.append(modules.Flip())
+
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
+
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
+ x = torch.detach(x)
+ x = self.pre(x)
+ if g is not None:
+ g = torch.detach(g)
+ x = x + self.cond(g)
+ x = self.convs(x, x_mask)
+ x = self.proj(x) * x_mask
+
+ if not reverse:
+ flows = self.flows
+ assert w is not None
+
+ logdet_tot_q = 0
+ h_w = self.post_pre(w)
+ h_w = self.post_convs(h_w, x_mask)
+ h_w = self.post_proj(h_w) * x_mask
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
+ z_q = e_q
+ for flow in self.post_flows:
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
+ logdet_tot_q += logdet_q
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
+ u = torch.sigmoid(z_u) * x_mask
+ z0 = (w - u) * x_mask
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
+ logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
+
+ logdet_tot = 0
+ z0, logdet = self.log_flow(z0, x_mask)
+ logdet_tot += logdet
+ z = torch.cat([z0, z1], 1)
+ for flow in flows:
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
+ logdet_tot = logdet_tot + logdet
+ nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
+ return nll + logq # [b]
+ else:
+ flows = list(reversed(self.flows))
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
+ for flow in flows:
+ z = flow(z, x_mask, g=x, reverse=reverse)
+ z0, z1 = torch.split(z, [1, 1], 1)
+ logw = z0
+ return logw
+
+class PosteriorEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.enc = modules.WN(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths, g=None, tau=1.0):
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+ x.dtype
+ )
+ x = self.pre(x) * x_mask
+ x = self.enc(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask
+
+
+class Generator(torch.nn.Module):
+ def __init__(
+ self,
+ initial_channel,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=0,
+ ):
+ super(Generator, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.conv_pre = Conv1d(
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
+ )
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(ch, k, d))
+
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+ self.ups.apply(init_weights)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+ def forward(self, x, g=None):
+ x = self.conv_pre(x)
+ if g is not None:
+ x = x + self.cond(g)
+
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print("Removing weight norm...")
+ for layer in self.ups:
+ remove_weight_norm(layer)
+ for layer in self.resblocks:
+ layer.remove_weight_norm()
+
+
+class ReferenceEncoder(nn.Module):
+ """
+ inputs --- [N, Ty/r, n_mels*r] mels
+ outputs --- [N, ref_enc_gru_size]
+ """
+
+ def __init__(self, spec_channels, gin_channels=0, layernorm=True):
+ super().__init__()
+ self.spec_channels = spec_channels
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
+ K = len(ref_enc_filters)
+ filters = [1] + ref_enc_filters
+ convs = [
+ weight_norm(
+ nn.Conv2d(
+ in_channels=filters[i],
+ out_channels=filters[i + 1],
+ kernel_size=(3, 3),
+ stride=(2, 2),
+ padding=(1, 1),
+ )
+ )
+ for i in range(K)
+ ]
+ self.convs = nn.ModuleList(convs)
+
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
+ self.gru = nn.GRU(
+ input_size=ref_enc_filters[-1] * out_channels,
+ hidden_size=256 // 2,
+ batch_first=True,
+ )
+ self.proj = nn.Linear(128, gin_channels)
+ if layernorm:
+ self.layernorm = nn.LayerNorm(self.spec_channels)
+ else:
+ self.layernorm = None
+
+ def forward(self, inputs, mask=None):
+ N = inputs.size(0)
+
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
+ if self.layernorm is not None:
+ out = self.layernorm(out)
+
+ for conv in self.convs:
+ out = conv(out)
+ # out = wn(out)
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
+
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
+ T = out.size(1)
+ N = out.size(0)
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
+
+ self.gru.flatten_parameters()
+ memory, out = self.gru(out) # out --- [1, N, 128]
+
+ return self.proj(out.squeeze(0))
+
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
+ for i in range(n_convs):
+ L = (L - kernel_size + 2 * pad) // stride + 1
+ return L
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0):
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
+ self.flows.append(modules.Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, reverse=reverse)
+ return x
+
+class SynthesizerTrn(nn.Module):
+ """
+ Synthesizer for Training
+ """
+
+ def __init__(
+ self,
+ n_vocab,
+ spec_channels,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ n_speakers=256,
+ gin_channels=256,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.dec = Generator(
+ inter_channels,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=gin_channels,
+ )
+ self.enc_q = PosteriorEncoder(
+ spec_channels,
+ inter_channels,
+ hidden_channels,
+ 5,
+ 1,
+ 16,
+ gin_channels=gin_channels,
+ )
+
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
+
+ self.n_speakers = n_speakers
+ if n_speakers == 0:
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
+ else:
+ self.enc_p = TextEncoder(n_vocab,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout)
+ self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
+ self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
+
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
+ if self.n_speakers > 0:
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
+ else:
+ g = None
+
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
+ + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
+
+ w = torch.exp(logw) * x_mask * length_scale
+ w_ceil = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = commons.generate_path(w_ceil, attn_mask)
+
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
+
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
+ o = self.dec((z * y_mask)[:,:,:max_len], g=g)
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
+
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
+ g_src = sid_src
+ g_tgt = sid_tgt
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
+ z_p = self.flow(z, y_mask, g=g_src)
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
+ return o_hat, y_mask, (z, z_p, z_hat)
diff --git a/modules.py b/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0461728edee2171f43b0a99f082d1ca6ab7cb3f
--- /dev/null
+++ b/modules.py
@@ -0,0 +1,598 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.nn import Conv1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+import commons
+from commons import init_weights, get_padding
+from transforms import piecewise_rational_quadratic_transform
+from attentions import Encoder
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels,
+ out_channels,
+ kernel_size,
+ n_layers,
+ p_dropout,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(
+ nn.Conv1d(
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(
+ nn.Conv1d(
+ hidden_channels,
+ hidden_channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class DDSConv(nn.Module):
+ """
+ Dilated and Depth-Separable Convolution
+ """
+
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.drop = nn.Dropout(p_dropout)
+ self.convs_sep = nn.ModuleList()
+ self.convs_1x1 = nn.ModuleList()
+ self.norms_1 = nn.ModuleList()
+ self.norms_2 = nn.ModuleList()
+ for i in range(n_layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs_sep.append(
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ )
+ )
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+ self.norms_1.append(LayerNorm(channels))
+ self.norms_2.append(LayerNorm(channels))
+
+ def forward(self, x, x_mask, g=None):
+ if g is not None:
+ x = x + g
+ for i in range(self.n_layers):
+ y = self.convs_sep[i](x * x_mask)
+ y = self.norms_1[i](y)
+ y = F.gelu(y)
+ y = self.convs_1x1[i](y)
+ y = self.norms_2[i](y)
+ y = F.gelu(y)
+ y = self.drop(y)
+ x = x + y
+ return x * x_mask
+
+
+class WN(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ p_dropout=0,
+ ):
+ super(WN, self).__init__()
+ assert kernel_size % 2 == 1
+ self.hidden_channels = hidden_channels
+ self.kernel_size = (kernel_size,)
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ cond_layer = torch.nn.Conv1d(
+ gin_channels, 2 * hidden_channels * n_layers, 1
+ )
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
+
+ for i in range(n_layers):
+ dilation = dilation_rate**i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(
+ hidden_channels,
+ 2 * hidden_channels,
+ kernel_size,
+ dilation=dilation,
+ padding=padding,
+ )
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ torch.nn.utils.remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ torch.nn.utils.remove_weight_norm(l)
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c2(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Log(nn.Module):
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ if not reverse:
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+ else:
+ return x
+
+
+class ElementwiseAffine(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.m = nn.Parameter(torch.zeros(channels, 1))
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
+
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False,
+ ):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ self.enc = WN(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=p_dropout,
+ gin_channels=gin_channels,
+ )
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ return x
+
+
+class ConvFlow(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ filter_channels,
+ kernel_size,
+ n_layers,
+ num_bins=10,
+ tail_bound=5.0,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.num_bins = num_bins
+ self.tail_bound = tail_bound
+ self.half_channels = in_channels // 2
+
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
+ self.proj = nn.Conv1d(
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0)
+ h = self.convs(h, x_mask, g=g)
+ h = self.proj(h) * x_mask
+
+ b, c, t = x0.shape
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
+
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
+ self.filter_channels
+ )
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
+
+ x1, logabsdet = piecewise_rational_quadratic_transform(
+ x1,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=reverse,
+ tails="linear",
+ tail_bound=self.tail_bound,
+ )
+
+ x = torch.cat([x0, x1], 1) * x_mask
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
+ if not reverse:
+ return x, logdet
+ else:
+ return x
+
+
+class TransformerCouplingLayer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ n_layers,
+ n_heads,
+ p_dropout=0,
+ filter_channels=0,
+ mean_only=False,
+ wn_sharing_parameter=None,
+ gin_channels=0,
+ ):
+ assert n_layers == 3, n_layers
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ self.enc = (
+ Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ isflow=True,
+ gin_channels=gin_channels,
+ )
+ if wn_sharing_parameter is None
+ else wn_sharing_parameter
+ )
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ return x
+
+ x1, logabsdet = piecewise_rational_quadratic_transform(
+ x1,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=reverse,
+ tails="linear",
+ tail_bound=self.tail_bound,
+ )
+
+ x = torch.cat([x0, x1], 1) * x_mask
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
+ if not reverse:
+ return x, logdet
+ else:
+ return x
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a4139330983dccb0178328577cfc92182cc2c5bc
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+librosa==0.9.1
+faster-whisper==0.9.0
+pydub==0.25.1
+wavmark==0.0.2
+numpy==1.22.0
+eng_to_ipa==0.0.2
+inflect==7.0.0
+unidecode==1.3.7
+whisper-timestamped==1.14.2
+openai
+python-dotenv
+pypinyin
+jieba
+cn2an
diff --git a/resources/example_reference.mp3 b/resources/example_reference.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..d9ff7e3f9df15bceb7ad9abe0ee726992635721b
Binary files /dev/null and b/resources/example_reference.mp3 differ
diff --git a/resources/framework.jpg b/resources/framework.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..26fb4b9b0fbdcc22183b18a5ed21e24651e6ae0b
Binary files /dev/null and b/resources/framework.jpg differ
diff --git a/resources/lepton.jpg b/resources/lepton.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5bd5601481f38589f2eb2954396911bd0aae74ad
Binary files /dev/null and b/resources/lepton.jpg differ
diff --git a/resources/myshell.jpg b/resources/myshell.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..501d7ab6b02fe714ab25e9c0d954d2a0684268fa
Binary files /dev/null and b/resources/myshell.jpg differ
diff --git a/resources/openvoicelogo.jpg b/resources/openvoicelogo.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1bc9b9e38bae8ee38e998f5136a1e7a9ed967e80
Binary files /dev/null and b/resources/openvoicelogo.jpg differ
diff --git a/se_extractor.py b/se_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ea2b9dbbba9fc6ea81fb4c5cdb6213f400c8aaf
--- /dev/null
+++ b/se_extractor.py
@@ -0,0 +1,139 @@
+import os
+import glob
+import torch
+from glob import glob
+import numpy as np
+from pydub import AudioSegment
+from faster_whisper import WhisperModel
+from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
+
+model_size = "medium"
+# Run on GPU with FP16
+model = None
+def split_audio_whisper(audio_path, target_dir='processed'):
+ global model
+ if model is None:
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
+ audio = AudioSegment.from_file(audio_path)
+ max_len = len(audio)
+
+ audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
+ target_folder = os.path.join(target_dir, audio_name)
+
+ segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
+ segments = list(segments)
+
+ # create directory
+ os.makedirs(target_folder, exist_ok=True)
+ wavs_folder = os.path.join(target_folder, 'wavs')
+ os.makedirs(wavs_folder, exist_ok=True)
+
+ # segments
+ s_ind = 0
+ start_time = None
+
+ for k, w in enumerate(segments):
+ # process with the time
+ if k == 0:
+ start_time = max(0, w.start)
+
+ end_time = w.end
+
+ # calculate confidence
+ if len(w.words) > 0:
+ confidence = sum([s.probability for s in w.words]) / len(w.words)
+ else:
+ confidence = 0.
+ # clean text
+ text = w.text.replace('...', '')
+
+ # left 0.08s for each audios
+ audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
+
+ # segment file name
+ fname = f"{audio_name}_seg{s_ind}.wav"
+
+ # filter out the segment shorter than 1.5s and longer than 20s
+ save = audio_seg.duration_seconds > 1.5 and \
+ audio_seg.duration_seconds < 20. and \
+ len(text) >= 2 and len(text) < 200
+
+ if save:
+ output_file = os.path.join(wavs_folder, fname)
+ audio_seg.export(output_file, format='wav')
+
+ if k < len(segments) - 1:
+ start_time = max(0, segments[k+1].start - 0.08)
+
+ s_ind = s_ind + 1
+ return wavs_folder
+
+
+def split_audio_vad(audio_path, target_dir, split_seconds=10.0):
+ SAMPLE_RATE = 16000
+ audio_vad = get_audio_tensor(audio_path)
+ segments = get_vad_segments(
+ audio_vad,
+ output_sample=True,
+ min_speech_duration=0.1,
+ min_silence_duration=1,
+ method="silero",
+ )
+ segments = [(seg["start"], seg["end"]) for seg in segments]
+ segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
+ print(segments)
+ audio_active = AudioSegment.silent(duration=0)
+ audio = AudioSegment.from_file(audio_path)
+
+ for start_time, end_time in segments:
+ audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
+
+ audio_dur = audio_active.duration_seconds
+ print(f'after vad: dur = {audio_dur}')
+ audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
+ target_folder = os.path.join(target_dir, audio_name)
+ wavs_folder = os.path.join(target_folder, 'wavs')
+ os.makedirs(wavs_folder, exist_ok=True)
+ start_time = 0.
+ count = 0
+ num_splits = int(np.round(audio_dur / split_seconds))
+ assert num_splits > 0, 'input audio is too short'
+ interval = audio_dur / num_splits
+
+ for i in range(num_splits):
+ end_time = min(start_time + interval, audio_dur)
+ if i == num_splits - 1:
+ end_time = audio_dur
+ output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
+ audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
+ audio_seg.export(output_file, format='wav')
+ start_time = end_time
+ count += 1
+ return wavs_folder
+
+
+
+
+
+def get_se(audio_path, vc_model, target_dir='processed', vad=True):
+ device = vc_model.device
+
+ audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
+ se_path = os.path.join(target_dir, audio_name, 'se.pth')
+
+ if os.path.isfile(se_path):
+ se = torch.load(se_path).to(device)
+ return se, audio_name
+ if os.path.isdir(audio_path):
+ wavs_folder = audio_path
+ elif vad:
+ wavs_folder = split_audio_vad(audio_path, target_dir)
+ else:
+ wavs_folder = split_audio_whisper(audio_path, target_dir)
+
+ audio_segs = glob(f'{wavs_folder}/*.wav')
+ if len(audio_segs) == 0:
+ raise NotImplementedError('No audio segments found!')
+
+ return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
+
diff --git a/text/__init__.py b/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f97622cb99b2f6a3b3528dfc447848e696f33c75
--- /dev/null
+++ b/text/__init__.py
@@ -0,0 +1,79 @@
+""" from https://github.com/keithito/tacotron """
+from text import cleaners
+from text.symbols import symbols
+
+
+# Mappings from symbol to numeric ID and vice versa:
+_symbol_to_id = {s: i for i, s in enumerate(symbols)}
+_id_to_symbol = {i: s for i, s in enumerate(symbols)}
+
+
+def text_to_sequence(text, symbols, cleaner_names):
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+ Args:
+ text: string to convert to a sequence
+ cleaner_names: names of the cleaner functions to run the text through
+ Returns:
+ List of integers corresponding to the symbols in the text
+ '''
+ sequence = []
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
+ clean_text = _clean_text(text, cleaner_names)
+ print(clean_text)
+ print(f" length:{len(clean_text)}")
+ for symbol in clean_text:
+ if symbol not in symbol_to_id.keys():
+ continue
+ symbol_id = symbol_to_id[symbol]
+ sequence += [symbol_id]
+ print(f" length:{len(sequence)}")
+ return sequence
+
+
+def cleaned_text_to_sequence(cleaned_text, symbols):
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+ Args:
+ text: string to convert to a sequence
+ Returns:
+ List of integers corresponding to the symbols in the text
+ '''
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
+ sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
+ return sequence
+
+
+
+from text.symbols import language_tone_start_map
+def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages):
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+ Args:
+ text: string to convert to a sequence
+ Returns:
+ List of integers corresponding to the symbols in the text
+ """
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
+ language_id_map = {s: i for i, s in enumerate(languages)}
+ phones = [symbol_to_id[symbol] for symbol in cleaned_text]
+ tone_start = language_tone_start_map[language]
+ tones = [i + tone_start for i in tones]
+ lang_id = language_id_map[language]
+ lang_ids = [lang_id for i in phones]
+ return phones, tones, lang_ids
+
+
+def sequence_to_text(sequence):
+ '''Converts a sequence of IDs back to a string'''
+ result = ''
+ for symbol_id in sequence:
+ s = _id_to_symbol[symbol_id]
+ result += s
+ return result
+
+
+def _clean_text(text, cleaner_names):
+ for name in cleaner_names:
+ cleaner = getattr(cleaners, name)
+ if not cleaner:
+ raise Exception('Unknown cleaner: %s' % name)
+ text = cleaner(text)
+ return text
diff --git a/text/__pycache__/__init__.cpython-310.pyc b/text/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0894d3fda10df9c669bb7dd5dcd8ca1e3d9168b2
Binary files /dev/null and b/text/__pycache__/__init__.cpython-310.pyc differ
diff --git a/text/__pycache__/cleaners.cpython-310.pyc b/text/__pycache__/cleaners.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36b645343ce8ff1d3997d0d886e356306b4c5b99
Binary files /dev/null and b/text/__pycache__/cleaners.cpython-310.pyc differ
diff --git a/text/__pycache__/english.cpython-310.pyc b/text/__pycache__/english.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..353a607a00bf95f9eb1efc76d7b4d72c3cf8101b
Binary files /dev/null and b/text/__pycache__/english.cpython-310.pyc differ
diff --git a/text/__pycache__/mandarin.cpython-310.pyc b/text/__pycache__/mandarin.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..228681319ca5d2ea346fed8478501c9750425732
Binary files /dev/null and b/text/__pycache__/mandarin.cpython-310.pyc differ
diff --git a/text/__pycache__/symbols.cpython-310.pyc b/text/__pycache__/symbols.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..636ee6f4ccbf21cc353935e012e6fe17578ab69f
Binary files /dev/null and b/text/__pycache__/symbols.cpython-310.pyc differ
diff --git a/text/cleaners.py b/text/cleaners.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b0a0bfd69a8c6c37ee68866cef064730b9fbd6
--- /dev/null
+++ b/text/cleaners.py
@@ -0,0 +1,16 @@
+import re
+from text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
+from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
+
+def cjke_cleaners2(text):
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
+ lambda x: chinese_to_ipa(x.group(1))+' ', text)
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
+ lambda x: japanese_to_ipa2(x.group(1))+' ', text)
+ text = re.sub(r'\[KO\](.*?)\[KO\]',
+ lambda x: korean_to_ipa(x.group(1))+' ', text)
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
+ lambda x: english_to_ipa2(x.group(1))+' ', text)
+ text = re.sub(r'\s+$', '', text)
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
+ return text
\ No newline at end of file
diff --git a/text/english.py b/text/english.py
new file mode 100644
index 0000000000000000000000000000000000000000..6817392ba8a9eb830351de89fb7afc5ad72f5e42
--- /dev/null
+++ b/text/english.py
@@ -0,0 +1,188 @@
+""" from https://github.com/keithito/tacotron """
+
+'''
+Cleaners are transformations that run over the input text at both training and eval time.
+
+Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
+hyperparameter. Some cleaners are English-specific. You'll typically want to use:
+ 1. "english_cleaners" for English text
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
+ the symbols in symbols.py to match your data).
+'''
+
+
+# Regular expression matching whitespace:
+
+
+import re
+import inflect
+from unidecode import unidecode
+import eng_to_ipa as ipa
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
+_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
+_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
+_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
+_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
+_number_re = re.compile(r'[0-9]+')
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('mrs', 'misess'),
+ ('mr', 'mister'),
+ ('dr', 'doctor'),
+ ('st', 'saint'),
+ ('co', 'company'),
+ ('jr', 'junior'),
+ ('maj', 'major'),
+ ('gen', 'general'),
+ ('drs', 'doctors'),
+ ('rev', 'reverend'),
+ ('lt', 'lieutenant'),
+ ('hon', 'honorable'),
+ ('sgt', 'sergeant'),
+ ('capt', 'captain'),
+ ('esq', 'esquire'),
+ ('ltd', 'limited'),
+ ('col', 'colonel'),
+ ('ft', 'fort'),
+]]
+
+
+# List of (ipa, lazy ipa) pairs:
+_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('r', 'ɹ'),
+ ('æ', 'e'),
+ ('ɑ', 'a'),
+ ('ɔ', 'o'),
+ ('ð', 'z'),
+ ('θ', 's'),
+ ('ɛ', 'e'),
+ ('ɪ', 'i'),
+ ('ʊ', 'u'),
+ ('ʒ', 'ʥ'),
+ ('ʤ', 'ʥ'),
+ ('ˈ', '↓'),
+]]
+
+# List of (ipa, lazy ipa2) pairs:
+_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('r', 'ɹ'),
+ ('ð', 'z'),
+ ('θ', 's'),
+ ('ʒ', 'ʑ'),
+ ('ʤ', 'dʑ'),
+ ('ˈ', '↓'),
+]]
+
+# List of (ipa, ipa2) pairs
+_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('r', 'ɹ'),
+ ('ʤ', 'dʒ'),
+ ('ʧ', 'tʃ')
+]]
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def collapse_whitespace(text):
+ return re.sub(r'\s+', ' ', text)
+
+
+def _remove_commas(m):
+ return m.group(1).replace(',', '')
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace('.', ' point ')
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split('.')
+ if len(parts) > 2:
+ return match + ' dollars' # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ return '%s %s' % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s' % (cents, cent_unit)
+ else:
+ return 'zero dollars'
+
+
+def _expand_ordinal(m):
+ return _inflect.number_to_words(m.group(0))
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return 'two thousand'
+ elif num > 2000 and num < 2010:
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
+ elif num % 100 == 0:
+ return _inflect.number_to_words(num // 100) + ' hundred'
+ else:
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
+ else:
+ return _inflect.number_to_words(num, andword='')
+
+
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r'\1 pounds', text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
+
+
+def mark_dark_l(text):
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
+
+
+def english_to_ipa(text):
+ text = unidecode(text).lower()
+ text = expand_abbreviations(text)
+ text = normalize_numbers(text)
+ phonemes = ipa.convert(text)
+ phonemes = collapse_whitespace(phonemes)
+ return phonemes
+
+
+def english_to_lazy_ipa(text):
+ text = english_to_ipa(text)
+ for regex, replacement in _lazy_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def english_to_ipa2(text):
+ text = english_to_ipa(text)
+ text = mark_dark_l(text)
+ for regex, replacement in _ipa_to_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text.replace('...', '…')
+
+
+def english_to_lazy_ipa2(text):
+ text = english_to_ipa(text)
+ for regex, replacement in _lazy_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text
diff --git a/text/mandarin.py b/text/mandarin.py
new file mode 100644
index 0000000000000000000000000000000000000000..162e1b912dabec4b448ccd3d00d56306f82ce076
--- /dev/null
+++ b/text/mandarin.py
@@ -0,0 +1,326 @@
+import os
+import sys
+import re
+from pypinyin import lazy_pinyin, BOPOMOFO
+import jieba
+import cn2an
+import logging
+
+
+# List of (Latin alphabet, bopomofo) pairs:
+_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('a', 'ㄟˉ'),
+ ('b', 'ㄅㄧˋ'),
+ ('c', 'ㄙㄧˉ'),
+ ('d', 'ㄉㄧˋ'),
+ ('e', 'ㄧˋ'),
+ ('f', 'ㄝˊㄈㄨˋ'),
+ ('g', 'ㄐㄧˋ'),
+ ('h', 'ㄝˇㄑㄩˋ'),
+ ('i', 'ㄞˋ'),
+ ('j', 'ㄐㄟˋ'),
+ ('k', 'ㄎㄟˋ'),
+ ('l', 'ㄝˊㄛˋ'),
+ ('m', 'ㄝˊㄇㄨˋ'),
+ ('n', 'ㄣˉ'),
+ ('o', 'ㄡˉ'),
+ ('p', 'ㄆㄧˉ'),
+ ('q', 'ㄎㄧㄡˉ'),
+ ('r', 'ㄚˋ'),
+ ('s', 'ㄝˊㄙˋ'),
+ ('t', 'ㄊㄧˋ'),
+ ('u', 'ㄧㄡˉ'),
+ ('v', 'ㄨㄧˉ'),
+ ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
+ ('x', 'ㄝˉㄎㄨˋㄙˋ'),
+ ('y', 'ㄨㄞˋ'),
+ ('z', 'ㄗㄟˋ')
+]]
+
+# List of (bopomofo, romaji) pairs:
+_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ㄅㄛ', 'p⁼wo'),
+ ('ㄆㄛ', 'pʰwo'),
+ ('ㄇㄛ', 'mwo'),
+ ('ㄈㄛ', 'fwo'),
+ ('ㄅ', 'p⁼'),
+ ('ㄆ', 'pʰ'),
+ ('ㄇ', 'm'),
+ ('ㄈ', 'f'),
+ ('ㄉ', 't⁼'),
+ ('ㄊ', 'tʰ'),
+ ('ㄋ', 'n'),
+ ('ㄌ', 'l'),
+ ('ㄍ', 'k⁼'),
+ ('ㄎ', 'kʰ'),
+ ('ㄏ', 'h'),
+ ('ㄐ', 'ʧ⁼'),
+ ('ㄑ', 'ʧʰ'),
+ ('ㄒ', 'ʃ'),
+ ('ㄓ', 'ʦ`⁼'),
+ ('ㄔ', 'ʦ`ʰ'),
+ ('ㄕ', 's`'),
+ ('ㄖ', 'ɹ`'),
+ ('ㄗ', 'ʦ⁼'),
+ ('ㄘ', 'ʦʰ'),
+ ('ㄙ', 's'),
+ ('ㄚ', 'a'),
+ ('ㄛ', 'o'),
+ ('ㄜ', 'ə'),
+ ('ㄝ', 'e'),
+ ('ㄞ', 'ai'),
+ ('ㄟ', 'ei'),
+ ('ㄠ', 'au'),
+ ('ㄡ', 'ou'),
+ ('ㄧㄢ', 'yeNN'),
+ ('ㄢ', 'aNN'),
+ ('ㄧㄣ', 'iNN'),
+ ('ㄣ', 'əNN'),
+ ('ㄤ', 'aNg'),
+ ('ㄧㄥ', 'iNg'),
+ ('ㄨㄥ', 'uNg'),
+ ('ㄩㄥ', 'yuNg'),
+ ('ㄥ', 'əNg'),
+ ('ㄦ', 'əɻ'),
+ ('ㄧ', 'i'),
+ ('ㄨ', 'u'),
+ ('ㄩ', 'ɥ'),
+ ('ˉ', '→'),
+ ('ˊ', '↑'),
+ ('ˇ', '↓↑'),
+ ('ˋ', '↓'),
+ ('˙', ''),
+ (',', ','),
+ ('。', '.'),
+ ('!', '!'),
+ ('?', '?'),
+ ('—', '-')
+]]
+
+# List of (romaji, ipa) pairs:
+_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('ʃy', 'ʃ'),
+ ('ʧʰy', 'ʧʰ'),
+ ('ʧ⁼y', 'ʧ⁼'),
+ ('NN', 'n'),
+ ('Ng', 'ŋ'),
+ ('y', 'j'),
+ ('h', 'x')
+]]
+
+# List of (bopomofo, ipa) pairs:
+_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ㄅㄛ', 'p⁼wo'),
+ ('ㄆㄛ', 'pʰwo'),
+ ('ㄇㄛ', 'mwo'),
+ ('ㄈㄛ', 'fwo'),
+ ('ㄅ', 'p⁼'),
+ ('ㄆ', 'pʰ'),
+ ('ㄇ', 'm'),
+ ('ㄈ', 'f'),
+ ('ㄉ', 't⁼'),
+ ('ㄊ', 'tʰ'),
+ ('ㄋ', 'n'),
+ ('ㄌ', 'l'),
+ ('ㄍ', 'k⁼'),
+ ('ㄎ', 'kʰ'),
+ ('ㄏ', 'x'),
+ ('ㄐ', 'tʃ⁼'),
+ ('ㄑ', 'tʃʰ'),
+ ('ㄒ', 'ʃ'),
+ ('ㄓ', 'ts`⁼'),
+ ('ㄔ', 'ts`ʰ'),
+ ('ㄕ', 's`'),
+ ('ㄖ', 'ɹ`'),
+ ('ㄗ', 'ts⁼'),
+ ('ㄘ', 'tsʰ'),
+ ('ㄙ', 's'),
+ ('ㄚ', 'a'),
+ ('ㄛ', 'o'),
+ ('ㄜ', 'ə'),
+ ('ㄝ', 'ɛ'),
+ ('ㄞ', 'aɪ'),
+ ('ㄟ', 'eɪ'),
+ ('ㄠ', 'ɑʊ'),
+ ('ㄡ', 'oʊ'),
+ ('ㄧㄢ', 'jɛn'),
+ ('ㄩㄢ', 'ɥæn'),
+ ('ㄢ', 'an'),
+ ('ㄧㄣ', 'in'),
+ ('ㄩㄣ', 'ɥn'),
+ ('ㄣ', 'ən'),
+ ('ㄤ', 'ɑŋ'),
+ ('ㄧㄥ', 'iŋ'),
+ ('ㄨㄥ', 'ʊŋ'),
+ ('ㄩㄥ', 'jʊŋ'),
+ ('ㄥ', 'əŋ'),
+ ('ㄦ', 'əɻ'),
+ ('ㄧ', 'i'),
+ ('ㄨ', 'u'),
+ ('ㄩ', 'ɥ'),
+ ('ˉ', '→'),
+ ('ˊ', '↑'),
+ ('ˇ', '↓↑'),
+ ('ˋ', '↓'),
+ ('˙', ''),
+ (',', ','),
+ ('。', '.'),
+ ('!', '!'),
+ ('?', '?'),
+ ('—', '-')
+]]
+
+# List of (bopomofo, ipa2) pairs:
+_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ㄅㄛ', 'pwo'),
+ ('ㄆㄛ', 'pʰwo'),
+ ('ㄇㄛ', 'mwo'),
+ ('ㄈㄛ', 'fwo'),
+ ('ㄅ', 'p'),
+ ('ㄆ', 'pʰ'),
+ ('ㄇ', 'm'),
+ ('ㄈ', 'f'),
+ ('ㄉ', 't'),
+ ('ㄊ', 'tʰ'),
+ ('ㄋ', 'n'),
+ ('ㄌ', 'l'),
+ ('ㄍ', 'k'),
+ ('ㄎ', 'kʰ'),
+ ('ㄏ', 'h'),
+ ('ㄐ', 'tɕ'),
+ ('ㄑ', 'tɕʰ'),
+ ('ㄒ', 'ɕ'),
+ ('ㄓ', 'tʂ'),
+ ('ㄔ', 'tʂʰ'),
+ ('ㄕ', 'ʂ'),
+ ('ㄖ', 'ɻ'),
+ ('ㄗ', 'ts'),
+ ('ㄘ', 'tsʰ'),
+ ('ㄙ', 's'),
+ ('ㄚ', 'a'),
+ ('ㄛ', 'o'),
+ ('ㄜ', 'ɤ'),
+ ('ㄝ', 'ɛ'),
+ ('ㄞ', 'aɪ'),
+ ('ㄟ', 'eɪ'),
+ ('ㄠ', 'ɑʊ'),
+ ('ㄡ', 'oʊ'),
+ ('ㄧㄢ', 'jɛn'),
+ ('ㄩㄢ', 'yæn'),
+ ('ㄢ', 'an'),
+ ('ㄧㄣ', 'in'),
+ ('ㄩㄣ', 'yn'),
+ ('ㄣ', 'ən'),
+ ('ㄤ', 'ɑŋ'),
+ ('ㄧㄥ', 'iŋ'),
+ ('ㄨㄥ', 'ʊŋ'),
+ ('ㄩㄥ', 'jʊŋ'),
+ ('ㄥ', 'ɤŋ'),
+ ('ㄦ', 'əɻ'),
+ ('ㄧ', 'i'),
+ ('ㄨ', 'u'),
+ ('ㄩ', 'y'),
+ ('ˉ', '˥'),
+ ('ˊ', '˧˥'),
+ ('ˇ', '˨˩˦'),
+ ('ˋ', '˥˩'),
+ ('˙', ''),
+ (',', ','),
+ ('。', '.'),
+ ('!', '!'),
+ ('?', '?'),
+ ('—', '-')
+]]
+
+
+def number_to_chinese(text):
+ numbers = re.findall(r'\d+(?:\.?\d+)?', text)
+ for number in numbers:
+ text = text.replace(number, cn2an.an2cn(number), 1)
+ return text
+
+
+def chinese_to_bopomofo(text):
+ text = text.replace('、', ',').replace(';', ',').replace(':', ',')
+ words = jieba.lcut(text, cut_all=False)
+ text = ''
+ for word in words:
+ bopomofos = lazy_pinyin(word, BOPOMOFO)
+ if not re.search('[\u4e00-\u9fff]', word):
+ text += word
+ continue
+ for i in range(len(bopomofos)):
+ bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
+ if text != '':
+ text += ' '
+ text += ''.join(bopomofos)
+ return text
+
+
+def latin_to_bopomofo(text):
+ for regex, replacement in _latin_to_bopomofo:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def bopomofo_to_romaji(text):
+ for regex, replacement in _bopomofo_to_romaji:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def bopomofo_to_ipa(text):
+ for regex, replacement in _bopomofo_to_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def bopomofo_to_ipa2(text):
+ for regex, replacement in _bopomofo_to_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def chinese_to_romaji(text):
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_romaji(text)
+ text = re.sub('i([aoe])', r'y\1', text)
+ text = re.sub('u([aoəe])', r'w\1', text)
+ text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
+ text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
+ return text
+
+
+def chinese_to_lazy_ipa(text):
+ text = chinese_to_romaji(text)
+ for regex, replacement in _romaji_to_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def chinese_to_ipa(text):
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_ipa(text)
+ text = re.sub('i([aoe])', r'j\1', text)
+ text = re.sub('u([aoəe])', r'w\1', text)
+ text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
+ text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
+ return text
+
+
+def chinese_to_ipa2(text):
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_ipa2(text)
+ text = re.sub(r'i([aoe])', r'j\1', text)
+ text = re.sub(r'u([aoəe])', r'w\1', text)
+ text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
+ text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
+ return text
diff --git a/text/symbols.py b/text/symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..6072b09039136b6f2e63f719478d1cf15a759a8e
--- /dev/null
+++ b/text/symbols.py
@@ -0,0 +1,88 @@
+'''
+Defines the set of symbols used in text input to the model.
+'''
+
+# japanese_cleaners
+# _pad = '_'
+# _punctuation = ',.!?-'
+# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
+
+
+'''# japanese_cleaners2
+_pad = '_'
+_punctuation = ',.!?-~…'
+_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
+'''
+
+
+'''# korean_cleaners
+_pad = '_'
+_punctuation = ',.!?…~'
+_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
+'''
+
+'''# chinese_cleaners
+_pad = '_'
+_punctuation = ',。!?—…'
+_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
+'''
+
+# # zh_ja_mixture_cleaners
+# _pad = '_'
+# _punctuation = ',.!?-~…'
+# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
+
+
+'''# sanskrit_cleaners
+_pad = '_'
+_punctuation = '।'
+_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
+'''
+
+'''# cjks_cleaners
+_pad = '_'
+_punctuation = ',.!?-~…'
+_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
+'''
+
+'''# thai_cleaners
+_pad = '_'
+_punctuation = '.!? '
+_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
+'''
+
+# # cjke_cleaners2
+_pad = '_'
+_punctuation = ',.!?-~…'
+_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
+
+
+'''# shanghainese_cleaners
+_pad = '_'
+_punctuation = ',.!?…'
+_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
+'''
+
+'''# chinese_dialect_cleaners
+_pad = '_'
+_punctuation = ',.!?~…─'
+_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
+'''
+
+# Export all symbols:
+symbols = [_pad] + list(_punctuation) + list(_letters)
+
+# Special symbol ids
+SPACE_ID = symbols.index(" ")
+
+num_ja_tones = 1
+num_kr_tones = 1
+num_zh_tones = 6
+num_en_tones = 4
+
+language_tone_start_map = {
+ "ZH": 0,
+ "JP": num_zh_tones,
+ "EN": num_zh_tones + num_ja_tones,
+ 'KR': num_zh_tones + num_ja_tones + num_en_tones,
+}
\ No newline at end of file
diff --git a/transforms.py b/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..a11f799e023864ff7082c1f49c0cc18351a13b47
--- /dev/null
+++ b/transforms.py
@@ -0,0 +1,209 @@
+import torch
+from torch.nn import functional as F
+
+import numpy as np
+
+
+DEFAULT_MIN_BIN_WIDTH = 1e-3
+DEFAULT_MIN_BIN_HEIGHT = 1e-3
+DEFAULT_MIN_DERIVATIVE = 1e-3
+
+
+def piecewise_rational_quadratic_transform(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails=None,
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if tails is None:
+ spline_fn = rational_quadratic_spline
+ spline_kwargs = {}
+ else:
+ spline_fn = unconstrained_rational_quadratic_spline
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
+
+ outputs, logabsdet = spline_fn(
+ inputs=inputs,
+ unnormalized_widths=unnormalized_widths,
+ unnormalized_heights=unnormalized_heights,
+ unnormalized_derivatives=unnormalized_derivatives,
+ inverse=inverse,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ **spline_kwargs
+ )
+ return outputs, logabsdet
+
+
+def searchsorted(bin_locations, inputs, eps=1e-6):
+ bin_locations[..., -1] += eps
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
+
+
+def unconstrained_rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails="linear",
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
+ outside_interval_mask = ~inside_interval_mask
+
+ outputs = torch.zeros_like(inputs)
+ logabsdet = torch.zeros_like(inputs)
+
+ if tails == "linear":
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
+ constant = np.log(np.exp(1 - min_derivative) - 1)
+ unnormalized_derivatives[..., 0] = constant
+ unnormalized_derivatives[..., -1] = constant
+
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
+ logabsdet[outside_interval_mask] = 0
+ else:
+ raise RuntimeError("{} tails are not implemented.".format(tails))
+
+ (
+ outputs[inside_interval_mask],
+ logabsdet[inside_interval_mask],
+ ) = rational_quadratic_spline(
+ inputs=inputs[inside_interval_mask],
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
+ inverse=inverse,
+ left=-tail_bound,
+ right=tail_bound,
+ bottom=-tail_bound,
+ top=tail_bound,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ )
+
+ return outputs, logabsdet
+
+
+def rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ left=0.0,
+ right=1.0,
+ bottom=0.0,
+ top=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if torch.min(inputs) < left or torch.max(inputs) > right:
+ raise ValueError("Input to a transform is not within its domain")
+
+ num_bins = unnormalized_widths.shape[-1]
+
+ if min_bin_width * num_bins > 1.0:
+ raise ValueError("Minimal bin width too large for the number of bins")
+ if min_bin_height * num_bins > 1.0:
+ raise ValueError("Minimal bin height too large for the number of bins")
+
+ widths = F.softmax(unnormalized_widths, dim=-1)
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
+ cumwidths = torch.cumsum(widths, dim=-1)
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
+ cumwidths = (right - left) * cumwidths + left
+ cumwidths[..., 0] = left
+ cumwidths[..., -1] = right
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
+
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
+
+ heights = F.softmax(unnormalized_heights, dim=-1)
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
+ cumheights = torch.cumsum(heights, dim=-1)
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
+ cumheights = (top - bottom) * cumheights + bottom
+ cumheights[..., 0] = bottom
+ cumheights[..., -1] = top
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
+
+ if inverse:
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
+ else:
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
+
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
+
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
+ delta = heights / widths
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
+
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
+
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
+
+ if inverse:
+ a = (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ ) + input_heights * (input_delta - input_derivatives)
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ )
+ c = -input_delta * (inputs - input_cumheights)
+
+ discriminant = b.pow(2) - 4 * a * c
+ assert (discriminant >= 0).all()
+
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
+ outputs = root * input_bin_widths + input_cumwidths
+
+ theta_one_minus_theta = root * (1 - root)
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * root.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - root).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, -logabsdet
+ else:
+ theta = (inputs - input_cumwidths) / input_bin_widths
+ theta_one_minus_theta = theta * (1 - theta)
+
+ numerator = input_heights * (
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
+ )
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ outputs = input_cumheights + numerator / denominator
+
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * theta.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - theta).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, logabsdet
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..747a3b7e5911403ff5a57f80f85569d180378431
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,194 @@
+import re
+import json
+import numpy as np
+
+
+def get_hparams_from_file(config_path):
+ with open(config_path, "r", encoding="utf-8") as f:
+ data = f.read()
+ config = json.loads(data)
+
+ hparams = HParams(**config)
+ return hparams
+
+class HParams:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ if type(v) == dict:
+ v = HParams(**v)
+ self[k] = v
+
+ def keys(self):
+ return self.__dict__.keys()
+
+ def items(self):
+ return self.__dict__.items()
+
+ def values(self):
+ return self.__dict__.values()
+
+ def __len__(self):
+ return len(self.__dict__)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value):
+ return setattr(self, key, value)
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def __repr__(self):
+ return self.__dict__.__repr__()
+
+
+def string_to_bits(string, pad_len=8):
+ # Convert each character to its ASCII value
+ ascii_values = [ord(char) for char in string]
+
+ # Convert ASCII values to binary representation
+ binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
+
+ # Convert binary strings to integer arrays
+ bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
+
+ # Convert list of arrays to NumPy array
+ numpy_array = np.array(bit_arrays)
+ numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
+ numpy_array_full[:, 2] = 1
+ max_len = min(pad_len, len(numpy_array))
+ numpy_array_full[:max_len] = numpy_array[:max_len]
+ return numpy_array_full
+
+
+def bits_to_string(bits_array):
+ # Convert each row of the array to a binary string
+ binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
+
+ # Convert binary strings to ASCII values
+ ascii_values = [int(binary, 2) for binary in binary_values]
+
+ # Convert ASCII values to characters
+ output_string = ''.join(chr(value) for value in ascii_values)
+
+ return output_string
+
+
+def split_sentence(text, min_len=10, language_str='[EN]'):
+ if language_str in ['EN']:
+ sentences = split_sentences_latin(text, min_len=min_len)
+ else:
+ sentences = split_sentences_zh(text, min_len=min_len)
+ return sentences
+
+def split_sentences_latin(text, min_len=10):
+ """Split Long sentences into list of short ones
+
+ Args:
+ str: Input sentences.
+
+ Returns:
+ List[str]: list of output sentences.
+ """
+ # deal with dirty sentences
+ text = re.sub('[。!?;]', '.', text)
+ text = re.sub('[,]', ',', text)
+ text = re.sub('[“”]', '"', text)
+ text = re.sub('[‘’]', "'", text)
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
+ text = re.sub('[\n\t ]+', ' ', text)
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
+ # split
+ sentences = [s.strip() for s in text.split('$#!')]
+ if len(sentences[-1]) == 0: del sentences[-1]
+
+ new_sentences = []
+ new_sent = []
+ count_len = 0
+ for ind, sent in enumerate(sentences):
+ # print(sent)
+ new_sent.append(sent)
+ count_len += len(sent.split(" "))
+ if count_len > min_len or ind == len(sentences) - 1:
+ count_len = 0
+ new_sentences.append(' '.join(new_sent))
+ new_sent = []
+ return merge_short_sentences_latin(new_sentences)
+
+
+def merge_short_sentences_latin(sens):
+ """Avoid short sentences by merging them with the following sentence.
+
+ Args:
+ List[str]: list of input sentences.
+
+ Returns:
+ List[str]: list of output sentences.
+ """
+ sens_out = []
+ for s in sens:
+ # If the previous sentense is too short, merge them with
+ # the current sentence.
+ if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
+ sens_out[-1] = sens_out[-1] + " " + s
+ else:
+ sens_out.append(s)
+ try:
+ if len(sens_out[-1].split(" ")) <= 2:
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
+ sens_out.pop(-1)
+ except:
+ pass
+ return sens_out
+
+def split_sentences_zh(text, min_len=10):
+ text = re.sub('[。!?;]', '.', text)
+ text = re.sub('[,]', ',', text)
+ # 将文本中的换行符、空格和制表符替换为空格
+ text = re.sub('[\n\t ]+', ' ', text)
+ # 在标点符号后添加一个空格
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
+ # 分隔句子并去除前后空格
+ # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
+ sentences = [s.strip() for s in text.split('$#!')]
+ if len(sentences[-1]) == 0: del sentences[-1]
+
+ new_sentences = []
+ new_sent = []
+ count_len = 0
+ for ind, sent in enumerate(sentences):
+ new_sent.append(sent)
+ count_len += len(sent)
+ if count_len > min_len or ind == len(sentences) - 1:
+ count_len = 0
+ new_sentences.append(' '.join(new_sent))
+ new_sent = []
+ return merge_short_sentences_zh(new_sentences)
+
+
+def merge_short_sentences_zh(sens):
+ # return sens
+ """Avoid short sentences by merging them with the following sentence.
+
+ Args:
+ List[str]: list of input sentences.
+
+ Returns:
+ List[str]: list of output sentences.
+ """
+ sens_out = []
+ for s in sens:
+ # If the previous sentense is too short, merge them with
+ # the current sentence.
+ if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
+ sens_out[-1] = sens_out[-1] + " " + s
+ else:
+ sens_out.append(s)
+ try:
+ if len(sens_out[-1]) <= 2:
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
+ sens_out.pop(-1)
+ except:
+ pass
+ return sens_out
\ No newline at end of file