gongqianshao Marne commited on
Commit
ecdf01c
0 Parent(s):

Duplicate from Marne/MockingBird

Browse files

Co-authored-by: Sakrua <Marne@users.noreply.huggingface.co>

Files changed (49) hide show
  1. .gitattributes +16 -0
  2. .gitignore +160 -0
  3. MANIFEST.in +4 -0
  4. README.md +35 -0
  5. app.py +86 -0
  6. data/azusa/azusa.pt +3 -0
  7. data/azusa/record.wav +3 -0
  8. data/encoder.pt +3 -0
  9. data/g_hifigan.pt +3 -0
  10. data/ltyai/ltyai.pt +3 -0
  11. data/ltyai/record.wav +3 -0
  12. data/nanmei/nanmei.pt +3 -0
  13. data/nanmei/record.wav +3 -0
  14. data/tianyi/record.wav +3 -0
  15. data/tianyi/tianyi.pt +3 -0
  16. data/wavernn.pt +3 -0
  17. mockingbirdforuse/__init__.py +120 -0
  18. mockingbirdforuse/encoder/__init__.py +0 -0
  19. mockingbirdforuse/encoder/audio.py +121 -0
  20. mockingbirdforuse/encoder/hparams.py +42 -0
  21. mockingbirdforuse/encoder/inference.py +154 -0
  22. mockingbirdforuse/encoder/model.py +145 -0
  23. mockingbirdforuse/log.py +40 -0
  24. mockingbirdforuse/synthesizer/__init__.py +0 -0
  25. mockingbirdforuse/synthesizer/gst_hyperparameters.py +19 -0
  26. mockingbirdforuse/synthesizer/hparams.py +113 -0
  27. mockingbirdforuse/synthesizer/inference.py +151 -0
  28. mockingbirdforuse/synthesizer/models/global_style_token.py +175 -0
  29. mockingbirdforuse/synthesizer/models/tacotron.py +678 -0
  30. mockingbirdforuse/synthesizer/utils/__init__.py +46 -0
  31. mockingbirdforuse/synthesizer/utils/cleaners.py +91 -0
  32. mockingbirdforuse/synthesizer/utils/logmmse.py +245 -0
  33. mockingbirdforuse/synthesizer/utils/numbers.py +70 -0
  34. mockingbirdforuse/synthesizer/utils/symbols.py +20 -0
  35. mockingbirdforuse/synthesizer/utils/text.py +74 -0
  36. mockingbirdforuse/vocoder/__init__.py +0 -0
  37. mockingbirdforuse/vocoder/distribution.py +136 -0
  38. mockingbirdforuse/vocoder/hifigan/__init__.py +0 -0
  39. mockingbirdforuse/vocoder/hifigan/hparams.py +37 -0
  40. mockingbirdforuse/vocoder/hifigan/inference.py +32 -0
  41. mockingbirdforuse/vocoder/hifigan/models.py +460 -0
  42. mockingbirdforuse/vocoder/wavernn/__init__.py +0 -0
  43. mockingbirdforuse/vocoder/wavernn/audio.py +118 -0
  44. mockingbirdforuse/vocoder/wavernn/hparams.py +53 -0
  45. mockingbirdforuse/vocoder/wavernn/inference.py +56 -0
  46. mockingbirdforuse/vocoder/wavernn/models/deepmind_version.py +180 -0
  47. mockingbirdforuse/vocoder/wavernn/models/fatchord_version.py +445 -0
  48. packages.txt +3 -0
  49. requirements.txt +13 -0
.gitattributes ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
MANIFEST.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ include assets/*
2
+ include inputs/*
3
+ include LICENSE
4
+ include requirements.txt
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MockingBird
3
+ emoji: 🏃
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.1.7
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: Marne/MockingBird
11
+ ---
12
+
13
+ # Configuration
14
+
15
+ `title`: _string_
16
+ Display title for the Space
17
+
18
+ `emoji`: _string_
19
+ Space emoji (emoji-only character allowed)
20
+
21
+ `colorFrom`: _string_
22
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
+
24
+ `colorTo`: _string_
25
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
+
27
+ `sdk`: _string_
28
+ Can be either `gradio` or `streamlit`
29
+
30
+ `app_file`: _string_
31
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
32
+ Path is relative to the root of the repository.
33
+
34
+ `pinned`: _boolean_
35
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ import torch
4
+ import gradio as gr
5
+ from tempfile import NamedTemporaryFile
6
+ from pathlib import Path
7
+
8
+ from mockingbirdforuse import MockingBird
9
+
10
+
11
+ mockingbird = MockingBird()
12
+ mockingbird_path = Path(os.path.dirname(__file__)) / "data"
13
+ base_url = "https://al.smoe.top/d/Home/source/mockingbird/"
14
+
15
+ for sy in ["encoder.pt", "g_hifigan.pt", "wavernn.pt"]:
16
+ if not os.path.exists(os.path.join(mockingbird_path, sy)):
17
+ torch.hub.download_url_to_file(f"{base_url}/{sy}", mockingbird_path / sy)
18
+
19
+ for model in ["azusa", "nanmei", "ltyai", "tianyi"]:
20
+ model_path = mockingbird_path / model
21
+ model_path.mkdir(parents=True, exist_ok=True)
22
+ for file_name in ["record.wav", f"{model}.pt"]:
23
+ if not os.path.exists(os.path.join(model_path, file_name)):
24
+ torch.hub.download_url_to_file(
25
+ f"{base_url}/{model}/{file_name}", model_path / file_name
26
+ )
27
+
28
+ mockingbird.load_model(
29
+ Path(os.path.join(mockingbird_path, "encoder.pt")),
30
+ Path(os.path.join(mockingbird_path, "g_hifigan.pt")),
31
+ Path(os.path.join(mockingbird_path, "wavernn.pt")),
32
+ )
33
+
34
+
35
+ def inference(
36
+ text: str,
37
+ model_name: str,
38
+ vocoder_type: str = "HifiGan",
39
+ style_idx: int = 0,
40
+ min_stop_token: int = 9,
41
+ steps: int = 2000,
42
+ ):
43
+ model_path = mockingbird_path / model_name
44
+ mockingbird.set_synthesizer(Path(os.path.join(model_path, f"{model_name}.pt")))
45
+ fd = NamedTemporaryFile(suffix=".wav", delete=False)
46
+ record = mockingbird.synthesize(
47
+ text=str(text),
48
+ input_wav=model_path / "record.wav",
49
+ vocoder_type=vocoder_type,
50
+ style_idx=style_idx,
51
+ min_stop_token=min_stop_token,
52
+ steps=steps,
53
+ )
54
+ with open(fd.name, "wb") as file:
55
+ file.write(record.getvalue())
56
+ return fd.name
57
+
58
+
59
+ title = "MockingBird"
60
+ description = "🚀AI拟声: 5秒内克隆您的声音并生成任意语音内容 Clone a voice in 5 seconds to generate arbitrary speech in real-time"
61
+ article = "<a href='https://github.com/babysor/MockingBird'>Github Repo</a></p>"
62
+
63
+ gr.Interface(
64
+ inference,
65
+ [
66
+ gr.Textbox(label="Input"),
67
+ gr.Radio(
68
+ ["azusa", "nanmei", "ltyai", "tianyi"],
69
+ label="model type",
70
+ value="azusa",
71
+ ),
72
+ gr.Radio(
73
+ ["HifiGan", "WaveRNN"],
74
+ label="Vocoder type",
75
+ value="HifiGan",
76
+ ),
77
+ gr.Slider(minimum=-1, maximum=9, step=1, label="style idx", value=0),
78
+ gr.Slider(minimum=3, maximum=9, label="min stop token", value=9),
79
+ gr.Slider(minimum=200, maximum=2000, label="steps", value=2000),
80
+ ],
81
+ gr.Audio(type="filepath", label="Output"),
82
+ title=title,
83
+ description=description,
84
+ article=article,
85
+ examples=[["阿梓不是你的电子播放器", "azusa", "HifiGan", 0, 9, 2000], ["不是", "nanmei", "HifiGan", 0, 9, 2000]],
86
+ ).launch()
data/azusa/azusa.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f5cc81057c8c7a5c8000ac8f5dd0335f878484640e69e2bb1f7a84d9b0bbf90
3
+ size 526153469
data/azusa/record.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:021a1fc6c6ee27a1829095e0cb886989ca9702abb6abf94c5482c28e1ec17b7f
3
+ size 778390
data/encoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57715adc6f36047166ab06e37b904240aee2f4d10fc88f78ed91510cf4b38666
3
+ size 17095158
data/g_hifigan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c5b29830f9b42c481c108cb0b89d56f380928d4d46e1d30d65c92340ddc694e
3
+ size 51985448
data/ltyai/ltyai.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4bd4b759a30efd70d0064628c3b107aa7cd9d0bff8a36a242946a46d7c5235c
3
+ size 526153021
data/ltyai/record.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cff766c73cb0a033249038a8e13243ea38d06a9e3b3a9fe7c207c663a046ca2
3
+ size 1130540
data/nanmei/nanmei.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95e90985b4c6b8090d8b328e7b23078eb00cffa0464ca9982464f0000b44a2a9
3
+ size 526153469
data/nanmei/record.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adc98df66db2ed721aaf917845c705a5f242be8a9f8adfc7d496a33a7d2b51e0
3
+ size 915594
data/tianyi/record.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dba1af79c33be37282636c19660d77b10ab054708246682abd2b525c01cb578a
3
+ size 942158
data/tianyi/tianyi.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9140c057ad8f4243e47a18103e773e0f823c4423927eec67dd47a3c3e9a9293
3
+ size 526153469
data/wavernn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169
3
+ size 53845290
mockingbirdforuse/__init__.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import librosa
3
+ import numpy as np
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+ from scipy.io import wavfile
7
+ from typing import List, Literal, Optional
8
+
9
+ from .encoder.inference import Encoder, preprocess_wav
10
+ from .synthesizer.inference import Synthesizer
11
+ from .vocoder.hifigan.inference import HifiGanVocoder
12
+ from .vocoder.wavernn.inference import WaveRNNVocoder
13
+ from .log import logger
14
+
15
+
16
+ def process_text(text: str) -> List[str]:
17
+ punctuation = "!,。、,?!," # punctuate and split/clean text
18
+ processed_texts = []
19
+ text = re.sub(r"[{}]+".format(punctuation), "\n", text)
20
+ for processed_text in text.split("\n"):
21
+ if processed_text:
22
+ processed_texts.append(processed_text.strip())
23
+ return processed_texts
24
+
25
+
26
+ class MockingBird:
27
+ def __init__(self):
28
+ self.encoder: Optional[Encoder] = None
29
+ self.gan_vocoder: Optional[HifiGanVocoder] = None
30
+ self.rnn_vocoder: Optional[WaveRNNVocoder] = None
31
+ self.synthesizer: Optional[Synthesizer] = None
32
+
33
+ def load_model(
34
+ self,
35
+ encoder_path: Path,
36
+ gan_vocoder_path: Optional[Path] = None,
37
+ rnn_vocoder_path: Optional[Path] = None,
38
+ ):
39
+ """
40
+ 设置 Encoder模型 和 Vocoder模型 路径
41
+
42
+ Args:
43
+ encoder_path (Path): Encoder模型路径
44
+ gan_vocoder_path (Path): HifiGan Vocoder模型路径,可选,需要用到 HifiGan 类型时必须填写
45
+ rnn_vocoder_path (Path): WaveRNN Vocoder模型路径,可选,需要用到 WaveRNN 类型时必须填写
46
+ """
47
+ self.encoder = Encoder(encoder_path)
48
+ if gan_vocoder_path:
49
+ self.gan_vocoder = HifiGanVocoder(gan_vocoder_path)
50
+ if rnn_vocoder_path:
51
+ self.rnn_vocoder = WaveRNNVocoder(rnn_vocoder_path)
52
+
53
+ def set_synthesizer(self, synthesizer_path: Path):
54
+ """
55
+ 设置Synthesizer模型路径
56
+
57
+ Args:
58
+ synthesizer_path (Path): Synthesizer模型路径
59
+ """
60
+ self.synthesizer = Synthesizer(synthesizer_path)
61
+ logger.info(f"using synthesizer model: {synthesizer_path}")
62
+
63
+ def synthesize(
64
+ self,
65
+ text: str,
66
+ input_wav: Path,
67
+ vocoder_type: Literal["HifiGan", "WaveRNN"] = "HifiGan",
68
+ style_idx: int = 0,
69
+ min_stop_token: int = 5,
70
+ steps: int = 1000,
71
+ ) -> BytesIO:
72
+ """
73
+ 生成语音
74
+
75
+ Args:
76
+ text (str): 目标文字
77
+ input_wav (Path): 目标录音路径
78
+ vocoder_type (HifiGan / WaveRNN): Vocoder模型,默认使用HifiGan
79
+ style_idx (int, optional): Style 范围 -1~9,默认为 0
80
+ min_stop_token (int, optional): Accuracy(精度) 范围3~9,默认为 5
81
+ steps (int, optional): MaxLength(最大句长) 范围200~2000,默认为 1000
82
+ """
83
+ if not self.encoder:
84
+ raise Exception("Please set encoder path first")
85
+
86
+ if not self.synthesizer:
87
+ raise Exception("Please set synthesizer path first")
88
+
89
+ # Load input wav
90
+ wav, sample_rate = librosa.load(input_wav)
91
+
92
+ encoder_wav = preprocess_wav(wav, sample_rate)
93
+ embed, _, _ = self.encoder.embed_utterance(encoder_wav, return_partials=True)
94
+
95
+ # Load input text
96
+ texts = process_text(text)
97
+
98
+ # synthesize and vocode
99
+ embeds = [embed] * len(texts)
100
+ specs = self.synthesizer.synthesize_spectrograms(
101
+ texts,
102
+ embeds,
103
+ style_idx=style_idx,
104
+ min_stop_token=min_stop_token,
105
+ steps=steps,
106
+ )
107
+ spec = np.concatenate(specs, axis=1)
108
+ if vocoder_type == "WaveRNN":
109
+ if not self.rnn_vocoder:
110
+ raise Exception("Please set wavernn vocoder path first")
111
+ wav, sample_rate = self.rnn_vocoder.infer_waveform(spec)
112
+ else:
113
+ if not self.gan_vocoder:
114
+ raise Exception("Please set hifigan vocoder path first")
115
+ wav, sample_rate = self.gan_vocoder.infer_waveform(spec)
116
+
117
+ # Return cooked wav
118
+ out = BytesIO()
119
+ wavfile.write(out, sample_rate, wav.astype(np.float32))
120
+ return out
mockingbirdforuse/encoder/__init__.py ADDED
File without changes
mockingbirdforuse/encoder/audio.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import struct
2
+ import librosa
3
+ import webrtcvad
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from typing import Optional, Union
7
+ from scipy.ndimage.morphology import binary_dilation
8
+
9
+ from .hparams import hparams as hp
10
+
11
+
12
+ def preprocess_wav(
13
+ fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None,
15
+ normalize: Optional[bool] = True,
16
+ trim_silence: Optional[bool] = True,
17
+ ):
18
+ """
19
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
20
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
21
+
22
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
23
+ just .wav), either the waveform as a numpy array of floats.
24
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
25
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
26
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
27
+ this argument will be ignored.
28
+ """
29
+ # Load the wav from disk if needed
30
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
31
+ wav, source_sr = librosa.load(str(fpath_or_wav))
32
+ else:
33
+ wav = fpath_or_wav
34
+
35
+ # Resample the wav if needed
36
+ if source_sr is not None and source_sr != hp.sampling_rate:
37
+ wav = librosa.resample(wav, orig_sr=source_sr, target_sr=hp.sampling_rate)
38
+
39
+ # Apply the preprocessing: normalize volume and shorten long silences
40
+ if normalize:
41
+ wav = normalize_volume(wav, hp.audio_norm_target_dBFS, increase_only=True)
42
+ if webrtcvad and trim_silence:
43
+ wav = trim_long_silences(wav)
44
+
45
+ return wav
46
+
47
+
48
+ def wav_to_mel_spectrogram(wav):
49
+ """
50
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
51
+ Note: this not a log-mel spectrogram.
52
+ """
53
+ frames = librosa.feature.melspectrogram(
54
+ y=wav,
55
+ sr=hp.sampling_rate,
56
+ n_fft=int(hp.sampling_rate * hp.mel_window_length / 1000),
57
+ hop_length=int(hp.sampling_rate * hp.mel_window_step / 1000),
58
+ n_mels=hp.mel_n_channels,
59
+ )
60
+ return frames.astype(np.float32).T
61
+
62
+
63
+ def trim_long_silences(wav):
64
+ """
65
+ Ensures that segments without voice in the waveform remain no longer than a
66
+ threshold determined by the VAD parameters in params.py.
67
+
68
+ :param wav: the raw waveform as a numpy array of floats
69
+ :return: the same waveform with silences trimmed away (length <= original wav length)
70
+ """
71
+ # Compute the voice detection window size
72
+ samples_per_window = (hp.vad_window_length * hp.sampling_rate) // 1000
73
+
74
+ # Trim the end of the audio to have a multiple of the window size
75
+ wav = wav[: len(wav) - (len(wav) % samples_per_window)]
76
+
77
+ # Convert the float waveform to 16-bit mono PCM
78
+ int16_max = (2**15) - 1
79
+ pcm_wave = struct.pack(
80
+ "%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)
81
+ )
82
+
83
+ # Perform voice activation detection
84
+ voice_flags = []
85
+ vad = webrtcvad.Vad(mode=3)
86
+ for window_start in range(0, len(wav), samples_per_window):
87
+ window_end = window_start + samples_per_window
88
+ voice_flags.append(
89
+ vad.is_speech(
90
+ pcm_wave[window_start * 2 : window_end * 2],
91
+ sample_rate=hp.sampling_rate,
92
+ )
93
+ )
94
+ voice_flags = np.array(voice_flags)
95
+
96
+ # Smooth the voice detection with a moving average
97
+ def moving_average(array, width):
98
+ array_padded = np.concatenate(
99
+ (np.zeros((width - 1) // 2), array, np.zeros(width // 2))
100
+ )
101
+ ret = np.cumsum(array_padded, dtype=float)
102
+ ret[width:] = ret[width:] - ret[:-width]
103
+ return ret[width - 1 :] / width
104
+
105
+ audio_mask = moving_average(voice_flags, hp.vad_moving_average_width)
106
+ audio_mask = np.round(audio_mask).astype(np.bool8)
107
+
108
+ # Dilate the voiced regions
109
+ audio_mask = binary_dilation(audio_mask, np.ones(hp.vad_max_silence_length + 1))
110
+ audio_mask = np.repeat(audio_mask, samples_per_window)
111
+
112
+ return wav[audio_mask == True]
113
+
114
+
115
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
116
+ if increase_only and decrease_only:
117
+ raise ValueError("Both increase only and decrease only are set")
118
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
119
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
120
+ return wav
121
+ return wav * (10 ** (dBFS_change / 20))
mockingbirdforuse/encoder/hparams.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class HParams:
6
+ ## Mel-filterbank
7
+ mel_window_length = 25 # In milliseconds
8
+ mel_window_step = 10 # In milliseconds
9
+ mel_n_channels = 40
10
+
11
+ ## Audio
12
+ sampling_rate = 16000
13
+ # Number of spectrogram frames in a partial utterance
14
+ partials_n_frames = 160 # 1600 ms
15
+ # Number of spectrogram frames at inference
16
+ inference_n_frames = 80 # 800 ms
17
+
18
+ ## Voice Activation Detection
19
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
20
+ # This sets the granularity of the VAD. Should not need to be changed.
21
+ vad_window_length = 30 # In milliseconds
22
+ # Number of frames to average together when performing the moving average smoothing.
23
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
24
+ vad_moving_average_width = 8
25
+ # Maximum number of consecutive silent frames a segment can have.
26
+ vad_max_silence_length = 6
27
+
28
+ ## Audio volume normalization
29
+ audio_norm_target_dBFS = -30
30
+
31
+ ## Model parameters
32
+ model_hidden_size = 256
33
+ model_embedding_size = 256
34
+ model_num_layers = 3
35
+
36
+ ## Training parameters
37
+ learning_rate_init = 1e-4
38
+ speakers_per_batch = 64
39
+ utterances_per_speaker = 10
40
+
41
+
42
+ hparams = HParams()
mockingbirdforuse/encoder/inference.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from pathlib import Path
4
+
5
+ from . import audio
6
+ from .model import SpeakerEncoder
7
+ from .audio import preprocess_wav # We want to expose this function from here
8
+ from .hparams import hparams as hp
9
+ from ..log import logger
10
+
11
+
12
+ class Encoder:
13
+ def __init__(self, model_path: Path):
14
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self._model = SpeakerEncoder(self._device, torch.device("cpu"))
16
+ checkpoint = torch.load(model_path, self._device)
17
+ self._model.load_state_dict(checkpoint["model_state"])
18
+ self._model.eval()
19
+ logger.info(
20
+ f"Loaded encoder {model_path.name} trained to step {checkpoint['step']}"
21
+ )
22
+
23
+ def embed_frames_batch(self, frames_batch):
24
+ """
25
+ Computes embeddings for a batch of mel spectrogram.
26
+
27
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
28
+ (batch_size, n_frames, n_channels)
29
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
30
+ """
31
+
32
+ frames = torch.from_numpy(frames_batch).to(self._device)
33
+ embed = self._model.forward(frames).detach().cpu().numpy()
34
+ return embed
35
+
36
+ def compute_partial_slices(
37
+ self,
38
+ n_samples,
39
+ partial_utterance_n_frames=hp.partials_n_frames,
40
+ min_pad_coverage=0.75,
41
+ overlap=0.5,
42
+ rate=None,
43
+ ):
44
+ """
45
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
46
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
47
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
48
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
49
+ defined in params_data.py.
50
+
51
+ The returned ranges may be indexing further than the length of the waveform. It is
52
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
53
+
54
+ :param n_samples: the number of samples in the waveform
55
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
56
+ utterance
57
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
58
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
59
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
60
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
61
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
62
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
63
+ utterances are entirely disjoint.
64
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
65
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
66
+ utterances.
67
+ """
68
+ assert 0 <= overlap < 1
69
+ assert 0 < min_pad_coverage <= 1
70
+
71
+ if rate != None:
72
+ samples_per_frame = int((hp.sampling_rate * hp.mel_window_step / 1000))
73
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
74
+ frame_step = int(np.round((hp.sampling_rate / rate) / samples_per_frame))
75
+ else:
76
+ samples_per_frame = int((hp.sampling_rate * hp.mel_window_step / 1000))
77
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
78
+ frame_step = max(
79
+ int(np.round(partial_utterance_n_frames * (1 - overlap))), 1
80
+ )
81
+
82
+ assert 0 < frame_step, "The rate is too high"
83
+ assert (
84
+ frame_step <= hp.partials_n_frames
85
+ ), "The rate is too low, it should be %f at least" % (
86
+ hp.sampling_rate / (samples_per_frame * hp.partials_n_frames)
87
+ )
88
+
89
+ # Compute the slices
90
+ wav_slices, mel_slices = [], []
91
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
92
+ for i in range(0, steps, frame_step):
93
+ mel_range = np.array([i, i + partial_utterance_n_frames])
94
+ wav_range = mel_range * samples_per_frame
95
+ mel_slices.append(slice(*mel_range))
96
+ wav_slices.append(slice(*wav_range))
97
+
98
+ # Evaluate whether extra padding is warranted or not
99
+ last_wav_range = wav_slices[-1]
100
+ coverage = (n_samples - last_wav_range.start) / (
101
+ last_wav_range.stop - last_wav_range.start
102
+ )
103
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
104
+ mel_slices = mel_slices[:-1]
105
+ wav_slices = wav_slices[:-1]
106
+
107
+ return wav_slices, mel_slices
108
+
109
+ def embed_utterance(
110
+ self, wav, using_partials: bool = True, return_partials: bool = False, **kwargs
111
+ ):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
116
+ :param using_partials: if True, then the utterance is split in partial utterances of
117
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
118
+ normalized average. If False, the utterance is instead computed from feeding the entire
119
+ spectogram to the network.
120
+ :param return_partials: if True, the partial embeddings will also be returned along with the
121
+ wav slices that correspond to the partial embeddings.
122
+ :param kwargs: additional arguments to compute_partial_splits()
123
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
124
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
125
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
126
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
127
+ instead.
128
+ """
129
+ # Process the entire utterance if not using partials
130
+ if not using_partials:
131
+ frames = audio.wav_to_mel_spectrogram(wav)
132
+ embed = self.embed_frames_batch(frames[None, ...])[0]
133
+ if return_partials:
134
+ return embed, None, None
135
+ return embed
136
+
137
+ # Compute where to split the utterance into partials and pad if necessary
138
+ wave_slices, mel_slices = self.compute_partial_slices(len(wav), **kwargs)
139
+ max_wave_length = wave_slices[-1].stop
140
+ if max_wave_length >= len(wav):
141
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
142
+
143
+ # Split the utterance into partials
144
+ frames = audio.wav_to_mel_spectrogram(wav)
145
+ frames_batch = np.array([frames[s] for s in mel_slices])
146
+ partial_embeds = self.embed_frames_batch(frames_batch)
147
+
148
+ # Compute the utterance embedding from the partial embeddings
149
+ raw_embed = np.mean(partial_embeds, axis=0)
150
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
151
+
152
+ if return_partials:
153
+ return embed, partial_embeds, wave_slices
154
+ return embed
mockingbirdforuse/encoder/model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+ from scipy.optimize import brentq
5
+ from sklearn.metrics import roc_curve
6
+ from scipy.interpolate import interp1d
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.utils.clip_grad import clip_grad_norm_
9
+
10
+ from .hparams import hparams as hp
11
+
12
+
13
+ class SpeakerEncoder(nn.Module):
14
+ def __init__(self, device, loss_device):
15
+ super().__init__()
16
+ self.loss_device = loss_device
17
+
18
+ # Network defition
19
+ self.lstm = nn.LSTM(
20
+ input_size=hp.mel_n_channels,
21
+ hidden_size=hp.model_hidden_size,
22
+ num_layers=hp.model_num_layers,
23
+ batch_first=True,
24
+ ).to(device)
25
+ self.linear = nn.Linear(
26
+ in_features=hp.model_hidden_size, out_features=hp.model_embedding_size
27
+ ).to(device)
28
+ self.relu = torch.nn.ReLU().to(device)
29
+
30
+ # Cosine similarity scaling (with fixed initial parameter values)
31
+ self.similarity_weight = Parameter(torch.tensor([10.0])).to(loss_device)
32
+ self.similarity_bias = Parameter(torch.tensor([-5.0])).to(loss_device)
33
+
34
+ # Loss
35
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
36
+
37
+ def do_gradient_ops(self):
38
+ # Gradient scale
39
+ self.similarity_weight.grad *= 0.01
40
+ self.similarity_bias.grad *= 0.01
41
+
42
+ # Gradient clipping
43
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
44
+
45
+ def forward(self, utterances, hidden_init=None):
46
+ """
47
+ Computes the embeddings of a batch of utterance spectrograms.
48
+
49
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
50
+ (batch_size, n_frames, n_channels)
51
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
52
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
53
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
54
+ """
55
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
56
+ # and the final cell state.
57
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
58
+
59
+ # We take only the hidden state of the last layer
60
+ embeds_raw = self.relu(self.linear(hidden[-1]))
61
+
62
+ # L2-normalize it
63
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
64
+
65
+ return embeds
66
+
67
+ def similarity_matrix(self, embeds):
68
+ """
69
+ Computes the similarity matrix according the section 2.1 of GE2E.
70
+
71
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
72
+ utterances_per_speaker, embedding_size)
73
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
74
+ utterances_per_speaker, speakers_per_batch)
75
+ """
76
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
77
+
78
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
79
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
80
+ centroids_incl = centroids_incl.clone() / (
81
+ torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5
82
+ )
83
+
84
+ # Exclusive centroids (1 per utterance)
85
+ centroids_excl = torch.sum(embeds, dim=1, keepdim=True) - embeds
86
+ centroids_excl /= utterances_per_speaker - 1
87
+ centroids_excl = centroids_excl.clone() / (
88
+ torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5
89
+ )
90
+
91
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
92
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
93
+ # We vectorize the computation for efficiency.
94
+ sim_matrix = torch.zeros(
95
+ speakers_per_batch, utterances_per_speaker, speakers_per_batch
96
+ ).to(self.loss_device)
97
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int32)
98
+ for j in range(speakers_per_batch):
99
+ mask = np.where(mask_matrix[j])[0]
100
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
101
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
102
+
103
+ ## Even more vectorized version (slower maybe because of transpose)
104
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
105
+ # ).to(self.loss_device)
106
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
107
+ # mask = np.where(1 - eye)
108
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
109
+ # mask = np.where(eye)
110
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
111
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
112
+
113
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
114
+ return sim_matrix
115
+
116
+ def loss(self, embeds):
117
+ """
118
+ Computes the softmax loss according the section 2.1 of GE2E.
119
+
120
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
121
+ utterances_per_speaker, embedding_size)
122
+ :return: the loss and the EER for this batch of embeddings.
123
+ """
124
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
125
+
126
+ # Loss
127
+ sim_matrix = self.similarity_matrix(embeds)
128
+ sim_matrix = sim_matrix.reshape(
129
+ (speakers_per_batch * utterances_per_speaker, speakers_per_batch)
130
+ )
131
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
132
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
133
+ loss = self.loss_fn(sim_matrix, target)
134
+
135
+ # EER (not backpropagated)
136
+ with torch.no_grad():
137
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int32)[0]
138
+ labels = np.array([inv_argmax(i) for i in ground_truth])
139
+ preds = sim_matrix.detach().cpu().numpy()
140
+
141
+ # Snippet from https://yangcha.github.io/EER-ROC/
142
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
143
+ eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
144
+
145
+ return loss, eer
mockingbirdforuse/log.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import loguru
3
+
4
+ from typing import TYPE_CHECKING, Union
5
+
6
+ if TYPE_CHECKING:
7
+ from loguru import Logger
8
+
9
+ logger: "Logger" = loguru.logger
10
+
11
+
12
+ class Filter:
13
+ def __init__(self) -> None:
14
+ self.level: Union[int, str] = "DEBUG"
15
+
16
+ def __call__(self, record):
17
+ module_name: str = record["name"]
18
+ record["name"] = module_name.split(".")[0]
19
+ levelno = (
20
+ logger.level(self.level).no if isinstance(self.level, str) else self.level
21
+ )
22
+ return record["level"].no >= levelno
23
+
24
+
25
+ logger.remove()
26
+ default_filter: Filter = Filter()
27
+ default_format: str = (
28
+ "<g>{time:MM-DD HH:mm:ss}</g> "
29
+ "[<lvl>{level}</lvl>] "
30
+ "<c><u>{name}</u></c> | "
31
+ "{message}"
32
+ )
33
+ logger.add(
34
+ sys.stdout,
35
+ level=0,
36
+ colorize=True,
37
+ diagnose=False,
38
+ filter=default_filter,
39
+ format=default_format,
40
+ )
mockingbirdforuse/synthesizer/__init__.py ADDED
File without changes
mockingbirdforuse/synthesizer/gst_hyperparameters.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class GSTHyperparameters:
6
+ E = 512
7
+
8
+ # reference encoder
9
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
10
+
11
+ # style token layer
12
+ token_num = 10
13
+ # token_emb_size = 256
14
+ num_heads = 8
15
+
16
+ n_mels = 256 # Number of Mel banks to generate
17
+
18
+
19
+ hparams = GSTHyperparameters()
mockingbirdforuse/synthesizer/hparams.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class HParams:
6
+ ### Signal Processing (used in both synthesizer and vocoder)
7
+ sample_rate = 16000
8
+ n_fft = 800
9
+ num_mels = 80
10
+ hop_size = 200
11
+ """Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)"""
12
+ win_size = 800
13
+ """Tacotron uses 50 ms frame length (set to sample_rate * 0.050)"""
14
+ fmin = 55
15
+ min_level_db = -100
16
+ ref_level_db = 20
17
+ max_abs_value = 4.0
18
+ """Gradient explodes if too big, premature convergence if too small."""
19
+ preemphasis = 0.97
20
+ """Filter coefficient to use if preemphasize is True"""
21
+ preemphasize = True
22
+
23
+ ### Tacotron Text-to-Speech (TTS)
24
+ tts_embed_dims = 512
25
+ """Embedding dimension for the graphemes/phoneme inputs"""
26
+ tts_encoder_dims = 256
27
+ tts_decoder_dims = 128
28
+ tts_postnet_dims = 512
29
+ tts_encoder_K = 5
30
+ tts_lstm_dims = 1024
31
+ tts_postnet_K = 5
32
+ tts_num_highways = 4
33
+ tts_dropout = 0.5
34
+ tts_cleaner_names = ["basic_cleaners"]
35
+ tts_stop_threshold = -3.4
36
+ """
37
+ Value below which audio generation ends.
38
+ For example, for a range of [-4, 4], this
39
+ will terminate the sequence at the first
40
+ frame that has all values < -3.4
41
+ """
42
+
43
+ ### Tacotron Training
44
+ tts_schedule = [
45
+ (2, 1e-3, 10_000, 12),
46
+ (2, 5e-4, 15_000, 12),
47
+ (2, 2e-4, 20_000, 12),
48
+ (2, 1e-4, 30_000, 12),
49
+ (2, 5e-5, 40_000, 12),
50
+ (2, 1e-5, 60_000, 12),
51
+ (2, 5e-6, 160_000, 12),
52
+ (2, 3e-6, 320_000, 12),
53
+ (2, 1e-6, 640_000, 12),
54
+ ]
55
+ """
56
+ Progressive training schedule
57
+ (r, lr, step, batch_size)
58
+ r = reduction factor (# of mel frames synthesized for each decoder iteration)
59
+ lr = learning rate
60
+ """
61
+
62
+ tts_clip_grad_norm = 1.0
63
+ """clips the gradient norm to prevent explosion - set to None if not needed"""
64
+ tts_eval_interval = 500
65
+ """
66
+ Number of steps between model evaluation (sample generation)
67
+ Set to -1 to generate after completing epoch, or 0 to disable
68
+ """
69
+ tts_eval_num_samples = 1
70
+ """Makes this number of samples"""
71
+ tts_finetune_layers = []
72
+ """For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj"""
73
+
74
+ ### Data Preprocessing
75
+ max_mel_frames = 900
76
+ rescale = True
77
+ rescaling_max = 0.9
78
+ synthesis_batch_size = 16
79
+ """For vocoder preprocessing and inference."""
80
+
81
+ ### Mel Visualization and Griffin-Lim
82
+ signal_normalization = True
83
+ power = 1.5
84
+ griffin_lim_iters = 60
85
+
86
+ ### Audio processing options
87
+ fmax = 7600
88
+ """Should not exceed (sample_rate // 2)"""
89
+ allow_clipping_in_normalization = True
90
+ """Used when signal_normalization = True"""
91
+ clip_mels_length = True
92
+ """If true, discards samples exceeding max_mel_frames"""
93
+ use_lws = False
94
+ """Fast spectrogram phase recovery using local weighted sums"""
95
+ symmetric_mels = True
96
+ """Sets mel range to [-max_abs_value, max_abs_value] if True, and [0, max_abs_value] if False"""
97
+ trim_silence = True
98
+ """Use with sample_rate of 16000 for best results"""
99
+
100
+ ### SV2TTS
101
+ speaker_embedding_size = 256
102
+ """Dimension for the speaker embedding"""
103
+ silence_min_duration_split = 0.4
104
+ """Duration in seconds of a silence for an utterance to be split"""
105
+ utterance_min_duration = 1.6
106
+ """Duration in seconds below which utterances are discarded"""
107
+ use_gst = True
108
+ """Whether to use global style token"""
109
+ use_ser_for_gst = True
110
+ """Whether to use speaker embedding referenced for global style token"""
111
+
112
+
113
+ hparams = HParams()
mockingbirdforuse/synthesizer/inference.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Union, List
6
+ from pypinyin import lazy_pinyin, Style
7
+
8
+ from .hparams import hparams as hp
9
+ from .utils.symbols import symbols
10
+ from .models.tacotron import Tacotron
11
+ from .utils.text import text_to_sequence
12
+ from .utils.logmmse import denoise, profile_noise
13
+ from ..log import logger
14
+
15
+
16
+ class Synthesizer:
17
+ def __init__(self, model_path: Path):
18
+ # Check for GPU
19
+ if torch.cuda.is_available():
20
+ self.device = torch.device("cuda")
21
+ else:
22
+ self.device = torch.device("cpu")
23
+ logger.info(f"Synthesizer using device: {self.device}")
24
+
25
+ self._model = Tacotron(
26
+ embed_dims=hp.tts_embed_dims,
27
+ num_chars=len(symbols),
28
+ encoder_dims=hp.tts_encoder_dims,
29
+ decoder_dims=hp.tts_decoder_dims,
30
+ n_mels=hp.num_mels,
31
+ fft_bins=hp.num_mels,
32
+ postnet_dims=hp.tts_postnet_dims,
33
+ encoder_K=hp.tts_encoder_K,
34
+ lstm_dims=hp.tts_lstm_dims,
35
+ postnet_K=hp.tts_postnet_K,
36
+ num_highways=hp.tts_num_highways,
37
+ dropout=hp.tts_dropout,
38
+ stop_threshold=hp.tts_stop_threshold,
39
+ speaker_embedding_size=hp.speaker_embedding_size,
40
+ ).to(self.device)
41
+
42
+ self._model.load(model_path, self.device)
43
+ self._model.eval()
44
+
45
+ logger.info(
46
+ 'Loaded synthesizer "%s" trained to step %d'
47
+ % (model_path.name, self._model.state_dict()["step"])
48
+ )
49
+
50
+ def synthesize_spectrograms(
51
+ self,
52
+ texts: List[str],
53
+ embeddings: Union[np.ndarray, List[np.ndarray]],
54
+ return_alignments=False,
55
+ style_idx=0,
56
+ min_stop_token=5,
57
+ steps=2000,
58
+ ):
59
+ """
60
+ Synthesizes mel spectrograms from texts and speaker embeddings.
61
+
62
+ :param texts: a list of N text prompts to be synthesized
63
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
64
+ :param return_alignments: if True, a matrix representing the alignments between the
65
+ characters
66
+ and each decoder output step will be returned for each spectrogram
67
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
68
+ sequence length of spectrogram i, and possibly the alignments.
69
+ """
70
+
71
+ logger.debug("Read " + str(texts))
72
+ texts = [
73
+ " ".join(lazy_pinyin(v, style=Style.TONE3, neutral_tone_with_five=True))
74
+ for v in texts
75
+ ]
76
+ logger.debug("Synthesizing " + str(texts))
77
+ # Preprocess text inputs
78
+ inputs = [text_to_sequence(text, hp.tts_cleaner_names) for text in texts]
79
+ if not isinstance(embeddings, list):
80
+ embeddings = [embeddings]
81
+
82
+ # Batch inputs
83
+ batched_inputs = [
84
+ inputs[i : i + hp.synthesis_batch_size]
85
+ for i in range(0, len(inputs), hp.synthesis_batch_size)
86
+ ]
87
+ batched_embeds = [
88
+ embeddings[i : i + hp.synthesis_batch_size]
89
+ for i in range(0, len(embeddings), hp.synthesis_batch_size)
90
+ ]
91
+
92
+ specs = []
93
+ alignments = []
94
+ for i, batch in enumerate(batched_inputs, 1):
95
+ logger.debug(f"\n| Generating {i}/{len(batched_inputs)}")
96
+
97
+ # Pad texts so they are all the same length
98
+ text_lens = [len(text) for text in batch]
99
+ max_text_len = max(text_lens)
100
+ chars = [pad1d(text, max_text_len) for text in batch]
101
+ chars = np.stack(chars)
102
+
103
+ # Stack speaker embeddings into 2D array for batch processing
104
+ speaker_embeds = np.stack(batched_embeds[i - 1])
105
+
106
+ # Convert to tensor
107
+ chars = torch.tensor(chars).long().to(self.device)
108
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
109
+
110
+ # Inference
111
+ _, mels, alignments = self._model.generate(
112
+ chars,
113
+ speaker_embeddings,
114
+ style_idx=style_idx,
115
+ min_stop_token=min_stop_token,
116
+ steps=steps,
117
+ )
118
+ mels = mels.detach().cpu().numpy()
119
+ for m in mels:
120
+ # Trim silence from end of each spectrogram
121
+ while np.max(m[:, -1]) < hp.tts_stop_threshold:
122
+ m = m[:, :-1]
123
+ specs.append(m)
124
+
125
+ logger.debug("\n\nDone.\n")
126
+ return (specs, alignments) if return_alignments else specs
127
+
128
+ @staticmethod
129
+ def load_preprocess_wav(fpath):
130
+ """
131
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
132
+ train the synthesizer.
133
+ """
134
+ wav = librosa.load(path=str(fpath), sr=hp.sample_rate)[0]
135
+ if hp.rescale:
136
+ wav = wav / np.abs(wav).max() * hp.rescaling_max
137
+ # denoise
138
+ if len(wav) > hp.sample_rate * (0.3 + 0.1):
139
+ noise_wav = np.concatenate(
140
+ [
141
+ wav[: int(hp.sample_rate * 0.15)],
142
+ wav[-int(hp.sample_rate * 0.15) :],
143
+ ]
144
+ )
145
+ profile = profile_noise(noise_wav, hp.sample_rate)
146
+ wav = denoise(wav, profile)
147
+ return wav
148
+
149
+
150
+ def pad1d(x, max_len, pad_value=0):
151
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
mockingbirdforuse/synthesizer/models/global_style_token.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ from torch.nn.parameter import Parameter
5
+ import torch.nn.functional as tFunctional
6
+
7
+ from ..hparams import hparams as hp
8
+ from ..gst_hyperparameters import hparams as gst_hp
9
+
10
+
11
+ class GlobalStyleToken(nn.Module):
12
+ """
13
+ inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
14
+ speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
15
+ outputs: [batch_size, embedding_dim]
16
+ """
17
+
18
+ def __init__(self, speaker_embedding_dim=None):
19
+
20
+ super().__init__()
21
+ self.encoder = ReferenceEncoder()
22
+ self.stl = STL(speaker_embedding_dim)
23
+
24
+ def forward(self, inputs, speaker_embedding=None):
25
+ enc_out = self.encoder(inputs)
26
+ # concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
27
+ if hp.use_ser_for_gst and speaker_embedding is not None:
28
+ enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
29
+ style_embed = self.stl(enc_out)
30
+
31
+ return style_embed
32
+
33
+
34
+ class ReferenceEncoder(nn.Module):
35
+ """
36
+ inputs --- [N, Ty/r, n_mels*r] mels
37
+ outputs --- [N, ref_enc_gru_size]
38
+ """
39
+
40
+ def __init__(self):
41
+
42
+ super().__init__()
43
+ K = len(gst_hp.ref_enc_filters)
44
+ filters = [1] + gst_hp.ref_enc_filters
45
+ convs = [
46
+ nn.Conv2d(
47
+ in_channels=filters[i],
48
+ out_channels=filters[i + 1],
49
+ kernel_size=(3, 3),
50
+ stride=(2, 2),
51
+ padding=(1, 1),
52
+ )
53
+ for i in range(K)
54
+ ]
55
+ self.convs = nn.ModuleList(convs)
56
+ self.bns = nn.ModuleList(
57
+ [nn.BatchNorm2d(num_features=gst_hp.ref_enc_filters[i]) for i in range(K)]
58
+ )
59
+
60
+ out_channels = self.calculate_channels(gst_hp.n_mels, 3, 2, 1, K)
61
+ self.gru = nn.GRU(
62
+ input_size=gst_hp.ref_enc_filters[-1] * out_channels,
63
+ hidden_size=gst_hp.E // 2,
64
+ batch_first=True,
65
+ )
66
+
67
+ def forward(self, inputs):
68
+ N = inputs.size(0)
69
+ out = inputs.view(N, 1, -1, gst_hp.n_mels) # [N, 1, Ty, n_mels]
70
+ for conv, bn in zip(self.convs, self.bns):
71
+ out = conv(out)
72
+ out = bn(out)
73
+ out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
74
+
75
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
76
+ T = out.size(1)
77
+ N = out.size(0)
78
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
79
+
80
+ self.gru.flatten_parameters()
81
+ memory, out = self.gru(out) # out --- [1, N, E//2]
82
+
83
+ return out.squeeze(0)
84
+
85
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
86
+ for i in range(n_convs):
87
+ L = (L - kernel_size + 2 * pad) // stride + 1
88
+ return L
89
+
90
+
91
+ class STL(nn.Module):
92
+ """
93
+ inputs --- [N, E//2]
94
+ """
95
+
96
+ def __init__(self, speaker_embedding_dim=None):
97
+
98
+ super().__init__()
99
+ self.embed = Parameter(
100
+ torch.FloatTensor(gst_hp.token_num, gst_hp.E // gst_hp.num_heads)
101
+ )
102
+ d_q = gst_hp.E // 2
103
+ d_k = gst_hp.E // gst_hp.num_heads
104
+ # self.attention = MultiHeadAttention(gst_hp.num_heads, d_model, d_q, d_v)
105
+ if hp.use_ser_for_gst and speaker_embedding_dim is not None:
106
+ d_q += speaker_embedding_dim
107
+ self.attention = MultiHeadAttention(
108
+ query_dim=d_q, key_dim=d_k, num_units=gst_hp.E, num_heads=gst_hp.num_heads
109
+ )
110
+
111
+ init.normal_(self.embed, mean=0, std=0.5)
112
+
113
+ def forward(self, inputs):
114
+ N = inputs.size(0)
115
+ query = inputs.unsqueeze(1) # [N, 1, E//2]
116
+ keys = (
117
+ torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)
118
+ ) # [N, token_num, E // num_heads]
119
+ style_embed = self.attention(query, keys)
120
+
121
+ return style_embed
122
+
123
+
124
+ class MultiHeadAttention(nn.Module):
125
+ """
126
+ input:
127
+ query --- [N, T_q, query_dim]
128
+ key --- [N, T_k, key_dim]
129
+ output:
130
+ out --- [N, T_q, num_units]
131
+ """
132
+
133
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
134
+
135
+ super().__init__()
136
+ self.num_units = num_units
137
+ self.num_heads = num_heads
138
+ self.key_dim = key_dim
139
+
140
+ self.W_query = nn.Linear(
141
+ in_features=query_dim, out_features=num_units, bias=False
142
+ )
143
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
144
+ self.W_value = nn.Linear(
145
+ in_features=key_dim, out_features=num_units, bias=False
146
+ )
147
+
148
+ def forward(self, query, key):
149
+ querys = self.W_query(query) # [N, T_q, num_units]
150
+ keys = self.W_key(key) # [N, T_k, num_units]
151
+ values = self.W_value(key)
152
+
153
+ split_size = self.num_units // self.num_heads
154
+ querys = torch.stack(
155
+ torch.split(querys, split_size, dim=2), dim=0
156
+ ) # [h, N, T_q, num_units/h]
157
+ keys = torch.stack(
158
+ torch.split(keys, split_size, dim=2), dim=0
159
+ ) # [h, N, T_k, num_units/h]
160
+ values = torch.stack(
161
+ torch.split(values, split_size, dim=2), dim=0
162
+ ) # [h, N, T_k, num_units/h]
163
+
164
+ # score = softmax(QK^T / (d_k ** 0.5))
165
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
166
+ scores = scores / (self.key_dim**0.5)
167
+ scores = tFunctional.softmax(scores, dim=3)
168
+
169
+ # out = score * V
170
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
171
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
172
+ 0
173
+ ) # [N, T_q, num_units]
174
+
175
+ return out
mockingbirdforuse/synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from ..hparams import hparams as hp
7
+ from .global_style_token import GlobalStyleToken
8
+ from ..gst_hyperparameters import hparams as gst_hp
9
+ from ...log import logger
10
+
11
+
12
+ class HighwayNetwork(nn.Module):
13
+ def __init__(self, size):
14
+ super().__init__()
15
+ self.W1 = nn.Linear(size, size)
16
+ self.W2 = nn.Linear(size, size)
17
+ self.W1.bias.data.fill_(0.0)
18
+
19
+ def forward(self, x):
20
+ x1 = self.W1(x)
21
+ x2 = self.W2(x)
22
+ g = torch.sigmoid(x2)
23
+ y = g * F.relu(x1) + (1.0 - g) * x
24
+ return y
25
+
26
+
27
+ class Encoder(nn.Module):
28
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
29
+ super().__init__()
30
+ prenet_dims = (encoder_dims, encoder_dims)
31
+ cbhg_channels = encoder_dims
32
+ self.embedding = nn.Embedding(num_chars, embed_dims)
33
+ self.pre_net = PreNet(
34
+ embed_dims,
35
+ fc1_dims=prenet_dims[0],
36
+ fc2_dims=prenet_dims[1],
37
+ dropout=dropout,
38
+ )
39
+ self.cbhg = CBHG(
40
+ K=K,
41
+ in_channels=cbhg_channels,
42
+ channels=cbhg_channels,
43
+ proj_channels=[cbhg_channels, cbhg_channels],
44
+ num_highways=num_highways,
45
+ )
46
+
47
+ def forward(self, x, speaker_embedding=None):
48
+ x = self.embedding(x)
49
+ x = self.pre_net(x)
50
+ x.transpose_(1, 2)
51
+ x = self.cbhg(x)
52
+ if speaker_embedding is not None:
53
+ x = self.add_speaker_embedding(x, speaker_embedding)
54
+ return x
55
+
56
+ def add_speaker_embedding(self, x, speaker_embedding):
57
+ # SV2TTS
58
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
59
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
60
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
61
+ # This concats the speaker embedding for each char in the encoder output
62
+
63
+ # Save the dimensions as human-readable names
64
+ batch_size = x.size()[0]
65
+ num_chars = x.size()[1]
66
+
67
+ if speaker_embedding.dim() == 1:
68
+ idx = 0
69
+ else:
70
+ idx = 1
71
+
72
+ # Start by making a copy of each speaker embedding to match the input text length
73
+ # The output of this has size (batch_size, num_chars * speaker_embedding_size)
74
+ speaker_embedding_size = speaker_embedding.size()[idx]
75
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
76
+
77
+ # Reshape it and transpose
78
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
79
+ e = e.transpose(1, 2)
80
+
81
+ # Concatenate the tiled speaker embedding with the encoder output
82
+ x = torch.cat((x, e), 2)
83
+ return x
84
+
85
+
86
+ class BatchNormConv(nn.Module):
87
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
88
+ super().__init__()
89
+ self.conv = nn.Conv1d(
90
+ in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False
91
+ )
92
+ self.bnorm = nn.BatchNorm1d(out_channels)
93
+ self.relu = relu
94
+
95
+ def forward(self, x):
96
+ x = self.conv(x)
97
+ x = F.relu(x) if self.relu is True else x
98
+ return self.bnorm(x)
99
+
100
+
101
+ class CBHG(nn.Module):
102
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
103
+ super().__init__()
104
+
105
+ # List of all rnns to call `flatten_parameters()` on
106
+ self._to_flatten = []
107
+
108
+ self.bank_kernels = [i for i in range(1, K + 1)]
109
+ self.conv1d_bank = nn.ModuleList()
110
+ for k in self.bank_kernels:
111
+ conv = BatchNormConv(in_channels, channels, k)
112
+ self.conv1d_bank.append(conv)
113
+
114
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
115
+
116
+ self.conv_project1 = BatchNormConv(
117
+ len(self.bank_kernels) * channels, proj_channels[0], 3
118
+ )
119
+ self.conv_project2 = BatchNormConv(
120
+ proj_channels[0], proj_channels[1], 3, relu=False
121
+ )
122
+
123
+ # Fix the highway input if necessary
124
+ if proj_channels[-1] != channels:
125
+ self.highway_mismatch = True
126
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
127
+ else:
128
+ self.highway_mismatch = False
129
+
130
+ self.highways = nn.ModuleList()
131
+ for i in range(num_highways):
132
+ hn = HighwayNetwork(channels)
133
+ self.highways.append(hn)
134
+
135
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
136
+ self._to_flatten.append(self.rnn)
137
+
138
+ # Avoid fragmentation of RNN parameters and associated warning
139
+ self._flatten_parameters()
140
+
141
+ def forward(self, x):
142
+ # Although we `_flatten_parameters()` on init, when using DataParallel
143
+ # the model gets replicated, making it no longer guaranteed that the
144
+ # weights are contiguous in GPU memory. Hence, we must call it again
145
+ self.rnn.flatten_parameters()
146
+
147
+ # Save these for later
148
+ residual = x
149
+ seq_len = x.size(-1)
150
+ conv_bank = []
151
+
152
+ # Convolution Bank
153
+ for conv in self.conv1d_bank:
154
+ c = conv(x) # Convolution
155
+ conv_bank.append(c[:, :, :seq_len])
156
+
157
+ # Stack along the channel axis
158
+ conv_bank = torch.cat(conv_bank, dim=1)
159
+
160
+ # dump the last padding to fit residual
161
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
162
+
163
+ # Conv1d projections
164
+ x = self.conv_project1(x)
165
+ x = self.conv_project2(x)
166
+
167
+ # Residual Connect
168
+ x = x + residual
169
+
170
+ # Through the highways
171
+ x = x.transpose(1, 2)
172
+ if self.highway_mismatch is True:
173
+ x = self.pre_highway(x)
174
+ for h in self.highways:
175
+ x = h(x)
176
+
177
+ # And then the RNN
178
+ x, _ = self.rnn(x)
179
+ return x
180
+
181
+ def _flatten_parameters(self):
182
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
183
+ to improve efficiency and avoid PyTorch yelling at us."""
184
+ [m.flatten_parameters() for m in self._to_flatten]
185
+
186
+
187
+ class PreNet(nn.Module):
188
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
189
+ super().__init__()
190
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
191
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
192
+ self.p = dropout
193
+
194
+ def forward(self, x):
195
+ x = self.fc1(x)
196
+ x = F.relu(x)
197
+ x = F.dropout(x, self.p, training=True)
198
+ x = self.fc2(x)
199
+ x = F.relu(x)
200
+ x = F.dropout(x, self.p, training=True)
201
+ return x
202
+
203
+
204
+ class Attention(nn.Module):
205
+ def __init__(self, attn_dims):
206
+ super().__init__()
207
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
208
+ self.v = nn.Linear(attn_dims, 1, bias=False)
209
+
210
+ def forward(self, encoder_seq_proj, query, t):
211
+ # Transform the query vector
212
+ query_proj = self.W(query).unsqueeze(1)
213
+
214
+ # Compute the scores
215
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
216
+ scores = F.softmax(u, dim=1)
217
+
218
+ return scores.transpose(1, 2)
219
+
220
+
221
+ class LSA(nn.Module):
222
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
223
+ super().__init__()
224
+ self.conv = nn.Conv1d(
225
+ 1,
226
+ filters,
227
+ padding=(kernel_size - 1) // 2,
228
+ kernel_size=kernel_size,
229
+ bias=True,
230
+ )
231
+ self.L = nn.Linear(filters, attn_dim, bias=False)
232
+ self.W = nn.Linear(
233
+ attn_dim, attn_dim, bias=True
234
+ ) # Include the attention bias in this term
235
+ self.v = nn.Linear(attn_dim, 1, bias=False)
236
+ self.cumulative = None
237
+ self.attention = None
238
+
239
+ def init_attention(self, encoder_seq_proj):
240
+ device = encoder_seq_proj.device # use same device as parameters
241
+ b, t, c = encoder_seq_proj.size()
242
+ self.cumulative = torch.zeros(b, t, device=device)
243
+ self.attention = torch.zeros(b, t, device=device)
244
+
245
+ def forward(self, encoder_seq_proj, query, t, chars):
246
+
247
+ if t == 0:
248
+ self.init_attention(encoder_seq_proj)
249
+
250
+ processed_query = self.W(query).unsqueeze(1)
251
+
252
+ location = self.cumulative.unsqueeze(1)
253
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
254
+
255
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
256
+ u = u.squeeze(-1)
257
+
258
+ # Mask zero padding chars
259
+ u = u * (chars != 0).float()
260
+
261
+ # Smooth Attention
262
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
263
+ scores = F.softmax(u, dim=1)
264
+ self.attention = scores
265
+ self.cumulative = self.cumulative + self.attention
266
+
267
+ return scores.unsqueeze(-1).transpose(1, 2)
268
+
269
+
270
+ class Decoder(nn.Module):
271
+ # Class variable because its value doesn't change between classes
272
+ # yet ought to be scoped by class because its a property of a Decoder
273
+ max_r = 20
274
+
275
+ def __init__(
276
+ self,
277
+ n_mels,
278
+ encoder_dims,
279
+ decoder_dims,
280
+ lstm_dims,
281
+ dropout,
282
+ speaker_embedding_size,
283
+ ):
284
+ super().__init__()
285
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
286
+ self.n_mels = n_mels
287
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
288
+ self.prenet = PreNet(
289
+ n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], dropout=dropout
290
+ )
291
+ self.attn_net = LSA(decoder_dims)
292
+ if hp.use_gst:
293
+ speaker_embedding_size += gst_hp.E
294
+ self.attn_rnn = nn.GRUCell(
295
+ encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims
296
+ )
297
+ self.rnn_input = nn.Linear(
298
+ encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims
299
+ )
300
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
301
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
302
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
303
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
304
+
305
+ def zoneout(self, prev, current, device, p=0.1):
306
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
307
+ return prev * mask + current * (1 - mask)
308
+
309
+ def forward(
310
+ self,
311
+ encoder_seq,
312
+ encoder_seq_proj,
313
+ prenet_in,
314
+ hidden_states,
315
+ cell_states,
316
+ context_vec,
317
+ t,
318
+ chars,
319
+ ):
320
+
321
+ # Need this for reshaping mels
322
+ batch_size = encoder_seq.size(0)
323
+ device = encoder_seq.device
324
+ # Unpack the hidden and cell states
325
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
326
+ rnn1_cell, rnn2_cell = cell_states
327
+
328
+ # PreNet for the Attention RNN
329
+ prenet_out = self.prenet(prenet_in)
330
+
331
+ # Compute the Attention RNN hidden state
332
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
333
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
334
+
335
+ # Compute the attention scores
336
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
337
+
338
+ # Dot product to create the context vector
339
+ context_vec = scores @ encoder_seq
340
+ context_vec = context_vec.squeeze(1)
341
+
342
+ # Concat Attention RNN output w. Context Vector & project
343
+ x = torch.cat([context_vec, attn_hidden], dim=1)
344
+ x = self.rnn_input(x)
345
+
346
+ # Compute first Residual RNN
347
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
348
+ if self.training:
349
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next, device=device)
350
+ else:
351
+ rnn1_hidden = rnn1_hidden_next
352
+ x = x + rnn1_hidden
353
+
354
+ # Compute second Residual RNN
355
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
356
+ if self.training:
357
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
358
+ else:
359
+ rnn2_hidden = rnn2_hidden_next
360
+ x = x + rnn2_hidden
361
+
362
+ # Project Mels
363
+ mels = self.mel_proj(x)
364
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, : self.r]
365
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
366
+ cell_states = (rnn1_cell, rnn2_cell)
367
+
368
+ # Stop token prediction
369
+ s = torch.cat((x, context_vec), dim=1)
370
+ s = self.stop_proj(s)
371
+ stop_tokens = torch.sigmoid(s)
372
+
373
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
374
+
375
+
376
+ class Tacotron(nn.Module):
377
+ def __init__(
378
+ self,
379
+ embed_dims,
380
+ num_chars,
381
+ encoder_dims,
382
+ decoder_dims,
383
+ n_mels,
384
+ fft_bins,
385
+ postnet_dims,
386
+ encoder_K,
387
+ lstm_dims,
388
+ postnet_K,
389
+ num_highways,
390
+ dropout,
391
+ stop_threshold,
392
+ speaker_embedding_size,
393
+ ):
394
+ super().__init__()
395
+ self.n_mels = n_mels
396
+ self.lstm_dims = lstm_dims
397
+ self.encoder_dims = encoder_dims
398
+ self.decoder_dims = decoder_dims
399
+ self.speaker_embedding_size = speaker_embedding_size
400
+ self.encoder = Encoder(
401
+ embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout
402
+ )
403
+ project_dims = encoder_dims + speaker_embedding_size
404
+ if hp.use_gst:
405
+ project_dims += gst_hp.E
406
+ self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
407
+ if hp.use_gst:
408
+ self.gst = GlobalStyleToken(speaker_embedding_size)
409
+ self.decoder = Decoder(
410
+ n_mels,
411
+ encoder_dims,
412
+ decoder_dims,
413
+ lstm_dims,
414
+ dropout,
415
+ speaker_embedding_size,
416
+ )
417
+ self.postnet = CBHG(
418
+ postnet_K, n_mels, postnet_dims, [postnet_dims, fft_bins], num_highways
419
+ )
420
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
421
+
422
+ self.init_model()
423
+ self.num_params()
424
+
425
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
426
+ self.register_buffer(
427
+ "stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)
428
+ )
429
+
430
+ @property
431
+ def r(self):
432
+ return self.decoder.r.item()
433
+
434
+ @r.setter
435
+ def r(self, value):
436
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
437
+
438
+ @staticmethod
439
+ def _concat_speaker_embedding(outputs, speaker_embeddings):
440
+ speaker_embeddings_ = speaker_embeddings.expand(
441
+ outputs.size(0), outputs.size(1), -1
442
+ )
443
+ outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
444
+ return outputs
445
+
446
+ def forward(self, texts, mels, speaker_embedding):
447
+ device = texts.device # use same device as parameters
448
+
449
+ self.step += 1
450
+ batch_size, _, steps = mels.size()
451
+
452
+ # Initialise all hidden states and pack into tuple
453
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
454
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
455
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
456
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
457
+
458
+ # Initialise all lstm cell states and pack into tuple
459
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
460
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
461
+ cell_states = (rnn1_cell, rnn2_cell)
462
+
463
+ # <GO> Frame for start of decoder loop
464
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
465
+
466
+ # Need an initial context vector
467
+ size = self.encoder_dims + self.speaker_embedding_size
468
+ if hp.use_gst:
469
+ size += gst_hp.E
470
+ context_vec = torch.zeros(batch_size, size, device=device)
471
+
472
+ # SV2TTS: Run the encoder with the speaker embedding
473
+ # The projection avoids unnecessary matmuls in the decoder loop
474
+ encoder_seq = self.encoder(texts, speaker_embedding)
475
+ # put after encoder
476
+ if hp.use_gst and self.gst is not None:
477
+ style_embed = self.gst(
478
+ speaker_embedding, speaker_embedding
479
+ ) # for training, speaker embedding can represent both style inputs and referenced
480
+ # style_embed = style_embed.expand_as(encoder_seq)
481
+ # encoder_seq = torch.cat((encoder_seq, style_embed), 2)
482
+ encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
483
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
484
+
485
+ # Need a couple of lists for outputs
486
+ mel_outputs, attn_scores, stop_outputs = [], [], []
487
+
488
+ # Run the decoder loop
489
+ for t in range(0, steps, self.r):
490
+ prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
491
+ (
492
+ mel_frames,
493
+ scores,
494
+ hidden_states,
495
+ cell_states,
496
+ context_vec,
497
+ stop_tokens,
498
+ ) = self.decoder(
499
+ encoder_seq,
500
+ encoder_seq_proj,
501
+ prenet_in,
502
+ hidden_states,
503
+ cell_states,
504
+ context_vec,
505
+ t,
506
+ texts,
507
+ )
508
+ mel_outputs.append(mel_frames)
509
+ attn_scores.append(scores)
510
+ stop_outputs.extend([stop_tokens] * self.r)
511
+
512
+ # Concat the mel outputs into sequence
513
+ mel_outputs = torch.cat(mel_outputs, dim=2)
514
+
515
+ # Post-Process for Linear Spectrograms
516
+ postnet_out = self.postnet(mel_outputs)
517
+ linear = self.post_proj(postnet_out)
518
+ linear = linear.transpose(1, 2)
519
+
520
+ # For easy visualisation
521
+ attn_scores = torch.cat(attn_scores, 1)
522
+ # attn_scores = attn_scores.cpu().data.numpy()
523
+ stop_outputs = torch.cat(stop_outputs, 1)
524
+
525
+ return mel_outputs, linear, attn_scores, stop_outputs
526
+
527
+ def generate(
528
+ self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5
529
+ ):
530
+ self.eval()
531
+ device = x.device # use same device as parameters
532
+
533
+ batch_size, _ = x.size()
534
+
535
+ # Need to initialise all hidden states and pack into tuple for tidyness
536
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
537
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
538
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
539
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
540
+
541
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
542
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
543
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
544
+ cell_states = (rnn1_cell, rnn2_cell)
545
+
546
+ # Need a <GO> Frame for start of decoder loop
547
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
548
+
549
+ # Need an initial context vector
550
+ size = self.encoder_dims + self.speaker_embedding_size
551
+ if hp.use_gst:
552
+ size += gst_hp.E
553
+ context_vec = torch.zeros(batch_size, size, device=device)
554
+
555
+ # SV2TTS: Run the encoder with the speaker embedding
556
+ # The projection avoids unnecessary matmuls in the decoder loop
557
+ encoder_seq = self.encoder(x, speaker_embedding)
558
+
559
+ # put after encoder
560
+ if hp.use_gst and self.gst is not None:
561
+ if style_idx >= 0 and style_idx < 10:
562
+ query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
563
+ if device.type == "cuda":
564
+ query = query.cuda()
565
+ gst_embed = torch.tanh(self.gst.stl.embed)
566
+ key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
567
+ style_embed = self.gst.stl.attention(query, key)
568
+ else:
569
+ speaker_embedding_style = torch.zeros(
570
+ speaker_embedding.size()[0], 1, self.speaker_embedding_size
571
+ ).to(device)
572
+ style_embed = self.gst(speaker_embedding_style, speaker_embedding)
573
+ encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
574
+ # style_embed = style_embed.expand_as(encoder_seq)
575
+ # encoder_seq = torch.cat((encoder_seq, style_embed), 2)
576
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
577
+
578
+ # Need a couple of lists for outputs
579
+ mel_outputs, attn_scores, stop_outputs = [], [], []
580
+
581
+ # Run the decoder loop
582
+ for t in range(0, steps, self.r):
583
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
584
+ (
585
+ mel_frames,
586
+ scores,
587
+ hidden_states,
588
+ cell_states,
589
+ context_vec,
590
+ stop_tokens,
591
+ ) = self.decoder(
592
+ encoder_seq,
593
+ encoder_seq_proj,
594
+ prenet_in,
595
+ hidden_states,
596
+ cell_states,
597
+ context_vec,
598
+ t,
599
+ x,
600
+ )
601
+ mel_outputs.append(mel_frames)
602
+ attn_scores.append(scores)
603
+ stop_outputs.extend([stop_tokens] * self.r)
604
+ # Stop the loop when all stop tokens in batch exceed threshold
605
+ if (stop_tokens * 10 > min_stop_token).all() and t > 10:
606
+ break
607
+
608
+ # Concat the mel outputs into sequence
609
+ mel_outputs = torch.cat(mel_outputs, dim=2)
610
+
611
+ # Post-Process for Linear Spectrograms
612
+ postnet_out = self.postnet(mel_outputs)
613
+ linear = self.post_proj(postnet_out)
614
+
615
+ linear = linear.transpose(1, 2)
616
+
617
+ # For easy visualisation
618
+ attn_scores = torch.cat(attn_scores, 1)
619
+ stop_outputs = torch.cat(stop_outputs, 1)
620
+
621
+ self.train()
622
+
623
+ return mel_outputs, linear, attn_scores
624
+
625
+ def init_model(self):
626
+ for p in self.parameters():
627
+ if p.dim() > 1:
628
+ nn.init.xavier_uniform_(p)
629
+
630
+ def finetune_partial(self, whitelist_layers):
631
+ self.zero_grad()
632
+ for name, child in self.named_children():
633
+ if name in whitelist_layers:
634
+ logger.debug("Trainable Layer: %s" % name)
635
+ logger.debug(
636
+ "Trainable Parameters: %.3f"
637
+ % sum([np.prod(p.size()) for p in child.parameters()])
638
+ )
639
+ for param in child.parameters():
640
+ param.requires_grad = False
641
+
642
+ def get_step(self):
643
+ return self.step.data.item()
644
+
645
+ def reset_step(self):
646
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
647
+ self.step = self.step.data.new_tensor(1)
648
+
649
+ def load(self, path, device, optimizer=None):
650
+ # Use device of model params as location for loaded state
651
+ checkpoint = torch.load(str(path), map_location=device)
652
+ self.load_state_dict(checkpoint["model_state"], strict=False)
653
+
654
+ if "optimizer_state" in checkpoint and optimizer is not None:
655
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
656
+
657
+ def save(self, path, optimizer=None):
658
+ if optimizer is not None:
659
+ torch.save(
660
+ {
661
+ "model_state": self.state_dict(),
662
+ "optimizer_state": optimizer.state_dict(),
663
+ },
664
+ str(path),
665
+ )
666
+ else:
667
+ torch.save(
668
+ {
669
+ "model_state": self.state_dict(),
670
+ },
671
+ str(path),
672
+ )
673
+
674
+ def num_params(self):
675
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
676
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
677
+ logger.debug("Trainable Parameters: %.3fM" % parameters)
678
+ return parameters
mockingbirdforuse/synthesizer/utils/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+
8
+ def data_parallel_workaround(model, *input):
9
+ global _output_ref
10
+ global _replicas_ref
11
+ device_ids = list(range(torch.cuda.device_count()))
12
+ output_device = device_ids[0]
13
+ replicas = torch.nn.parallel.replicate(model, device_ids)
14
+ # input.shape = (num_args, batch, ...)
15
+ inputs = torch.nn.parallel.scatter(input, device_ids)
16
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
17
+ replicas = replicas[: len(inputs)]
18
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
19
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
20
+ _output_ref = outputs
21
+ _replicas_ref = replicas
22
+ return y_hat
23
+
24
+
25
+ class ValueWindow:
26
+ def __init__(self, window_size=100):
27
+ self._window_size = window_size
28
+ self._values = []
29
+
30
+ def append(self, x):
31
+ self._values = self._values[-(self._window_size - 1) :] + [x]
32
+
33
+ @property
34
+ def sum(self):
35
+ return sum(self._values)
36
+
37
+ @property
38
+ def count(self):
39
+ return len(self._values)
40
+
41
+ @property
42
+ def average(self):
43
+ return self.sum / max(1, self.count)
44
+
45
+ def reset(self):
46
+ self._values = []
mockingbirdforuse/synthesizer/utils/cleaners.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+
13
+ import re
14
+ from unidecode import unidecode
15
+ from .numbers import normalize_numbers
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [
22
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
23
+ for x in [
24
+ ("mrs", "misess"),
25
+ ("mr", "mister"),
26
+ ("dr", "doctor"),
27
+ ("st", "saint"),
28
+ ("co", "company"),
29
+ ("jr", "junior"),
30
+ ("maj", "major"),
31
+ ("gen", "general"),
32
+ ("drs", "doctors"),
33
+ ("rev", "reverend"),
34
+ ("lt", "lieutenant"),
35
+ ("hon", "honorable"),
36
+ ("sgt", "sergeant"),
37
+ ("capt", "captain"),
38
+ ("esq", "esquire"),
39
+ ("ltd", "limited"),
40
+ ("col", "colonel"),
41
+ ("ft", "fort"),
42
+ ]
43
+ ]
44
+
45
+
46
+ def expand_abbreviations(text):
47
+ for regex, replacement in _abbreviations:
48
+ text = re.sub(regex, replacement, text)
49
+ return text
50
+
51
+
52
+ def expand_numbers(text):
53
+ return normalize_numbers(text)
54
+
55
+
56
+ def lowercase(text):
57
+ """lowercase input tokens."""
58
+ return text.lower()
59
+
60
+
61
+ def collapse_whitespace(text):
62
+ return re.sub(_whitespace_re, " ", text)
63
+
64
+
65
+ def convert_to_ascii(text):
66
+ return unidecode(text)
67
+
68
+
69
+ def basic_cleaners(text):
70
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
71
+ text = lowercase(text)
72
+ text = collapse_whitespace(text)
73
+ return text
74
+
75
+
76
+ def transliteration_cleaners(text):
77
+ """Pipeline for non-English text that transliterates to ASCII."""
78
+ text = convert_to_ascii(text)
79
+ text = lowercase(text)
80
+ text = collapse_whitespace(text)
81
+ return text
82
+
83
+
84
+ def english_cleaners(text):
85
+ """Pipeline for English text, including number and abbreviation expansion."""
86
+ text = convert_to_ascii(text)
87
+ text = lowercase(text)
88
+ text = expand_numbers(text)
89
+ text = expand_abbreviations(text)
90
+ text = collapse_whitespace(text)
91
+ return text
mockingbirdforuse/synthesizer/utils/logmmse.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ #
3
+ # Copyright (c) 2015 braindead
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ #
24
+ # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
25
+ # simply modified the interface to meet my needs.
26
+
27
+
28
+ import numpy as np
29
+ import math
30
+ from scipy.special import expn
31
+ from collections import namedtuple
32
+
33
+ NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
34
+
35
+
36
+ def profile_noise(noise, sampling_rate, window_size=0):
37
+ """
38
+ Creates a profile of the noise in a given waveform.
39
+
40
+ :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
41
+ :param sampling_rate: the sampling rate of the audio
42
+ :param window_size: the size of the window the logmmse algorithm operates on. A default value
43
+ will be picked if left as 0.
44
+ :return: a NoiseProfile object
45
+ """
46
+ noise, dtype = to_float(noise)
47
+ noise += np.finfo(np.float64).eps
48
+
49
+ if window_size == 0:
50
+ window_size = int(math.floor(0.02 * sampling_rate))
51
+
52
+ if window_size % 2 == 1:
53
+ window_size = window_size + 1
54
+
55
+ perc = 50
56
+ len1 = int(math.floor(window_size * perc / 100))
57
+ len2 = int(window_size - len1)
58
+
59
+ win = np.hanning(window_size)
60
+ win = win * len2 / np.sum(win)
61
+ n_fft = 2 * window_size
62
+
63
+ noise_mean = np.zeros(n_fft)
64
+ n_frames = len(noise) // window_size
65
+ for j in range(0, window_size * n_frames, window_size):
66
+ noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
67
+ noise_mu2 = (noise_mean / n_frames) ** 2
68
+
69
+ return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
70
+
71
+
72
+ def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
73
+ """
74
+ Cleans the noise from a speech waveform given a noise profile. The waveform must have the
75
+ same sampling rate as the one used to create the noise profile.
76
+
77
+ :param wav: a speech waveform as a numpy array of floats or ints.
78
+ :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
79
+ the same) waveform.
80
+ :param eta: voice threshold for noise update. While the voice activation detection value is
81
+ below this threshold, the noise profile will be continuously updated throughout the audio.
82
+ Set to 0 to disable updating the noise profile.
83
+ :return: the clean wav as a numpy array of floats or ints of the same length.
84
+ """
85
+ wav, dtype = to_float(wav)
86
+ wav += np.finfo(np.float64).eps
87
+ p = noise_profile
88
+
89
+ nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
90
+ x_final = np.zeros(nframes * p.len2)
91
+
92
+ aa = 0.98
93
+ mu = 0.98
94
+ ksi_min = 10 ** (-25 / 10)
95
+
96
+ x_old = np.zeros(p.len1)
97
+ xk_prev = np.zeros(p.len1)
98
+ noise_mu2 = p.noise_mu2
99
+ for k in range(0, nframes * p.len2, p.len2):
100
+ insign = p.win * wav[k:k + p.window_size]
101
+
102
+ spec = np.fft.fft(insign, p.n_fft, axis=0)
103
+ sig = np.absolute(spec)
104
+ sig2 = sig ** 2
105
+
106
+ gammak = np.minimum(sig2 / noise_mu2, 40)
107
+
108
+ if xk_prev.all() == 0:
109
+ ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
110
+ else:
111
+ ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
112
+ ksi = np.maximum(ksi_min, ksi)
113
+
114
+ log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
115
+ vad_decision = np.sum(log_sigma_k) / p.window_size
116
+ if vad_decision < eta:
117
+ noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
118
+
119
+ a = ksi / (1 + ksi)
120
+ vk = a * gammak
121
+ ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
122
+ hw = a * np.exp(ei_vk)
123
+ sig = sig * hw
124
+ xk_prev = sig ** 2
125
+ xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
126
+ xi_w = np.real(xi_w)
127
+
128
+ x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
129
+ x_old = xi_w[p.len1:p.window_size]
130
+
131
+ output = from_float(x_final, dtype)
132
+ output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
133
+ return output
134
+
135
+
136
+ ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
137
+ ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
138
+ ## webrctvad
139
+ # def vad(wav, sampling_rate, eta=0.15, window_size=0):
140
+ # """
141
+ # TODO: fix doc
142
+ # Creates a profile of the noise in a given waveform.
143
+ #
144
+ # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
145
+ # :param sampling_rate: the sampling rate of the audio
146
+ # :param window_size: the size of the window the logmmse algorithm operates on. A default value
147
+ # will be picked if left as 0.
148
+ # :param eta: voice threshold for noise update. While the voice activation detection value is
149
+ # below this threshold, the noise profile will be continuously updated throughout the audio.
150
+ # Set to 0 to disable updating the noise profile.
151
+ # """
152
+ # wav, dtype = to_float(wav)
153
+ # wav += np.finfo(np.float64).eps
154
+ #
155
+ # if window_size == 0:
156
+ # window_size = int(math.floor(0.02 * sampling_rate))
157
+ #
158
+ # if window_size % 2 == 1:
159
+ # window_size = window_size + 1
160
+ #
161
+ # perc = 50
162
+ # len1 = int(math.floor(window_size * perc / 100))
163
+ # len2 = int(window_size - len1)
164
+ #
165
+ # win = np.hanning(window_size)
166
+ # win = win * len2 / np.sum(win)
167
+ # n_fft = 2 * window_size
168
+ #
169
+ # wav_mean = np.zeros(n_fft)
170
+ # n_frames = len(wav) // window_size
171
+ # for j in range(0, window_size * n_frames, window_size):
172
+ # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
173
+ # noise_mu2 = (wav_mean / n_frames) ** 2
174
+ #
175
+ # wav, dtype = to_float(wav)
176
+ # wav += np.finfo(np.float64).eps
177
+ #
178
+ # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
179
+ # vad = np.zeros(nframes * len2, dtype=np.bool)
180
+ #
181
+ # aa = 0.98
182
+ # mu = 0.98
183
+ # ksi_min = 10 ** (-25 / 10)
184
+ #
185
+ # xk_prev = np.zeros(len1)
186
+ # noise_mu2 = noise_mu2
187
+ # for k in range(0, nframes * len2, len2):
188
+ # insign = win * wav[k:k + window_size]
189
+ #
190
+ # spec = np.fft.fft(insign, n_fft, axis=0)
191
+ # sig = np.absolute(spec)
192
+ # sig2 = sig ** 2
193
+ #
194
+ # gammak = np.minimum(sig2 / noise_mu2, 40)
195
+ #
196
+ # if xk_prev.all() == 0:
197
+ # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
198
+ # else:
199
+ # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
200
+ # ksi = np.maximum(ksi_min, ksi)
201
+ #
202
+ # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
203
+ # vad_decision = np.sum(log_sigma_k) / window_size
204
+ # if vad_decision < eta:
205
+ # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
206
+ #
207
+ # a = ksi / (1 + ksi)
208
+ # vk = a * gammak
209
+ # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
210
+ # hw = a * np.exp(ei_vk)
211
+ # sig = sig * hw
212
+ # xk_prev = sig ** 2
213
+ #
214
+ # vad[k:k + len2] = vad_decision >= eta
215
+ #
216
+ # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
217
+ # return vad
218
+
219
+
220
+ def to_float(_input):
221
+ if _input.dtype == np.float64:
222
+ return _input, _input.dtype
223
+ elif _input.dtype == np.float32:
224
+ return _input.astype(np.float64), _input.dtype
225
+ elif _input.dtype == np.uint8:
226
+ return (_input - 128) / 128., _input.dtype
227
+ elif _input.dtype == np.int16:
228
+ return _input / 32768., _input.dtype
229
+ elif _input.dtype == np.int32:
230
+ return _input / 2147483648., _input.dtype
231
+ raise ValueError('Unsupported wave file format')
232
+
233
+
234
+ def from_float(_input, dtype):
235
+ if dtype == np.float64:
236
+ return _input, np.float64
237
+ elif dtype == np.float32:
238
+ return _input.astype(np.float32)
239
+ elif dtype == np.uint8:
240
+ return ((_input * 128) + 128).astype(np.uint8)
241
+ elif dtype == np.int16:
242
+ return (_input * 32768).astype(np.int16)
243
+ elif dtype == np.int32:
244
+ return (_input * 2147483648).astype(np.int32)
245
+ raise ValueError('Unsupported wave file format')
mockingbirdforuse/synthesizer/utils/numbers.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+ _inflect = inflect.engine()
5
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
6
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
7
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
8
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
9
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
10
+ _number_re = re.compile(r"[0-9]+")
11
+
12
+
13
+ def _remove_commas(m):
14
+ return m.group(1).replace(",", "")
15
+
16
+
17
+ def _expand_decimal_point(m):
18
+ return m.group(1).replace(".", " point ")
19
+
20
+
21
+ def _expand_dollars(m):
22
+ match = m.group(1)
23
+ parts = match.split(".")
24
+ if len(parts) > 2:
25
+ return match + " dollars" # Unexpected format
26
+ dollars = int(parts[0]) if parts[0] else 0
27
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
28
+ if dollars and cents:
29
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
30
+ cent_unit = "cent" if cents == 1 else "cents"
31
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
32
+ elif dollars:
33
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
34
+ return "%s %s" % (dollars, dollar_unit)
35
+ elif cents:
36
+ cent_unit = "cent" if cents == 1 else "cents"
37
+ return "%s %s" % (cents, cent_unit)
38
+ else:
39
+ return "zero dollars"
40
+
41
+
42
+ def _expand_ordinal(m):
43
+ return _inflect.number_to_words(m.group(0))
44
+
45
+
46
+ def _expand_number(m):
47
+ num = int(m.group(0))
48
+ if num > 1000 and num < 3000:
49
+ if num == 2000:
50
+ return "two thousand"
51
+ elif num > 2000 and num < 2010:
52
+ return "two thousand " + _inflect.number_to_words(num % 100)
53
+ elif num % 100 == 0:
54
+ return _inflect.number_to_words(num // 100) + " hundred"
55
+ else:
56
+ return _inflect.number_to_words(
57
+ num, andword="", zero="oh", group=2
58
+ ).replace(", ", " ")
59
+ else:
60
+ return _inflect.number_to_words(num, andword="")
61
+
62
+
63
+ def normalize_numbers(text):
64
+ text = re.sub(_comma_number_re, _remove_commas, text)
65
+ text = re.sub(_pounds_re, r"\1 pounds", text)
66
+ text = re.sub(_dollars_re, _expand_dollars, text)
67
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
68
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
69
+ text = re.sub(_number_re, _expand_number, text)
70
+ return text
mockingbirdforuse/synthesizer/utils/symbols.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the set of symbols used in text input to the model.
3
+
4
+ The default is a set of ASCII characters that works well for English or text that has been run
5
+ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+ _characters = (
12
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!'(),-.:;? "
13
+ )
14
+
15
+ # _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model
16
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
17
+ # _arpabet = ["@' + s for s in cmudict.valid_symbols]
18
+
19
+ # Export all symbols:
20
+ symbols = [_pad, _eos] + list(_characters) # + _arpabet
mockingbirdforuse/synthesizer/utils/text.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import symbols
2
+ from . import cleaners
3
+ import re
4
+
5
+ # Mappings from symbol to numeric ID and vice versa:
6
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
7
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
8
+
9
+ # Regular expression matching text enclosed in curly braces:
10
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
11
+
12
+
13
+ def text_to_sequence(text, cleaner_names):
14
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
15
+
16
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
17
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
18
+
19
+ Args:
20
+ text: string to convert to a sequence
21
+ cleaner_names: names of the cleaner functions to run the text through
22
+
23
+ Returns:
24
+ List of integers corresponding to the symbols in the text
25
+ """
26
+ sequence = []
27
+
28
+ # Check for curly braces and treat their contents as ARPAbet:
29
+ while len(text):
30
+ m = _curly_re.match(text)
31
+ if not m:
32
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
33
+ break
34
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
35
+ sequence += _arpabet_to_sequence(m.group(2))
36
+ text = m.group(3)
37
+
38
+ # Append EOS token
39
+ sequence.append(_symbol_to_id["~"])
40
+ return sequence
41
+
42
+
43
+ def sequence_to_text(sequence):
44
+ """Converts a sequence of IDs back to a string"""
45
+ result = ""
46
+ for symbol_id in sequence:
47
+ if symbol_id in _id_to_symbol:
48
+ s = _id_to_symbol[symbol_id]
49
+ # Enclose ARPAbet back in curly braces:
50
+ if len(s) > 1 and s[0] == "@":
51
+ s = "{%s}" % s[1:]
52
+ result += s
53
+ return result.replace("}{", " ")
54
+
55
+
56
+ def _clean_text(text, cleaner_names):
57
+ for name in cleaner_names:
58
+ cleaner = getattr(cleaners, name)
59
+ if not cleaner:
60
+ raise Exception("Unknown cleaner: %s" % name)
61
+ text = cleaner(text)
62
+ return text
63
+
64
+
65
+ def _symbols_to_sequence(symbols):
66
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
67
+
68
+
69
+ def _arpabet_to_sequence(text):
70
+ return _symbols_to_sequence(["@" + s for s in text.split()])
71
+
72
+
73
+ def _should_keep_symbol(s):
74
+ return s in _symbol_to_id and s not in ("_", "~")
mockingbirdforuse/vocoder/__init__.py ADDED
File without changes
mockingbirdforuse/vocoder/distribution.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_sum_exp(x):
7
+ """numerically stable log_sum_exp implementation that prevents overflow"""
8
+ # TF ordering
9
+ axis = len(x.size()) - 1
10
+ m, _ = torch.max(x, dim=axis)
11
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
12
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
13
+
14
+
15
+ # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
16
+ def discretized_mix_logistic_loss(
17
+ y_hat, y, num_classes=65536, log_scale_min=None, reduce=True
18
+ ):
19
+ if log_scale_min is None:
20
+ log_scale_min = float(np.log(1e-14))
21
+ y_hat = y_hat.permute(0, 2, 1)
22
+ assert y_hat.dim() == 3
23
+ assert y_hat.size(1) % 3 == 0
24
+ nr_mix = y_hat.size(1) // 3
25
+
26
+ # (B x T x C)
27
+ y_hat = y_hat.transpose(1, 2)
28
+
29
+ # unpack parameters. (B, T, num_mixtures) x 3
30
+ logit_probs = y_hat[:, :, :nr_mix]
31
+ means = y_hat[:, :, nr_mix : 2 * nr_mix]
32
+ log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
33
+
34
+ # B x T x 1 -> B x T x num_mixtures
35
+ y = y.expand_as(means)
36
+
37
+ centered_y = y - means
38
+ inv_stdv = torch.exp(-log_scales)
39
+ plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
40
+ cdf_plus = torch.sigmoid(plus_in)
41
+ min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
42
+ cdf_min = torch.sigmoid(min_in)
43
+
44
+ # log probability for edge case of 0 (before scaling)
45
+ # equivalent: torch.log(F.sigmoid(plus_in))
46
+ log_cdf_plus = plus_in - F.softplus(plus_in)
47
+
48
+ # log probability for edge case of 255 (before scaling)
49
+ # equivalent: (1 - F.sigmoid(min_in)).log()
50
+ log_one_minus_cdf_min = -F.softplus(min_in)
51
+
52
+ # probability for all other cases
53
+ cdf_delta = cdf_plus - cdf_min
54
+
55
+ mid_in = inv_stdv * centered_y
56
+ # log probability in the center of the bin, to be used in extreme cases
57
+ # (not actually used in our code)
58
+ log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
59
+
60
+ # tf equivalent
61
+ """
62
+ log_probs = tf.where(x < -0.999, log_cdf_plus,
63
+ tf.where(x > 0.999, log_one_minus_cdf_min,
64
+ tf.where(cdf_delta > 1e-5,
65
+ tf.log(tf.maximum(cdf_delta, 1e-12)),
66
+ log_pdf_mid - np.log(127.5))))
67
+ """
68
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
69
+ # for num_classes=65536 case? 1e-7? not sure..
70
+ inner_inner_cond = (cdf_delta > 1e-5).float()
71
+
72
+ inner_inner_out = inner_inner_cond * torch.log(
73
+ torch.clamp(cdf_delta, min=1e-12)
74
+ ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
75
+ inner_cond = (y > 0.999).float()
76
+ inner_out = (
77
+ inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
78
+ )
79
+ cond = (y < -0.999).float()
80
+ log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
81
+
82
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
83
+
84
+ if reduce:
85
+ return -torch.mean(log_sum_exp(log_probs))
86
+ else:
87
+ return -log_sum_exp(log_probs).unsqueeze(-1)
88
+
89
+
90
+ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
91
+ """
92
+ Sample from discretized mixture of logistic distributions
93
+ Args:
94
+ y (Tensor): B x C x T
95
+ log_scale_min (float): Log scale minimum value
96
+ Returns:
97
+ Tensor: sample in range of [-1, 1].
98
+ """
99
+ if log_scale_min is None:
100
+ log_scale_min = float(np.log(1e-14))
101
+ assert y.size(1) % 3 == 0
102
+ nr_mix = y.size(1) // 3
103
+
104
+ # B x T x C
105
+ y = y.transpose(1, 2)
106
+ logit_probs = y[:, :, :nr_mix]
107
+
108
+ # sample mixture indicator from softmax
109
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
110
+ temp = logit_probs.data - torch.log(-torch.log(temp))
111
+ _, argmax = temp.max(dim=-1)
112
+
113
+ # (B, T) -> (B, T, nr_mix)
114
+ one_hot = to_one_hot(argmax, nr_mix)
115
+ # select logistic parameters
116
+ means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
117
+ log_scales = torch.clamp(
118
+ torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
119
+ )
120
+ # sample from logistic & clip to interval
121
+ # we don't actually round to the nearest 8bit value when sampling
122
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
123
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
124
+
125
+ x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
126
+
127
+ return x
128
+
129
+
130
+ def to_one_hot(tensor, n, fill_with=1.0):
131
+ # we perform one hot encore with respect to the last axis
132
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
133
+ if tensor.is_cuda:
134
+ one_hot = one_hot.cuda()
135
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
136
+ return one_hot
mockingbirdforuse/vocoder/hifigan/__init__.py ADDED
File without changes
mockingbirdforuse/vocoder/hifigan/hparams.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class HParams:
6
+ resblock = "1"
7
+ num_gpus = 0
8
+ batch_size = 16
9
+ learning_rate = 0.0002
10
+ adam_b1 = 0.8
11
+ adam_b2 = 0.99
12
+ lr_decay = 0.999
13
+ seed = 1234
14
+
15
+ upsample_rates = [5, 5, 4, 2]
16
+ upsample_kernel_sizes = [10, 10, 8, 4]
17
+ upsample_initial_channel = 512
18
+ resblock_kernel_sizes = [3, 7, 11]
19
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
20
+
21
+ segment_size = 6400
22
+ num_mels = 80
23
+ num_freq = 1025
24
+ n_fft = 1024
25
+ hop_size = 200
26
+ win_size = 800
27
+
28
+ sampling_rate = 16000
29
+
30
+ fmin = 0
31
+ fmax = 7600
32
+ fmax_for_loss = None
33
+
34
+ num_workers = 4
35
+
36
+
37
+ hparams = HParams()
mockingbirdforuse/vocoder/hifigan/inference.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+
4
+ from .hparams import hparams as hp
5
+ from .models import Generator
6
+ from ...log import logger
7
+
8
+
9
+ class HifiGanVocoder:
10
+ def __init__(self, model_path: Path):
11
+ torch.manual_seed(hp.seed)
12
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.generator = Generator(hp).to(self._device)
14
+
15
+ logger.debug("Loading '{}'".format(model_path))
16
+ state_dict_g = torch.load(model_path, map_location=self._device)
17
+ logger.debug("Complete.")
18
+
19
+ self.generator.load_state_dict(state_dict_g["generator"])
20
+ self.generator.eval()
21
+ self.generator.remove_weight_norm()
22
+
23
+ def infer_waveform(self, mel):
24
+ mel = torch.FloatTensor(mel).to(self._device)
25
+ mel = mel.unsqueeze(0)
26
+
27
+ with torch.no_grad():
28
+ y_g_hat = self.generator(mel)
29
+ audio = y_g_hat.squeeze()
30
+ audio = audio.cpu().numpy()
31
+
32
+ return audio, hp.sampling_rate
mockingbirdforuse/vocoder/hifigan/models.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.spectral_norm import spectral_norm
5
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
6
+ from torch.nn.utils.weight_norm import weight_norm, remove_weight_norm
7
+ from ...log import logger
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ def init_weights(m, mean=0.0, std=0.01):
13
+ classname = m.__class__.__name__
14
+ if classname.find("Conv") != -1:
15
+ m.weight.data.normal_(mean, std)
16
+
17
+
18
+ def get_padding(kernel_size, dilation=1):
19
+ return int((kernel_size * dilation - dilation) / 2)
20
+
21
+
22
+ class ResBlock1(torch.nn.Module):
23
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
24
+ super(ResBlock1, self).__init__()
25
+ self.h = h
26
+ self.convs1 = nn.ModuleList(
27
+ [
28
+ weight_norm(
29
+ Conv1d(
30
+ channels,
31
+ channels,
32
+ kernel_size,
33
+ 1,
34
+ dilation=dilation[0],
35
+ padding=get_padding(kernel_size, dilation[0]),
36
+ )
37
+ ),
38
+ weight_norm(
39
+ Conv1d(
40
+ channels,
41
+ channels,
42
+ kernel_size,
43
+ 1,
44
+ dilation=dilation[1],
45
+ padding=get_padding(kernel_size, dilation[1]),
46
+ )
47
+ ),
48
+ weight_norm(
49
+ Conv1d(
50
+ channels,
51
+ channels,
52
+ kernel_size,
53
+ 1,
54
+ dilation=dilation[2],
55
+ padding=get_padding(kernel_size, dilation[2]),
56
+ )
57
+ ),
58
+ ]
59
+ )
60
+ self.convs1.apply(init_weights)
61
+
62
+ self.convs2 = nn.ModuleList(
63
+ [
64
+ weight_norm(
65
+ Conv1d(
66
+ channels,
67
+ channels,
68
+ kernel_size,
69
+ 1,
70
+ dilation=1,
71
+ padding=get_padding(kernel_size, 1),
72
+ )
73
+ ),
74
+ weight_norm(
75
+ Conv1d(
76
+ channels,
77
+ channels,
78
+ kernel_size,
79
+ 1,
80
+ dilation=1,
81
+ padding=get_padding(kernel_size, 1),
82
+ )
83
+ ),
84
+ weight_norm(
85
+ Conv1d(
86
+ channels,
87
+ channels,
88
+ kernel_size,
89
+ 1,
90
+ dilation=1,
91
+ padding=get_padding(kernel_size, 1),
92
+ )
93
+ ),
94
+ ]
95
+ )
96
+ self.convs2.apply(init_weights)
97
+
98
+ def forward(self, x):
99
+ for c1, c2 in zip(self.convs1, self.convs2):
100
+ xt = F.leaky_relu(x, LRELU_SLOPE)
101
+ xt = c1(xt)
102
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
103
+ xt = c2(xt)
104
+ x = xt + x
105
+ return x
106
+
107
+ def remove_weight_norm(self):
108
+ for l in self.convs1:
109
+ remove_weight_norm(l)
110
+ for l in self.convs2:
111
+ remove_weight_norm(l)
112
+
113
+
114
+ class ResBlock2(torch.nn.Module):
115
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
116
+ super(ResBlock2, self).__init__()
117
+ self.h = h
118
+ self.convs = nn.ModuleList(
119
+ [
120
+ weight_norm(
121
+ Conv1d(
122
+ channels,
123
+ channels,
124
+ kernel_size,
125
+ 1,
126
+ dilation=dilation[0],
127
+ padding=get_padding(kernel_size, dilation[0]),
128
+ )
129
+ ),
130
+ weight_norm(
131
+ Conv1d(
132
+ channels,
133
+ channels,
134
+ kernel_size,
135
+ 1,
136
+ dilation=dilation[1],
137
+ padding=get_padding(kernel_size, dilation[1]),
138
+ )
139
+ ),
140
+ ]
141
+ )
142
+ self.convs.apply(init_weights)
143
+
144
+ def forward(self, x):
145
+ for c in self.convs:
146
+ xt = F.leaky_relu(x, LRELU_SLOPE)
147
+ xt = c(xt)
148
+ x = xt + x
149
+ return x
150
+
151
+ def remove_weight_norm(self):
152
+ for l in self.convs:
153
+ remove_weight_norm(l)
154
+
155
+
156
+ class InterpolationBlock(torch.nn.Module):
157
+ def __init__(
158
+ self, scale_factor, mode="nearest", align_corners=None, downsample=False
159
+ ):
160
+ super(InterpolationBlock, self).__init__()
161
+ self.downsample = downsample
162
+ self.scale_factor = scale_factor
163
+ self.mode = mode
164
+ self.align_corners = align_corners
165
+
166
+ def forward(self, x):
167
+ outputs = F.interpolate(
168
+ x,
169
+ size=x.shape[-1] * self.scale_factor
170
+ if not self.downsample
171
+ else x.shape[-1] // self.scale_factor,
172
+ mode=self.mode,
173
+ align_corners=self.align_corners,
174
+ recompute_scale_factor=False,
175
+ )
176
+ return outputs
177
+
178
+
179
+ class Generator(torch.nn.Module):
180
+ def __init__(self, h):
181
+ super(Generator, self).__init__()
182
+ self.h = h
183
+ self.num_kernels = len(h.resblock_kernel_sizes)
184
+ self.num_upsamples = len(h.upsample_rates)
185
+ self.conv_pre = weight_norm(
186
+ Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
187
+ )
188
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
189
+
190
+ self.ups = nn.ModuleList()
191
+ # for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
192
+ # # self.ups.append(weight_norm(
193
+ # # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
194
+ # # k, u, padding=(k-u)//2)))
195
+ if h.sampling_rate == 24000:
196
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
197
+ self.ups.append(
198
+ torch.nn.Sequential(
199
+ InterpolationBlock(u),
200
+ weight_norm(
201
+ torch.nn.Conv1d(
202
+ h.upsample_initial_channel // (2**i),
203
+ h.upsample_initial_channel // (2 ** (i + 1)),
204
+ k,
205
+ padding=(k - 1) // 2,
206
+ )
207
+ ),
208
+ )
209
+ )
210
+ else:
211
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
212
+ self.ups.append(
213
+ weight_norm(
214
+ ConvTranspose1d(
215
+ h.upsample_initial_channel // (2**i),
216
+ h.upsample_initial_channel // (2 ** (i + 1)),
217
+ k,
218
+ u,
219
+ padding=(u // 2 + u % 2),
220
+ output_padding=u % 2,
221
+ )
222
+ )
223
+ )
224
+ self.resblocks = nn.ModuleList()
225
+ ch = 0
226
+ for i in range(len(self.ups)):
227
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
228
+ for j, (k, d) in enumerate(
229
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
230
+ ):
231
+ self.resblocks.append(resblock(h, ch, k, d))
232
+
233
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
234
+ self.ups.apply(init_weights)
235
+ self.conv_post.apply(init_weights)
236
+
237
+ def forward(self, x):
238
+ x = self.conv_pre(x)
239
+ for i in range(self.num_upsamples):
240
+ x = F.leaky_relu(x, LRELU_SLOPE)
241
+ x = self.ups[i](x)
242
+ xs = None
243
+ for j in range(self.num_kernels):
244
+ if xs is None:
245
+ xs = self.resblocks[i * self.num_kernels + j](x)
246
+ else:
247
+ xs += self.resblocks[i * self.num_kernels + j](x)
248
+ x = xs / self.num_kernels
249
+ x = F.leaky_relu(x)
250
+ x = self.conv_post(x)
251
+ x = torch.tanh(x)
252
+ return x
253
+
254
+ def remove_weight_norm(self):
255
+ logger.debug("Removing weight norm...")
256
+ for module in self.ups:
257
+ if self.h.sampling_rate == 24000:
258
+ remove_weight_norm(module[-1])
259
+ else:
260
+ remove_weight_norm(module)
261
+ for module in self.resblocks:
262
+ module.remove_weight_norm()
263
+ remove_weight_norm(self.conv_pre)
264
+ remove_weight_norm(self.conv_post)
265
+
266
+
267
+ class DiscriminatorP(torch.nn.Module):
268
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
269
+ super(DiscriminatorP, self).__init__()
270
+ self.period = period
271
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
272
+ self.convs = nn.ModuleList(
273
+ [
274
+ norm_f(
275
+ Conv2d(
276
+ 1,
277
+ 32,
278
+ (kernel_size, 1),
279
+ (stride, 1),
280
+ padding=(get_padding(5, 1), 0),
281
+ )
282
+ ),
283
+ norm_f(
284
+ Conv2d(
285
+ 32,
286
+ 128,
287
+ (kernel_size, 1),
288
+ (stride, 1),
289
+ padding=(get_padding(5, 1), 0),
290
+ )
291
+ ),
292
+ norm_f(
293
+ Conv2d(
294
+ 128,
295
+ 512,
296
+ (kernel_size, 1),
297
+ (stride, 1),
298
+ padding=(get_padding(5, 1), 0),
299
+ )
300
+ ),
301
+ norm_f(
302
+ Conv2d(
303
+ 512,
304
+ 1024,
305
+ (kernel_size, 1),
306
+ (stride, 1),
307
+ padding=(get_padding(5, 1), 0),
308
+ )
309
+ ),
310
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
311
+ ]
312
+ )
313
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
314
+
315
+ def forward(self, x):
316
+ fmap = []
317
+
318
+ # 1d to 2d
319
+ b, c, t = x.shape
320
+ if t % self.period != 0: # pad first
321
+ n_pad = self.period - (t % self.period)
322
+ x = F.pad(x, (0, n_pad), "reflect")
323
+ t = t + n_pad
324
+ x = x.view(b, c, t // self.period, self.period)
325
+
326
+ for l in self.convs:
327
+ x = l(x)
328
+ x = F.leaky_relu(x, LRELU_SLOPE)
329
+ fmap.append(x)
330
+ x = self.conv_post(x)
331
+ fmap.append(x)
332
+ x = torch.flatten(x, 1, -1)
333
+
334
+ return x, fmap
335
+
336
+
337
+ class MultiPeriodDiscriminator(torch.nn.Module):
338
+ def __init__(self):
339
+ super(MultiPeriodDiscriminator, self).__init__()
340
+ self.discriminators = nn.ModuleList(
341
+ [
342
+ DiscriminatorP(2),
343
+ DiscriminatorP(3),
344
+ DiscriminatorP(5),
345
+ DiscriminatorP(7),
346
+ DiscriminatorP(11),
347
+ ]
348
+ )
349
+
350
+ def forward(self, y, y_hat):
351
+ y_d_rs = []
352
+ y_d_gs = []
353
+ fmap_rs = []
354
+ fmap_gs = []
355
+ for i, d in enumerate(self.discriminators):
356
+ y_d_r, fmap_r = d(y)
357
+ y_d_g, fmap_g = d(y_hat)
358
+ y_d_rs.append(y_d_r)
359
+ fmap_rs.append(fmap_r)
360
+ y_d_gs.append(y_d_g)
361
+ fmap_gs.append(fmap_g)
362
+
363
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
364
+
365
+
366
+ class DiscriminatorS(torch.nn.Module):
367
+ def __init__(self, use_spectral_norm=False):
368
+ super(DiscriminatorS, self).__init__()
369
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
370
+ self.convs = nn.ModuleList(
371
+ [
372
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
373
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
374
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
375
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
376
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
377
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
378
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
379
+ ]
380
+ )
381
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
382
+
383
+ def forward(self, x):
384
+ fmap = []
385
+ for l in self.convs:
386
+ x = l(x)
387
+ x = F.leaky_relu(x, LRELU_SLOPE)
388
+ fmap.append(x)
389
+ x = self.conv_post(x)
390
+ fmap.append(x)
391
+ x = torch.flatten(x, 1, -1)
392
+
393
+ return x, fmap
394
+
395
+
396
+ class MultiScaleDiscriminator(torch.nn.Module):
397
+ def __init__(self):
398
+ super(MultiScaleDiscriminator, self).__init__()
399
+ self.discriminators = nn.ModuleList(
400
+ [
401
+ DiscriminatorS(use_spectral_norm=True),
402
+ DiscriminatorS(),
403
+ DiscriminatorS(),
404
+ ]
405
+ )
406
+ self.meanpools = nn.ModuleList(
407
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
408
+ )
409
+
410
+ def forward(self, y, y_hat):
411
+ y_d_rs = []
412
+ y_d_gs = []
413
+ fmap_rs = []
414
+ fmap_gs = []
415
+ for i, d in enumerate(self.discriminators):
416
+ if i != 0:
417
+ y = self.meanpools[i - 1](y)
418
+ y_hat = self.meanpools[i - 1](y_hat)
419
+ y_d_r, fmap_r = d(y)
420
+ y_d_g, fmap_g = d(y_hat)
421
+ y_d_rs.append(y_d_r)
422
+ fmap_rs.append(fmap_r)
423
+ y_d_gs.append(y_d_g)
424
+ fmap_gs.append(fmap_g)
425
+
426
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
427
+
428
+
429
+ def feature_loss(fmap_r, fmap_g):
430
+ loss = 0
431
+ for dr, dg in zip(fmap_r, fmap_g):
432
+ for rl, gl in zip(dr, dg):
433
+ loss += torch.mean(torch.abs(rl - gl))
434
+
435
+ return loss * 2
436
+
437
+
438
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
439
+ loss = 0
440
+ r_losses = []
441
+ g_losses = []
442
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
443
+ r_loss = torch.mean((1 - dr) ** 2)
444
+ g_loss = torch.mean(dg**2)
445
+ loss += r_loss + g_loss
446
+ r_losses.append(r_loss.item())
447
+ g_losses.append(g_loss.item())
448
+
449
+ return loss, r_losses, g_losses
450
+
451
+
452
+ def generator_loss(disc_outputs):
453
+ loss = 0
454
+ gen_losses = []
455
+ for dg in disc_outputs:
456
+ l = torch.mean((1 - dg) ** 2)
457
+ gen_losses.append(l)
458
+ loss += l
459
+
460
+ return loss, gen_losses
mockingbirdforuse/vocoder/wavernn/__init__.py ADDED
File without changes
mockingbirdforuse/vocoder/wavernn/audio.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import librosa
3
+ import numpy as np
4
+ import soundfile as sf
5
+ from scipy.signal import lfilter
6
+
7
+ from .hparams import hparams as hp
8
+
9
+
10
+ def label_2_float(x, bits):
11
+ return 2 * x / (2**bits - 1.0) - 1.0
12
+
13
+
14
+ def float_2_label(x, bits):
15
+ assert abs(x).max() <= 1.0
16
+ x = (x + 1.0) * (2**bits - 1) / 2
17
+ return x.clip(0, 2**bits - 1)
18
+
19
+
20
+ def load_wav(path):
21
+ return librosa.load(str(path), sr=hp.sample_rate)[0]
22
+
23
+
24
+ def save_wav(x, path):
25
+ sf.write(path, x.astype(np.float32), hp.sample_rate)
26
+
27
+
28
+ def split_signal(x):
29
+ unsigned = x + 2**15
30
+ coarse = unsigned // 256
31
+ fine = unsigned % 256
32
+ return coarse, fine
33
+
34
+
35
+ def combine_signal(coarse, fine):
36
+ return coarse * 256 + fine - 2**15
37
+
38
+
39
+ def encode_16bits(x):
40
+ return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
41
+
42
+
43
+ mel_basis = None
44
+
45
+
46
+ def linear_to_mel(spectrogram):
47
+ global mel_basis
48
+ if mel_basis is None:
49
+ mel_basis = build_mel_basis()
50
+ return np.dot(mel_basis, spectrogram)
51
+
52
+
53
+ def build_mel_basis():
54
+ return librosa.filters.mel(
55
+ sr=hp.sample_rate,
56
+ n_fft=hp.n_fft,
57
+ n_mels=hp.num_mels,
58
+ fmin=hp.fmin,
59
+ )
60
+
61
+
62
+ def normalize(S):
63
+ return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
64
+
65
+
66
+ def denormalize(S):
67
+ return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
68
+
69
+
70
+ def amp_to_db(x):
71
+ return 20 * np.log10(np.maximum(1e-5, x))
72
+
73
+
74
+ def db_to_amp(x):
75
+ return np.power(10.0, x * 0.05)
76
+
77
+
78
+ def spectrogram(y):
79
+ D = stft(y)
80
+ S = amp_to_db(np.abs(D)) - hp.ref_level_db
81
+ return normalize(S)
82
+
83
+
84
+ def melspectrogram(y):
85
+ D = stft(y)
86
+ S = amp_to_db(linear_to_mel(np.abs(D)))
87
+ return normalize(S)
88
+
89
+
90
+ def stft(y):
91
+ return librosa.stft(
92
+ y=y,
93
+ n_fft=hp.n_fft,
94
+ hop_length=hp.hop_length,
95
+ win_length=hp.win_length,
96
+ )
97
+
98
+
99
+ def pre_emphasis(x):
100
+ return lfilter([1, -hp.preemphasis], [1], x)
101
+
102
+
103
+ def de_emphasis(x):
104
+ return lfilter([1], [1, -hp.preemphasis], x)
105
+
106
+
107
+ def encode_mu_law(x, mu):
108
+ mu = mu - 1
109
+ fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
110
+ return np.floor((fx + 1) / 2 * mu + 0.5)
111
+
112
+
113
+ def decode_mu_law(y, mu, from_labels=True):
114
+ if from_labels:
115
+ y = label_2_float(y, math.log2(mu))
116
+ mu = mu - 1
117
+ x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
118
+ return x
mockingbirdforuse/vocoder/wavernn/hparams.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from ...synthesizer.hparams import hparams as _syn_hp
3
+
4
+
5
+ @dataclass
6
+ class HParams:
7
+ # Audio settings------------------------------------------------------------------------
8
+ # Match the values of the synthesizer
9
+ sample_rate = _syn_hp.sample_rate
10
+ n_fft = _syn_hp.n_fft
11
+ num_mels = _syn_hp.num_mels
12
+ hop_length = _syn_hp.hop_size
13
+ win_length = _syn_hp.win_size
14
+ fmin = _syn_hp.fmin
15
+ min_level_db = _syn_hp.min_level_db
16
+ ref_level_db = _syn_hp.ref_level_db
17
+ mel_max_abs_value = _syn_hp.max_abs_value
18
+ preemphasis = _syn_hp.preemphasis
19
+ apply_preemphasis = _syn_hp.preemphasize
20
+
21
+ bits = 9 # bit depth of signal
22
+ mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode
23
+ # below
24
+
25
+ # WAVERNN / VOCODER --------------------------------------------------------------------------------
26
+ voc_mode = "RAW" # either 'RAW' (softmax on raw bits) or 'MOL' (sample from
27
+ # mixture of logistics)
28
+ voc_upsample_factors = (
29
+ 5,
30
+ 5,
31
+ 8,
32
+ ) # NB - this needs to correctly factorise hop_length
33
+ voc_rnn_dims = 512
34
+ voc_fc_dims = 512
35
+ voc_compute_dims = 128
36
+ voc_res_out_dims = 128
37
+ voc_res_blocks = 10
38
+
39
+ # Training
40
+ voc_batch_size = 100
41
+ voc_lr = 1e-4
42
+ voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint
43
+ voc_pad = 2 # this will pad the input so that the resnet can 'see' wider
44
+ # than input length
45
+ voc_seq_len = hop_length * 5 # must be a multiple of hop_length
46
+
47
+ # Generating / Synthesizing
48
+ voc_gen_batched = True # very fast (realtime+) single utterance batched generation
49
+ voc_target = 8000 # target number of samples to be generated in each batch entry
50
+ voc_overlap = 400 # number of samples for crossfading between batches
51
+
52
+
53
+ hparams = HParams()
mockingbirdforuse/vocoder/wavernn/inference.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+
4
+ from .hparams import hparams as hp
5
+ from .models.fatchord_version import WaveRNN
6
+ from ...log import logger
7
+
8
+
9
+ class WaveRNNVocoder:
10
+ def __init__(self, model_path: Path):
11
+ logger.debug("Building Wave-RNN")
12
+ self._model = WaveRNN(
13
+ rnn_dims=hp.voc_rnn_dims,
14
+ fc_dims=hp.voc_fc_dims,
15
+ bits=hp.bits,
16
+ pad=hp.voc_pad,
17
+ upsample_factors=hp.voc_upsample_factors,
18
+ feat_dims=hp.num_mels,
19
+ compute_dims=hp.voc_compute_dims,
20
+ res_out_dims=hp.voc_res_out_dims,
21
+ res_blocks=hp.voc_res_blocks,
22
+ hop_length=hp.hop_length,
23
+ sample_rate=hp.sample_rate,
24
+ mode=hp.voc_mode,
25
+ )
26
+
27
+ if torch.cuda.is_available():
28
+ self._model = self._model.cuda()
29
+ self._device = torch.device("cuda")
30
+ else:
31
+ self._device = torch.device("cpu")
32
+
33
+ logger.debug("Loading model weights at %s" % model_path)
34
+ checkpoint = torch.load(model_path, self._device)
35
+ self._model.load_state_dict(checkpoint["model_state"])
36
+ self._model.eval()
37
+
38
+ def infer_waveform(
39
+ self, mel, normalize=True, batched=True, target=8000, overlap=800
40
+ ):
41
+ """
42
+ Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
43
+ that of the synthesizer!)
44
+
45
+ :param normalize:
46
+ :param batched:
47
+ :param target:
48
+ :param overlap:
49
+ :return:
50
+ """
51
+
52
+ if normalize:
53
+ mel = mel / hp.mel_max_abs_value
54
+ mel = torch.from_numpy(mel[None, ...])
55
+ wav = self._model.generate(mel, batched, target, overlap, hp.mu_law)
56
+ return wav, hp.sample_rate
mockingbirdforuse/vocoder/wavernn/models/deepmind_version.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from ..audio import combine_signal
9
+ from ....log import logger
10
+
11
+
12
+ class WaveRNN(nn.Module):
13
+ def __init__(self, hidden_size=896, quantisation=256):
14
+ super(WaveRNN, self).__init__()
15
+
16
+ self.hidden_size = hidden_size
17
+ self.split_size = hidden_size // 2
18
+
19
+ # The main matmul
20
+ self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
21
+
22
+ # Output fc layers
23
+ self.O1 = nn.Linear(self.split_size, self.split_size)
24
+ self.O2 = nn.Linear(self.split_size, quantisation)
25
+ self.O3 = nn.Linear(self.split_size, self.split_size)
26
+ self.O4 = nn.Linear(self.split_size, quantisation)
27
+
28
+ # Input fc layers
29
+ self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
30
+ self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
31
+
32
+ # biases for the gates
33
+ self.bias_u = Parameter(torch.zeros(self.hidden_size))
34
+ self.bias_r = Parameter(torch.zeros(self.hidden_size))
35
+ self.bias_e = Parameter(torch.zeros(self.hidden_size))
36
+
37
+ # display num params
38
+ self.num_params()
39
+
40
+ def forward(self, prev_y, prev_hidden, current_coarse):
41
+
42
+ # Main matmul - the projection is split 3 ways
43
+ R_hidden = self.R(prev_hidden)
44
+ (
45
+ R_u,
46
+ R_r,
47
+ R_e,
48
+ ) = torch.split(R_hidden, self.hidden_size, dim=1)
49
+
50
+ # Project the prev input
51
+ coarse_input_proj = self.I_coarse(prev_y)
52
+ I_coarse_u, I_coarse_r, I_coarse_e = torch.split(
53
+ coarse_input_proj, self.split_size, dim=1
54
+ )
55
+
56
+ # Project the prev input and current coarse sample
57
+ fine_input = torch.cat([prev_y, current_coarse], dim=1)
58
+ fine_input_proj = self.I_fine(fine_input)
59
+ I_fine_u, I_fine_r, I_fine_e = torch.split(
60
+ fine_input_proj, self.split_size, dim=1
61
+ )
62
+
63
+ # concatenate for the gates
64
+ I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
65
+ I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
66
+ I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
67
+
68
+ # Compute all gates for coarse and fine
69
+ u = F.sigmoid(R_u + I_u + self.bias_u)
70
+ r = F.sigmoid(R_r + I_r + self.bias_r)
71
+ e = torch.tanh(r * R_e + I_e + self.bias_e)
72
+ hidden = u * prev_hidden + (1.0 - u) * e
73
+
74
+ # Split the hidden state
75
+ hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
76
+
77
+ # Compute outputs
78
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
79
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
80
+
81
+ return out_coarse, out_fine, hidden
82
+
83
+ def generate(self, seq_len):
84
+ with torch.no_grad():
85
+ # First split up the biases for the gates
86
+ b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
87
+ b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
88
+ b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
89
+
90
+ # Lists for the two output seqs
91
+ c_outputs, f_outputs = [], []
92
+
93
+ # Some initial inputs
94
+ out_coarse = torch.LongTensor([0]).cuda()
95
+ out_fine = torch.LongTensor([0]).cuda()
96
+
97
+ # We'll meed a hidden state
98
+ hidden = self.init_hidden()
99
+
100
+ # Need a clock for display
101
+ start = time.time()
102
+
103
+ # Loop for generation
104
+ for i in range(seq_len):
105
+
106
+ # Split into two hidden states
107
+ hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
108
+
109
+ # Scale and concat previous predictions
110
+ out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.0
111
+ out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.0
112
+ prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
113
+
114
+ # Project input
115
+ coarse_input_proj = self.I_coarse(prev_outputs)
116
+ I_coarse_u, I_coarse_r, I_coarse_e = torch.split(
117
+ coarse_input_proj, self.split_size, dim=1
118
+ )
119
+
120
+ # Project hidden state and split 6 ways
121
+ R_hidden = self.R(hidden)
122
+ (
123
+ R_coarse_u,
124
+ R_fine_u,
125
+ R_coarse_r,
126
+ R_fine_r,
127
+ R_coarse_e,
128
+ R_fine_e,
129
+ ) = torch.split(R_hidden, self.split_size, dim=1)
130
+
131
+ # Compute the coarse gates
132
+ u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
133
+ r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
134
+ e = torch.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
135
+ hidden_coarse = u * hidden_coarse + (1.0 - u) * e
136
+
137
+ # Compute the coarse output
138
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
139
+ posterior = F.softmax(out_coarse, dim=1)
140
+ distrib = torch.distributions.Categorical(posterior)
141
+ out_coarse = distrib.sample()
142
+ c_outputs.append(out_coarse)
143
+
144
+ # Project the [prev outputs and predicted coarse sample]
145
+ coarse_pred = out_coarse.float() / 127.5 - 1.0
146
+ fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
147
+ fine_input_proj = self.I_fine(fine_input)
148
+ I_fine_u, I_fine_r, I_fine_e = torch.split(
149
+ fine_input_proj, self.split_size, dim=1
150
+ )
151
+
152
+ # Compute the fine gates
153
+ u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
154
+ r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
155
+ e = torch.tanh(r * R_fine_e + I_fine_e + b_fine_e)
156
+ hidden_fine = u * hidden_fine + (1.0 - u) * e
157
+
158
+ # Compute the fine output
159
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
160
+ posterior = F.softmax(out_fine, dim=1)
161
+ distrib = torch.distributions.Categorical(posterior)
162
+ out_fine = distrib.sample()
163
+ f_outputs.append(out_fine)
164
+
165
+ # Put the hidden state back together
166
+ hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
167
+
168
+ coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
169
+ fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
170
+ output = combine_signal(coarse, fine)
171
+
172
+ return output, coarse, fine
173
+
174
+ def init_hidden(self, batch_size=1):
175
+ return torch.zeros(batch_size, self.hidden_size).cuda()
176
+
177
+ def num_params(self):
178
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
179
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
180
+ logger.debug("Trainable Parameters: %.3f million" % parameters)
mockingbirdforuse/vocoder/wavernn/models/fatchord_version.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from ..audio import de_emphasis, decode_mu_law
9
+ from ..hparams import hparams as hp
10
+ from ...distribution import sample_from_discretized_mix_logistic
11
+ from ....log import logger
12
+
13
+
14
+ class ResBlock(nn.Module):
15
+ def __init__(self, dims):
16
+ super().__init__()
17
+ self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
18
+ self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
19
+ self.batch_norm1 = nn.BatchNorm1d(dims)
20
+ self.batch_norm2 = nn.BatchNorm1d(dims)
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+ x = self.conv1(x)
25
+ x = self.batch_norm1(x)
26
+ x = F.relu(x)
27
+ x = self.conv2(x)
28
+ x = self.batch_norm2(x)
29
+ return x + residual
30
+
31
+
32
+ class MelResNet(nn.Module):
33
+ def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
34
+ super().__init__()
35
+ k_size = pad * 2 + 1
36
+ self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
37
+ self.batch_norm = nn.BatchNorm1d(compute_dims)
38
+ self.layers = nn.ModuleList()
39
+ for i in range(res_blocks):
40
+ self.layers.append(ResBlock(compute_dims))
41
+ self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
42
+
43
+ def forward(self, x):
44
+ x = self.conv_in(x)
45
+ x = self.batch_norm(x)
46
+ x = F.relu(x)
47
+ for f in self.layers:
48
+ x = f(x)
49
+ x = self.conv_out(x)
50
+ return x
51
+
52
+
53
+ class Stretch2d(nn.Module):
54
+ def __init__(self, x_scale, y_scale):
55
+ super().__init__()
56
+ self.x_scale = x_scale
57
+ self.y_scale = y_scale
58
+
59
+ def forward(self, x):
60
+ b, c, h, w = x.size()
61
+ x = x.unsqueeze(-1).unsqueeze(3)
62
+ x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
63
+ return x.view(b, c, h * self.y_scale, w * self.x_scale)
64
+
65
+
66
+ class UpsampleNetwork(nn.Module):
67
+ def __init__(
68
+ self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad
69
+ ):
70
+ super().__init__()
71
+ total_scale = np.cumproduct(upsample_scales)[-1]
72
+ self.indent = pad * total_scale
73
+ self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
74
+ self.resnet_stretch = Stretch2d(total_scale, 1)
75
+ self.up_layers = nn.ModuleList()
76
+ for scale in upsample_scales:
77
+ k_size = (1, scale * 2 + 1)
78
+ padding = (0, scale)
79
+ stretch = Stretch2d(scale, 1)
80
+ conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
81
+ conv.weight.data.fill_(1.0 / k_size[1])
82
+ self.up_layers.append(stretch)
83
+ self.up_layers.append(conv)
84
+
85
+ def forward(self, m):
86
+ aux = self.resnet(m).unsqueeze(1)
87
+ aux = self.resnet_stretch(aux)
88
+ aux = aux.squeeze(1)
89
+ m = m.unsqueeze(1)
90
+ for f in self.up_layers:
91
+ m = f(m)
92
+ m = m.squeeze(1)[:, :, self.indent : -self.indent]
93
+ return m.transpose(1, 2), aux.transpose(1, 2)
94
+
95
+
96
+ class WaveRNN(nn.Module):
97
+ def __init__(
98
+ self,
99
+ rnn_dims,
100
+ fc_dims,
101
+ bits,
102
+ pad,
103
+ upsample_factors,
104
+ feat_dims,
105
+ compute_dims,
106
+ res_out_dims,
107
+ res_blocks,
108
+ hop_length,
109
+ sample_rate,
110
+ mode="RAW",
111
+ ):
112
+ super().__init__()
113
+ self.mode = mode
114
+ self.pad = pad
115
+ if self.mode == "RAW":
116
+ self.n_classes = 2**bits
117
+ elif self.mode == "MOL":
118
+ self.n_classes = 30
119
+ else:
120
+ RuntimeError("Unknown model mode value - ", self.mode)
121
+
122
+ self.rnn_dims = rnn_dims
123
+ self.aux_dims = res_out_dims // 4
124
+ self.hop_length = hop_length
125
+ self.sample_rate = sample_rate
126
+
127
+ self.upsample = UpsampleNetwork(
128
+ feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad
129
+ )
130
+ self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
131
+ self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
132
+ self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
133
+ self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
134
+ self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
135
+ self.fc3 = nn.Linear(fc_dims, self.n_classes)
136
+
137
+ self.step = Parameter(torch.zeros(1).long(), requires_grad=False)
138
+ self.num_params()
139
+
140
+ def forward(self, x, mels):
141
+ self.step += 1
142
+ bsize = x.size(0)
143
+ if torch.cuda.is_available():
144
+ h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
145
+ h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
146
+ else:
147
+ h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
148
+ h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
149
+ mels, aux = self.upsample(mels)
150
+
151
+ aux_idx = [self.aux_dims * i for i in range(5)]
152
+ a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
153
+ a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
154
+ a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
155
+ a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
156
+
157
+ x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
158
+ x = self.I(x)
159
+ res = x
160
+ x, _ = self.rnn1(x, h1)
161
+
162
+ x = x + res
163
+ res = x
164
+ x = torch.cat([x, a2], dim=2)
165
+ x, _ = self.rnn2(x, h2)
166
+
167
+ x = x + res
168
+ x = torch.cat([x, a3], dim=2)
169
+ x = F.relu(self.fc1(x))
170
+
171
+ x = torch.cat([x, a4], dim=2)
172
+ x = F.relu(self.fc2(x))
173
+ return self.fc3(x)
174
+
175
+ def generate(self, mels, batched, target, overlap, mu_law):
176
+ mu_law = mu_law if self.mode == "RAW" else False
177
+
178
+ self.eval()
179
+ output = []
180
+ start = time.time()
181
+ rnn1 = self.get_gru_cell(self.rnn1)
182
+ rnn2 = self.get_gru_cell(self.rnn2)
183
+
184
+ with torch.no_grad():
185
+ if torch.cuda.is_available():
186
+ mels = mels.cuda()
187
+ else:
188
+ mels = mels.cpu()
189
+ wave_len = (mels.size(-1) - 1) * self.hop_length
190
+ mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
191
+ mels, aux = self.upsample(mels.transpose(1, 2))
192
+
193
+ if batched:
194
+ mels = self.fold_with_overlap(mels, target, overlap)
195
+ aux = self.fold_with_overlap(aux, target, overlap)
196
+
197
+ b_size, seq_len, _ = mels.size()
198
+
199
+ if torch.cuda.is_available():
200
+ h1 = torch.zeros(b_size, self.rnn_dims).cuda()
201
+ h2 = torch.zeros(b_size, self.rnn_dims).cuda()
202
+ x = torch.zeros(b_size, 1).cuda()
203
+ else:
204
+ h1 = torch.zeros(b_size, self.rnn_dims).cpu()
205
+ h2 = torch.zeros(b_size, self.rnn_dims).cpu()
206
+ x = torch.zeros(b_size, 1).cpu()
207
+
208
+ d = self.aux_dims
209
+ aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
210
+
211
+ for i in range(seq_len):
212
+
213
+ m_t = mels[:, i, :]
214
+
215
+ a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
216
+
217
+ x = torch.cat([x, m_t, a1_t], dim=1)
218
+ x = self.I(x)
219
+ h1 = rnn1(x, h1)
220
+
221
+ x = x + h1
222
+ inp = torch.cat([x, a2_t], dim=1)
223
+ h2 = rnn2(inp, h2)
224
+
225
+ x = x + h2
226
+ x = torch.cat([x, a3_t], dim=1)
227
+ x = F.relu(self.fc1(x))
228
+
229
+ x = torch.cat([x, a4_t], dim=1)
230
+ x = F.relu(self.fc2(x))
231
+
232
+ logits = self.fc3(x)
233
+
234
+ if self.mode == "MOL":
235
+ sample = sample_from_discretized_mix_logistic(
236
+ logits.unsqueeze(0).transpose(1, 2)
237
+ )
238
+ output.append(sample.view(-1))
239
+ if torch.cuda.is_available():
240
+ # x = torch.FloatTensor([[sample]]).cuda()
241
+ x = sample.transpose(0, 1).cuda()
242
+ else:
243
+ x = sample.transpose(0, 1)
244
+
245
+ elif self.mode == "RAW":
246
+ posterior = F.softmax(logits, dim=1)
247
+ distrib = torch.distributions.Categorical(posterior)
248
+
249
+ sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
250
+ output.append(sample)
251
+ x = sample.unsqueeze(-1)
252
+ else:
253
+ raise RuntimeError("Unknown model mode value - ", self.mode)
254
+
255
+ output = torch.stack(output).transpose(0, 1)
256
+ output = output.cpu().numpy()
257
+ output = output.astype(np.float64)
258
+
259
+ if batched:
260
+ output = self.xfade_and_unfold(output, target, overlap)
261
+ else:
262
+ output = output[0]
263
+
264
+ if mu_law:
265
+ output = decode_mu_law(output, self.n_classes, False)
266
+ if hp.apply_preemphasis:
267
+ output = de_emphasis(output)
268
+
269
+ # Fade-out at the end to avoid signal cutting out suddenly
270
+ fade_out = np.linspace(1, 0, 20 * self.hop_length)
271
+ output = output[:wave_len]
272
+ output[-20 * self.hop_length :] *= fade_out
273
+
274
+ self.train()
275
+
276
+ return output
277
+
278
+ def get_gru_cell(self, gru):
279
+ gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
280
+ gru_cell.weight_hh.data = gru.weight_hh_l0.data
281
+ gru_cell.weight_ih.data = gru.weight_ih_l0.data
282
+ gru_cell.bias_hh.data = gru.bias_hh_l0.data
283
+ gru_cell.bias_ih.data = gru.bias_ih_l0.data
284
+ return gru_cell
285
+
286
+ def pad_tensor(self, x, pad, side="both"):
287
+ # NB - this is just a quick method i need right now
288
+ # i.e., it won't generalise to other shapes/dims
289
+ b, t, c = x.size()
290
+ total = t + 2 * pad if side == "both" else t + pad
291
+ if torch.cuda.is_available():
292
+ padded = torch.zeros(b, total, c).cuda()
293
+ else:
294
+ padded = torch.zeros(b, total, c).cpu()
295
+ if side == "before" or side == "both":
296
+ padded[:, pad : pad + t, :] = x
297
+ elif side == "after":
298
+ padded[:, :t, :] = x
299
+ return padded
300
+
301
+ def fold_with_overlap(self, x, target, overlap):
302
+
303
+ """Fold the tensor with overlap for quick batched inference.
304
+ Overlap will be used for crossfading in xfade_and_unfold()
305
+
306
+ Args:
307
+ x (tensor) : Upsampled conditioning features.
308
+ shape=(1, timesteps, features)
309
+ target (int) : Target timesteps for each index of batch
310
+ overlap (int) : Timesteps for both xfade and rnn warmup
311
+
312
+ Return:
313
+ (tensor) : shape=(num_folds, target + 2 * overlap, features)
314
+
315
+ Details:
316
+ x = [[h1, h2, ... hn]]
317
+
318
+ Where each h is a vector of conditioning features
319
+
320
+ Eg: target=2, overlap=1 with x.size(1)=10
321
+
322
+ folded = [[h1, h2, h3, h4],
323
+ [h4, h5, h6, h7],
324
+ [h7, h8, h9, h10]]
325
+ """
326
+
327
+ _, total_len, features = x.size()
328
+
329
+ # Calculate variables needed
330
+ num_folds = (total_len - overlap) // (target + overlap)
331
+ extended_len = num_folds * (overlap + target) + overlap
332
+ remaining = total_len - extended_len
333
+
334
+ # Pad if some time steps poking out
335
+ if remaining != 0:
336
+ num_folds += 1
337
+ padding = target + 2 * overlap - remaining
338
+ x = self.pad_tensor(x, padding, side="after")
339
+
340
+ if torch.cuda.is_available():
341
+ folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
342
+ else:
343
+ folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu()
344
+
345
+ # Get the values for the folded tensor
346
+ for i in range(num_folds):
347
+ start = i * (target + overlap)
348
+ end = start + target + 2 * overlap
349
+ folded[i] = x[:, start:end, :]
350
+
351
+ return folded
352
+
353
+ def xfade_and_unfold(self, y, target, overlap):
354
+
355
+ """Applies a crossfade and unfolds into a 1d array.
356
+
357
+ Args:
358
+ y (ndarry) : Batched sequences of audio samples
359
+ shape=(num_folds, target + 2 * overlap)
360
+ dtype=np.float64
361
+ overlap (int) : Timesteps for both xfade and rnn warmup
362
+
363
+ Return:
364
+ (ndarry) : audio samples in a 1d array
365
+ shape=(total_len)
366
+ dtype=np.float64
367
+
368
+ Details:
369
+ y = [[seq1],
370
+ [seq2],
371
+ [seq3]]
372
+
373
+ Apply a gain envelope at both ends of the sequences
374
+
375
+ y = [[seq1_in, seq1_target, seq1_out],
376
+ [seq2_in, seq2_target, seq2_out],
377
+ [seq3_in, seq3_target, seq3_out]]
378
+
379
+ Stagger and add up the groups of samples:
380
+
381
+ [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
382
+
383
+ """
384
+
385
+ num_folds, length = y.shape
386
+ target = length - 2 * overlap
387
+ total_len = num_folds * (target + overlap) + overlap
388
+
389
+ # Need some silence for the rnn warmup
390
+ silence_len = overlap // 2
391
+ fade_len = overlap - silence_len
392
+ silence = np.zeros((silence_len), dtype=np.float64)
393
+
394
+ # Equal power crossfade
395
+ t = np.linspace(-1, 1, fade_len, dtype=np.float64)
396
+ fade_in = np.sqrt(0.5 * (1 + t))
397
+ fade_out = np.sqrt(0.5 * (1 - t))
398
+
399
+ # Concat the silence to the fades
400
+ fade_in = np.concatenate([silence, fade_in])
401
+ fade_out = np.concatenate([fade_out, silence])
402
+
403
+ # Apply the gain to the overlap samples
404
+ y[:, :overlap] *= fade_in
405
+ y[:, -overlap:] *= fade_out
406
+
407
+ unfolded = np.zeros((total_len), dtype=np.float64)
408
+
409
+ # Loop to add up all the samples
410
+ for i in range(num_folds):
411
+ start = i * (target + overlap)
412
+ end = start + target + 2 * overlap
413
+ unfolded[start:end] += y[i]
414
+
415
+ return unfolded
416
+
417
+ def get_step(self):
418
+ return self.step.data.item()
419
+
420
+ def checkpoint(self, model_dir, optimizer):
421
+ k_steps = self.get_step() // 1000
422
+ self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer)
423
+
424
+ def load(self, path, optimizer):
425
+ checkpoint = torch.load(path)
426
+ if "optimizer_state" in checkpoint:
427
+ self.load_state_dict(checkpoint["model_state"])
428
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
429
+ else:
430
+ # Backwards compatibility
431
+ self.load_state_dict(checkpoint)
432
+
433
+ def save(self, path, optimizer):
434
+ torch.save(
435
+ {
436
+ "model_state": self.state_dict(),
437
+ "optimizer_state": optimizer.state_dict(),
438
+ },
439
+ path,
440
+ )
441
+
442
+ def num_params(self):
443
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
444
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
445
+ logger.debug("Trainable Parameters: %.3fM" % parameters)
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ numba
5
+ opencv-python-headless
6
+ scipy
7
+ pypinyin
8
+ librosa
9
+ webrtcvad
10
+ Unidecode
11
+ inflect
12
+ loguru
13
+ gradio