Spaces:
Runtime error
Runtime error
Upload 204 files
Browse filesThis view is limited to 50 files because it contains too many changes. ย
See raw diff
- .dockerignore +4 -0
- .gitattributes +4 -35
- .github/FUNDING.yml +1 -0
- .github/ISSUE_TEMPLATE/issue.md +17 -0
- .gitignore +26 -0
- .vscode/launch.json +77 -0
- .vscode/settings.json +3 -0
- CODE_OF_CONDUCT.md +130 -0
- Dockerfile +17 -0
- LICENSE.txt +24 -0
- README-CN.md +283 -0
- README-LINUX-CN.md +223 -0
- archived_untest_files/demo_cli.py +225 -0
- control/__init__.py +0 -0
- control/cli/__init__.py +0 -0
- control/cli/encoder_preprocess.py +64 -0
- control/cli/encoder_train.py +47 -0
- control/cli/ppg2mel_train.py +67 -0
- control/cli/pre4ppg.py +49 -0
- control/cli/synthesizer_train.py +40 -0
- control/cli/train_ppg2mel.py +66 -0
- control/cli/vocoder_preprocess.py +59 -0
- control/cli/vocoder_train.py +92 -0
- control/mkgui/__init__.py +0 -0
- control/mkgui/app.py +151 -0
- control/mkgui/app_vc.py +166 -0
- control/mkgui/base/__init__.py +2 -0
- control/mkgui/base/api/__init__.py +1 -0
- control/mkgui/base/api/fastapi_utils.py +102 -0
- control/mkgui/base/components/__init__.py +0 -0
- control/mkgui/base/components/outputs.py +43 -0
- control/mkgui/base/components/types.py +46 -0
- control/mkgui/base/core.py +203 -0
- control/mkgui/base/ui/__init__.py +1 -0
- control/mkgui/base/ui/schema_utils.py +135 -0
- control/mkgui/base/ui/streamlit_ui.py +933 -0
- control/mkgui/base/ui/streamlit_utils.py +13 -0
- control/mkgui/preprocess.py +96 -0
- control/mkgui/static/mb.png +0 -0
- control/mkgui/train.py +106 -0
- control/mkgui/train_vc.py +155 -0
- control/toolbox/__init__.py +476 -0
- control/toolbox/assets/mb.png +0 -0
- control/toolbox/ui.py +700 -0
- control/toolbox/utterance.py +5 -0
- data/ckpt/encoder/pretrained.pt +3 -0
- data/ckpt/vocoder/pretrained/config_16k.json +31 -0
- data/ckpt/vocoder/pretrained/g_hifigan.pt +3 -0
- data/ckpt/vocoder/pretrained/pretrained.pt +3 -0
- data/samples/T0055G0013S0005.wav +0 -0
.dockerignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*/saved_models
|
2 |
+
!vocoder/saved_models/pretrained/**
|
3 |
+
!encoder/saved_models/pretrained.pt
|
4 |
+
/datasets
|
.gitattributes
CHANGED
@@ -1,35 +1,4 @@
|
|
1 |
-
*.
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.ipynb linguist-vendored
|
2 |
+
data/ckpt/encoder/pretrained.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
data/ckpt/vocoder/pretrained/g_hifigan.pt filter=lfs diff=lfs merge=lfs -text
|
4 |
+
data/ckpt/vocoder/pretrained/pretrained.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/FUNDING.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
github: babysor
|
.github/ISSUE_TEMPLATE/issue.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Issue
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Summary[้ฎ้ข็ฎ่ฟฐ๏ผไธๅฅ่ฏ๏ผ]**
|
11 |
+
A clear and concise description of what the issue is.
|
12 |
+
|
13 |
+
**Env & To Reproduce[ๅค็ฐไธ็ฏๅข]**
|
14 |
+
ๆ่ฟฐไฝ ็จ็็ฏๅขใไปฃ็ ็ๆฌใๆจกๅ
|
15 |
+
|
16 |
+
**Screenshots[ๆชๅพ๏ผๅฆๆ๏ผ]**
|
17 |
+
If applicable, add screenshots to help
|
.gitignore
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.aux
|
3 |
+
*.log
|
4 |
+
*.out
|
5 |
+
*.synctex.gz
|
6 |
+
*.suo
|
7 |
+
*__pycache__
|
8 |
+
*.idea
|
9 |
+
*.ipynb_checkpoints
|
10 |
+
*.pickle
|
11 |
+
*.npy
|
12 |
+
*.blg
|
13 |
+
*.bbl
|
14 |
+
*.bcf
|
15 |
+
*.toc
|
16 |
+
*.sh
|
17 |
+
data/ckpt/*/*
|
18 |
+
!data/ckpt/encoder/pretrained.pt
|
19 |
+
!data/ckpt/vocoder/pretrained/
|
20 |
+
wavs
|
21 |
+
log
|
22 |
+
!/docker-entrypoint.sh
|
23 |
+
!/datasets_download/*.sh
|
24 |
+
/datasets
|
25 |
+
monotonic_align/build
|
26 |
+
monotonic_align/monotonic_align
|
.vscode/launch.json
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// ไฝฟ็จ IntelliSense ไบ่งฃ็ธๅ
ณๅฑๆงใ
|
3 |
+
// ๆฌๅไปฅๆฅ็็ฐๆๅฑๆง็ๆ่ฟฐใ
|
4 |
+
// ๆฌฒไบ่งฃๆดๅคไฟกๆฏ๏ผ่ฏท่ฎฟ้ฎ: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python: Web",
|
9 |
+
"type": "python",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "web.py",
|
12 |
+
"console": "integratedTerminal"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"name": "Python: Vocoder Preprocess",
|
16 |
+
"type": "python",
|
17 |
+
"request": "launch",
|
18 |
+
"program": "control\\cli\\vocoder_preprocess.py",
|
19 |
+
"cwd": "${workspaceFolder}",
|
20 |
+
"console": "integratedTerminal",
|
21 |
+
"args": ["..\\audiodata"]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"name": "Python: Vocoder Train",
|
25 |
+
"type": "python",
|
26 |
+
"request": "launch",
|
27 |
+
"program": "control\\cli\\vocoder_train.py",
|
28 |
+
"cwd": "${workspaceFolder}",
|
29 |
+
"console": "integratedTerminal",
|
30 |
+
"args": ["dev", "..\\audiodata"]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"name": "Python: Demo Box",
|
34 |
+
"type": "python",
|
35 |
+
"request": "launch",
|
36 |
+
"program": "demo_toolbox.py",
|
37 |
+
"cwd": "${workspaceFolder}",
|
38 |
+
"console": "integratedTerminal",
|
39 |
+
"args": ["-d","..\\audiodata"]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"name": "Python: Demo Box VC",
|
43 |
+
"type": "python",
|
44 |
+
"request": "launch",
|
45 |
+
"program": "demo_toolbox.py",
|
46 |
+
"cwd": "${workspaceFolder}",
|
47 |
+
"console": "integratedTerminal",
|
48 |
+
"args": ["-d","..\\audiodata","-vc"]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"name": "Python: Synth Train",
|
52 |
+
"type": "python",
|
53 |
+
"request": "launch",
|
54 |
+
"program": "train.py",
|
55 |
+
"console": "integratedTerminal",
|
56 |
+
"args": ["--type", "vits"]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"name": "Python: PPG Convert",
|
60 |
+
"type": "python",
|
61 |
+
"request": "launch",
|
62 |
+
"program": "run.py",
|
63 |
+
"console": "integratedTerminal",
|
64 |
+
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
65 |
+
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"name": "Python: Vits Train",
|
70 |
+
"type": "python",
|
71 |
+
"request": "launch",
|
72 |
+
"program": "train.py",
|
73 |
+
"console": "integratedTerminal",
|
74 |
+
"args": ["--type", "vits"]
|
75 |
+
},
|
76 |
+
]
|
77 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.formatting.provider": "black"
|
3 |
+
}
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributor Covenant Code of Conduct
|
2 |
+
## First of all
|
3 |
+
Don't be evil, never
|
4 |
+
|
5 |
+
## Our Pledge
|
6 |
+
|
7 |
+
We as members, contributors, and leaders pledge to make participation in our
|
8 |
+
community a harassment-free experience for everyone, regardless of age, body
|
9 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
10 |
+
identity and expression, level of experience, education, socio-economic status,
|
11 |
+
nationality, personal appearance, race, religion, or sexual identity
|
12 |
+
and orientation.
|
13 |
+
|
14 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
15 |
+
diverse, inclusive, and healthy community.
|
16 |
+
|
17 |
+
## Our Standards
|
18 |
+
|
19 |
+
Examples of behavior that contributes to a positive environment for our
|
20 |
+
community include:
|
21 |
+
|
22 |
+
* Demonstrating empathy and kindness toward other people
|
23 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
24 |
+
* Giving and gracefully accepting constructive feedback
|
25 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
26 |
+
and learning from the experience
|
27 |
+
* Focusing on what is best not just for us as individuals, but for the
|
28 |
+
overall community
|
29 |
+
|
30 |
+
Examples of unacceptable behavior include:
|
31 |
+
|
32 |
+
* The use of sexualized language or imagery, and sexual attention or
|
33 |
+
advances of any kind
|
34 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
35 |
+
* Public or private harassment
|
36 |
+
* Publishing others' private information, such as a physical or email
|
37 |
+
address, without their explicit permission
|
38 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
39 |
+
professional setting
|
40 |
+
|
41 |
+
## Enforcement Responsibilities
|
42 |
+
|
43 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
44 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
45 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
46 |
+
or harmful.
|
47 |
+
|
48 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
49 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
50 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
51 |
+
decisions when appropriate.
|
52 |
+
|
53 |
+
## Scope
|
54 |
+
|
55 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
56 |
+
an individual is officially representing the community in public spaces.
|
57 |
+
Examples of representing our community include using an official e-mail address,
|
58 |
+
posting via an official social media account, or acting as an appointed
|
59 |
+
representative at an online or offline event.
|
60 |
+
|
61 |
+
## Enforcement
|
62 |
+
|
63 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
64 |
+
reported to the community leaders responsible for enforcement at
|
65 |
+
babysor00@gmail.com.
|
66 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
67 |
+
|
68 |
+
All community leaders are obligated to respect the privacy and security of the
|
69 |
+
reporter of any incident.
|
70 |
+
|
71 |
+
## Enforcement Guidelines
|
72 |
+
|
73 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
74 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
75 |
+
|
76 |
+
### 1. Correction
|
77 |
+
|
78 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
79 |
+
unprofessional or unwelcome in the community.
|
80 |
+
|
81 |
+
**Consequence**: A private, written warning from community leaders, providing
|
82 |
+
clarity around the nature of the violation and an explanation of why the
|
83 |
+
behavior was inappropriate. A public apology may be requested.
|
84 |
+
|
85 |
+
### 2. Warning
|
86 |
+
|
87 |
+
**Community Impact**: A violation through a single incident or series
|
88 |
+
of actions.
|
89 |
+
|
90 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
91 |
+
interaction with the people involved, including unsolicited interaction with
|
92 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
93 |
+
includes avoiding interactions in community spaces as well as external channels
|
94 |
+
like social media. Violating these terms may lead to a temporary or
|
95 |
+
permanent ban.
|
96 |
+
|
97 |
+
### 3. Temporary Ban
|
98 |
+
|
99 |
+
**Community Impact**: A serious violation of community standards, including
|
100 |
+
sustained inappropriate behavior.
|
101 |
+
|
102 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
103 |
+
communication with the community for a specified period of time. No public or
|
104 |
+
private interaction with the people involved, including unsolicited interaction
|
105 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
106 |
+
Violating these terms may lead to a permanent ban.
|
107 |
+
|
108 |
+
### 4. Permanent Ban
|
109 |
+
|
110 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
111 |
+
standards, including sustained inappropriate behavior, harassment of an
|
112 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
113 |
+
|
114 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
115 |
+
the community.
|
116 |
+
|
117 |
+
## Attribution
|
118 |
+
|
119 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
120 |
+
version 2.0, available at
|
121 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
122 |
+
|
123 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
124 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
125 |
+
|
126 |
+
[homepage]: https://www.contributor-covenant.org
|
127 |
+
|
128 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
129 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
130 |
+
https://www.contributor-covenant.org/translations.
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:latest
|
2 |
+
|
3 |
+
RUN apt-get update && apt-get install -y build-essential ffmpeg parallel aria2 && apt-get clean
|
4 |
+
|
5 |
+
COPY ./requirements.txt /workspace/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install -r requirements.txt && pip install webrtcvad-wheels
|
8 |
+
|
9 |
+
COPY . /workspace
|
10 |
+
|
11 |
+
VOLUME [ "/datasets", "/workspace/synthesizer/saved_models/" ]
|
12 |
+
|
13 |
+
ENV DATASET_MIRROR=default FORCE_RETRAIN=false TRAIN_DATASETS=aidatatang_200zh\ magicdata\ aishell3\ data_aishell TRAIN_SKIP_EXISTING=true
|
14 |
+
|
15 |
+
EXPOSE 8080
|
16 |
+
|
17 |
+
ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ]
|
LICENSE.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
4 |
+
Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
|
5 |
+
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
6 |
+
Original work Copyright (c) 2015 braindead (https://github.com/braindead)
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
SOFTWARE.
|
README-CN.md
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## ๅฎๆถ่ฏญ้ณๅ
้ - ไธญๆ/ๆฎ้่ฏ
|
2 |
+
![mockingbird](https://user-images.githubusercontent.com/12797292/131216767-6eb251d6-14fc-4951-8324-2722f0cd4c63.jpg)
|
3 |
+
|
4 |
+
[![MIT License](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](http://choosealicense.com/licenses/mit/)
|
5 |
+
|
6 |
+
### [English](README.md) | ไธญๆ
|
7 |
+
|
8 |
+
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wikiๆ็จ](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) ๏ฝ [่ฎญ็ปๆ็จ](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
|
9 |
+
|
10 |
+
## ็นๆง
|
11 |
+
๐ **ไธญๆ** ๆฏๆๆฎ้่ฏๅนถไฝฟ็จๅค็งไธญๆๆฐๆฎ้่ฟ่กๆต่ฏ๏ผaidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell ็ญ
|
12 |
+
|
13 |
+
๐คฉ **PyTorch** ้็จไบ pytorch๏ผๅทฒๅจ 1.9.0 ็ๆฌ๏ผๆๆฐไบ 2021 ๅนด 8 ๆ๏ผไธญๆต่ฏ๏ผGPU Tesla T4 ๅ GTX 2060
|
14 |
+
|
15 |
+
๐ **Windows + Linux** ๅฏๅจ Windows ๆไฝ็ณป็ปๅ linux ๆไฝ็ณป็ปไธญ่ฟ่ก๏ผ่นๆ็ณป็ปM1็ไนๆ็คพๅบๆๅ่ฟ่กๆกไพ๏ผ
|
16 |
+
|
17 |
+
๐คฉ **Easy & Awesome** ไป
้ไธ่ฝฝๆๆฐ่ฎญ็ปๅๆๅจ๏ผsynthesizer๏ผๅฐฑๆ่ฏๅฅฝๆๆ๏ผๅค็จ้ข่ฎญ็ป็็ผ็ ๅจ/ๅฃฐ็ ๅจ๏ผๆๅฎๆถ็HiFi-GANไฝไธบvocoder
|
18 |
+
|
19 |
+
๐ **Webserver Ready** ๅฏไผบๆไฝ ็่ฎญ็ป็ปๆ๏ผไพ่ฟ็จ่ฐ็จ
|
20 |
+
|
21 |
+
|
22 |
+
## ๅผๅง
|
23 |
+
### 1. ๅฎ่ฃ
่ฆๆฑ
|
24 |
+
#### 1.1 ้็จ้
็ฝฎ
|
25 |
+
> ๆ็
งๅๅงๅญๅจๅบๆต่ฏๆจๆฏๅฆๅทฒๅๅคๅฅฝๆๆ็ฏๅขใ
|
26 |
+
่ฟ่กๅทฅๅ
ท็ฎฑ(demo_toolbox.py)้่ฆ **Python 3.7 ๆๆด้ซ็ๆฌ** ใ
|
27 |
+
|
28 |
+
* ๅฎ่ฃ
[PyTorch](https://pytorch.org/get-started/locally/)ใ
|
29 |
+
> ๅฆๆๅจ็จ pip ๆนๅผๅฎ่ฃ
็ๆถๅๅบ็ฐ `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` ่ฟไธช้่ฏฏๅฏ่ฝๆฏ python ็ๆฌ่ฟไฝ๏ผ3.9 ๅฏไปฅๅฎ่ฃ
ๆๅ
|
30 |
+
* ๅฎ่ฃ
[ffmpeg](https://ffmpeg.org/download.html#get-packages)ใ
|
31 |
+
* ่ฟ่ก`pip install -r requirements.txt` ๆฅๅฎ่ฃ
ๅฉไฝ็ๅฟ
่ฆๅ
ใ
|
32 |
+
* ๅฎ่ฃ
webrtcvad `pip install webrtcvad-wheels`ใ
|
33 |
+
|
34 |
+
ๆ่
|
35 |
+
- ็จ`conda`ย ๆ่
ย `mamba`ย ๅฎ่ฃ
ไพ่ต
|
36 |
+
|
37 |
+
```conda env create -n env_name -f env.yml```
|
38 |
+
|
39 |
+
```mamba env create -n env_name -f env.yml```
|
40 |
+
|
41 |
+
ไผๅๅปบๆฐ็ฏๅขๅฎ่ฃ
ๅฟ
้กป็ไพ่ต. ไนๅ็จย `conda activate env_name`ย ๅๆข็ฏๅขๅฐฑๅฎๆไบ.
|
42 |
+
> env.ymlๅชๅ
ๅซไบ่ฟ่กๆถๅฟ
่ฆ็ไพ่ต๏ผๆๆถไธๅ
ๆฌmonotonic-align๏ผๅฆๆๆณ่ฆ่ฃ
GPU็ๆฌ็pytorchๅฏไปฅๆฅ็ๅฎ็ฝๆ็จใ
|
43 |
+
|
44 |
+
#### 1.2 M1่ฏ็Mac็ฏๅข้
็ฝฎ๏ผInference Time)
|
45 |
+
> ไปฅไธ็ฏๅขๆx86-64ๆญๅปบ๏ผไฝฟ็จๅ็็`demo_toolbox.py`๏ผๅฏไฝไธบๅจไธๆนไปฃ็ ๆ
ๅตไธๅฟซ้ไฝฟ็จ็workaroundใ
|
46 |
+
>
|
47 |
+
> ๅฆ้ไฝฟ็จM1่ฏ็่ฎญ็ป๏ผๅ `demo_toolbox.py`ไพ่ต็`PyQt5`ไธๆฏๆM1๏ผๅๅบๆ้ไฟฎๆนไปฃ็ ๏ผๆ่
ๅฐ่ฏไฝฟ็จ`web.py`ใ
|
48 |
+
|
49 |
+
* ๅฎ่ฃ
`PyQt5`๏ผๅ่[่ฟไธช้พๆฅ](https://stackoverflow.com/a/68038451/20455983)
|
50 |
+
* ็จRosettaๆๅผTerminal๏ผๅ่[่ฟไธช้พๆฅ](https://dev.to/courier/tips-and-tricks-to-setup-your-apple-m1-for-development-547g)
|
51 |
+
* ็จ็ณป็ปPythonๅๅปบ้กน็ฎ่ๆ็ฏๅข
|
52 |
+
```
|
53 |
+
/usr/bin/python3 -m venv /PathToMockingBird/venv
|
54 |
+
source /PathToMockingBird/venv/bin/activate
|
55 |
+
```
|
56 |
+
* ๅ็บงpipๅนถๅฎ่ฃ
`PyQt5`
|
57 |
+
```
|
58 |
+
pip install --upgrade pip
|
59 |
+
pip install pyqt5
|
60 |
+
```
|
61 |
+
* ๅฎ่ฃ
`pyworld`ๅ`ctc-segmentation`
|
62 |
+
> ่ฟ้ไธคไธชๆไปถ็ดๆฅ`pip install`็ๆถๅๆพไธๅฐwheel๏ผๅฐ่ฏไปc้buildๆถๆพไธๅฐ`Python.h`ๆฅ้
|
63 |
+
* ๅฎ่ฃ
`pyworld`
|
64 |
+
* `brew install python` ้่ฟbrewๅฎ่ฃ
pythonๆถไผ่ชๅจๅฎ่ฃ
`Python.h`
|
65 |
+
* `export CPLUS_INCLUDE_PATH=/opt/homebrew/Frameworks/Python.framework/Headers` ๅฏนไบM1๏ผbrewๅฎ่ฃ
`Python.h`ๅฐไธ่ฟฐ่ทฏๅพใๆ่ทฏๅพๆทปๅ ๅฐ็ฏๅขๅ้้
|
66 |
+
* `pip install pyworld`
|
67 |
+
|
68 |
+
* ๅฎ่ฃ
`ctc-segmentation`
|
69 |
+
> ๅ ไธ่ฟฐๆนๆณๆฒกๆๆๅ๏ผ้ๆฉไป[github](https://github.com/lumaku/ctc-segmentation) cloneๆบ็ ๆๅจ็ผ่ฏ
|
70 |
+
* `git clone https://github.com/lumaku/ctc-segmentation.git` ๅ
้ๅฐไปปๆไฝ็ฝฎ
|
71 |
+
* `cd ctc-segmentation`
|
72 |
+
* `source /PathToMockingBird/venv/bin/activate` ๅ่ฎพไธๅผๅงๆชๅผๅฏ๏ผๆๅผMockingBird้กน็ฎ็่ๆ็ฏๅข
|
73 |
+
* `cythonize -3 ctc_segmentation/ctc_segmentation_dyn.pyx`
|
74 |
+
* `/usr/bin/arch -x86_64 python setup.py build` ่ฆๆณจๆๆ็กฎ็จx86-64ๆถๆ็ผ่ฏ
|
75 |
+
* `/usr/bin/arch -x86_64 python setup.py install --optimize=1 --skip-build`็จx86-64ๆถๆๅฎ่ฃ
|
76 |
+
|
77 |
+
* ๅฎ่ฃ
ๅ
ถไปไพ่ต
|
78 |
+
* `/usr/bin/arch -x86_64 pip install torch torchvision torchaudio` ่ฟ้็จpipๅฎ่ฃ
`PyTorch`๏ผๆ็กฎๆถๆๆฏx86
|
79 |
+
* `pip install ffmpeg` ๅฎ่ฃ
ffmpeg
|
80 |
+
* `pip install -r requirements.txt`
|
81 |
+
|
82 |
+
* ่ฟ่ก
|
83 |
+
> ๅ่[่ฟไธช้พๆฅ](https://youtrack.jetbrains.com/issue/PY-46290/Allow-running-Python-under-Rosetta-2-in-PyCharm-for-Apple-Silicon)
|
84 |
+
๏ผ่ฎฉ้กน็ฎ่ทๅจx86ๆถๆ็ฏๅขไธ
|
85 |
+
* `vim /PathToMockingBird/venv/bin/pythonM1`
|
86 |
+
* ๅๅ
ฅไปฅไธไปฃ็
|
87 |
+
```
|
88 |
+
#!/usr/bin/env zsh
|
89 |
+
mydir=${0:a:h}
|
90 |
+
/usr/bin/arch -x86_64 $mydir/python "$@"
|
91 |
+
```
|
92 |
+
* `chmod +x pythonM1` ่ฎพไธบๅฏๆง่กๆไปถ
|
93 |
+
* ๅฆๆไฝฟ็จPyCharm๏ผๅๆInterpreterๆๅ`pythonM1`๏ผๅฆๅไนๅฏๅฝไปค่ก่ฟ่ก`/PathToMockingBird/venv/bin/pythonM1 demo_toolbox.py`
|
94 |
+
|
95 |
+
### 2. ๅๅค้ข่ฎญ็ปๆจกๅ
|
96 |
+
่่่ฎญ็ปๆจ่ชๅทฑไธๅฑ็ๆจกๅๆ่
ไธ่ฝฝ็คพๅบไปไบบ่ฎญ็ปๅฅฝ็ๆจกๅ:
|
97 |
+
> ่ฟๆๅๅปบไบ[็ฅไนไธ้ข](https://www.zhihu.com/column/c_1425605280340504576) ๅฐไธๅฎๆๆดๆฐ็ผไธนๅฐๆๅทงorๅฟๅพ๏ผไนๆฌข่ฟๆ้ฎ
|
98 |
+
#### 2.1 ไฝฟ็จๆฐๆฎ้่ชๅทฑ่ฎญ็ปencoderๆจกๅ (ๅฏ้)
|
99 |
+
|
100 |
+
* ่ฟ่ก้ณ้ขๅๆข
ๅฐ้ข่ฐฑๅพ้ขๅค็๏ผ
|
101 |
+
`python encoder_preprocess.py <datasets_root>`
|
102 |
+
ไฝฟ็จ`-d {dataset}` ๆๅฎๆฐๆฎ้๏ผๆฏๆ librispeech_other๏ผvoxceleb1๏ผaidatatang_200zh๏ผไฝฟ็จ้ๅทๅๅฒๅค็ๅคๆฐๆฎ้ใ
|
103 |
+
* ่ฎญ็ปencoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
|
104 |
+
> ่ฎญ็ปencoderไฝฟ็จไบvisdomใไฝ ๅฏไปฅๅ ไธ`-no_visdom`็ฆ็จvisdom๏ผไฝๆฏๆๅฏ่งๅไผๆดๅฅฝใๅจๅ็ฌ็ๅฝไปค่ก/่ฟ็จไธญ่ฟ่ก"visdom"ๆฅๅฏๅจvisdomๆๅกๅจใ
|
105 |
+
|
106 |
+
#### 2.2 ไฝฟ็จๆฐๆฎ้่ชๅทฑ่ฎญ็ปๅๆๅจๆจกๅ๏ผไธ2.3ไบ้ไธ๏ผ
|
107 |
+
* ไธ่ฝฝ ๆฐๆฎ้ๅนถ่งฃๅ๏ผ็กฎไฟๆจๅฏไปฅ่ฎฟ้ฎ *train* ๆไปถๅคนไธญ็ๆๆ้ณ้ขๆไปถ๏ผๅฆ.wav๏ผ
|
108 |
+
* ่ฟ่ก้ณ้ขๅๆข
ๅฐ้ข่ฐฑๅพ้ขๅค็๏ผ
|
109 |
+
`python pre.py <datasets_root> -d {dataset} -n {number}`
|
110 |
+
ๅฏไผ ๅ
ฅๅๆฐ๏ผ
|
111 |
+
* `-d {dataset}` ๆๅฎๆฐๆฎ้๏ผๆฏๆ aidatatang_200zh, magicdata, aishell3, data_aishell, ไธไผ ้ป่ฎคไธบaidatatang_200zh
|
112 |
+
* `-n {number}` ๆๅฎๅนถ่กๆฐ๏ผCPU 11770k + 32GBๅฎๆต10ๆฒกๆ้ฎ้ข
|
113 |
+
> ๅๅฆไฝ ไธ่ฝฝ็ `aidatatang_200zh`ๆไปถๆพๅจD็๏ผ`train`ๆไปถ่ทฏๅพไธบ `D:\data\aidatatang_200zh\corpus\train` , ไฝ ็`datasets_root`ๅฐฑๆฏ `D:\data\`
|
114 |
+
|
115 |
+
* ่ฎญ็ปๅๆๅจ๏ผ
|
116 |
+
`python ./control/cli/synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
117 |
+
|
118 |
+
* ๅฝๆจๅจ่ฎญ็ปๆไปถๅคน *synthesizer/saved_models/* ไธญ็ๅฐๆณจๆ็บฟๆพ็คบๅๆๅคฑๆปก่ถณๆจ็้่ฆๆถ๏ผ่ฏท่ฝฌๅฐ`ๅฏๅจ็จๅบ`ไธๆญฅใ
|
119 |
+
|
120 |
+
#### 2.3ไฝฟ็จ็คพๅบ้ขๅ
่ฎญ็ปๅฅฝ็ๅๆๅจ๏ผไธ2.2ไบ้ไธ๏ผ
|
121 |
+
> ๅฝๅฎๅจๆฒกๆ่ฎพๅคๆ่
ไธๆณๆ
ขๆ
ข่ฐ่ฏ๏ผๅฏไปฅไฝฟ็จ็คพๅบ่ดก็ฎ็ๆจกๅ(ๆฌข่ฟๆ็ปญๅไบซ):
|
122 |
+
|
123 |
+
| ไฝ่
| ไธ่ฝฝ้พๆฅ | ๆๆ้ข่ง | ไฟกๆฏ |
|
124 |
+
| --- | ----------- | ----- | ----- |
|
125 |
+
| ไฝ่
| https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [็พๅบฆ็้พๆฅ](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps ็จ3ไธชๅผๆบๆฐๆฎ้ๆททๅ่ฎญ็ป
|
126 |
+
| ไฝ่
| https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [็พๅบฆ็้พๆฅ](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) ๆๅ็ ๏ผom7f | | 25k steps ็จ3ไธชๅผๆบๆฐๆฎ้ๆททๅ่ฎญ็ป, ๅๆขๅฐtag v0.0.1ไฝฟ็จ
|
127 |
+
|@FawenYo | https://yisiou-my.sharepoint.com/:u:/g/personal/lawrence_cheng_fawenyo_onmicrosoft_com/EWFWDHzee-NNg9TWdKckCc4BC7bK2j9cCbOWn0-_tK0nOg?e=n0gGgC | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps ๅฐๆนพๅฃ้ณ้ๅๆขๅฐtag v0.0.1ไฝฟ็จ
|
128 |
+
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ ๆๅ็ ๏ผ2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps ๆณจๆ๏ผๆ นๆฎ[issue](https://github.com/babysor/MockingBird/issues/37)ไฟฎๅค ๅนถๅๆขๅฐtag v0.0.1ไฝฟ็จ
|
129 |
+
|
130 |
+
#### 2.4่ฎญ็ปๅฃฐ็ ๅจ (ๅฏ้)
|
131 |
+
ๅฏนๆๆๅฝฑๅไธๅคง๏ผๅทฒ็ป้ข็ฝฎ3ๆฌพ๏ผๅฆๆๅธๆ่ชๅทฑ่ฎญ็ปๅฏไปฅๅ่ไปฅไธๅฝไปคใ
|
132 |
+
* ้ขๅค็ๆฐๆฎ:
|
133 |
+
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
134 |
+
> `<datasets_root>`ๆฟๆขไธบไฝ ็ๆฐๆฎ้็ฎๅฝ๏ผ`<synthesizer_model_path>`ๆฟๆขไธบไธไธชไฝ ๆๅฅฝ็synthesizerๆจกๅ็ฎๅฝ๏ผไพๅฆ *sythensizer\saved_models\xxx*
|
135 |
+
|
136 |
+
|
137 |
+
* ่ฎญ็ปwavernnๅฃฐ็ ๅจ:
|
138 |
+
`python ./control/cli/vocoder_train.py <trainid> <datasets_root>`
|
139 |
+
> `<trainid>`ๆฟๆขไธบไฝ ๆณ่ฆ็ๆ ่ฏ๏ผๅไธๆ ่ฏๅๆฌก่ฎญ็ปๆถไผๅปถ็ปญๅๆจกๅ
|
140 |
+
|
141 |
+
* ่ฎญ็ปhifiganๅฃฐ็ ๅจ:
|
142 |
+
`python ./control/cli/vocoder_train.py <trainid> <datasets_root> hifigan`
|
143 |
+
> `<trainid>`ๆฟๆขไธบไฝ ๆณ่ฆ็ๆ ่ฏ๏ผๅไธๆ ่ฏๅๆฌก่ฎญ็ปๆถไผๅปถ็ปญๅๆจกๅ
|
144 |
+
* ่ฎญ็ปfreganๅฃฐ็ ๅจ:
|
145 |
+
`python ./control/cli/vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
146 |
+
> `<trainid>`ๆฟๆขไธบไฝ ๆณ่ฆ็ๆ ่ฏ๏ผๅไธๆ ่ฏๅๆฌก่ฎญ็ปๆถไผๅปถ็ปญๅๆจกๅ
|
147 |
+
* ๅฐGANๅฃฐ็ ๅจ็่ฎญ็ปๅๆขไธบๅคGPUๆจกๅผ๏ผไฟฎๆนGANๆไปถๅคนไธ.jsonๆไปถไธญ็"num_gpus"ๅๆฐ
|
148 |
+
### 3. ๅฏๅจ็จๅบๆๅทฅๅ
ท็ฎฑ
|
149 |
+
ๆจๅฏไปฅๅฐ่ฏไฝฟ็จไปฅไธๅฝไปค๏ผ
|
150 |
+
|
151 |
+
### 3.1 ๅฏๅจWeb็จๅบ๏ผv2๏ผ๏ผ
|
152 |
+
`python web.py`
|
153 |
+
่ฟ่กๆๅๅๅจๆต่งๅจๆๅผๅฐๅ, ้ป่ฎคไธบ `http://localhost:8080`
|
154 |
+
> * ไป
ๆฏๆๆๅจๆฐๅฝ้ณ๏ผ16khz๏ผ, ไธๆฏๆ่ถ
่ฟ4MB็ๅฝ้ณ๏ผๆไฝณ้ฟๅบฆๅจ5~15็ง
|
155 |
+
|
156 |
+
### 3.2 ๅฏๅจๅทฅๅ
ท็ฎฑ๏ผ
|
157 |
+
`python demo_toolbox.py -d <datasets_root>`
|
158 |
+
> ่ฏทๆๅฎไธไธชๅฏ็จ็ๆฐๆฎ้ๆไปถ่ทฏๅพ๏ผๅฆๆๆๆฏๆ็ๆฐๆฎ้ๅไผ่ชๅจๅ ่ฝฝไพ่ฐ่ฏ๏ผไนๅๆถไผไฝไธบๆๅจๅฝๅถ้ณ้ข็ๅญๅจ็ฎๅฝใ
|
159 |
+
|
160 |
+
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
|
161 |
+
|
162 |
+
### 4. ็ชๅค๏ผ่ฏญ้ณ่ฝฌๆขVoice Conversion(PPG based)
|
163 |
+
ๆณๅๆฏๅๆฟ็ๅๅฃฐๅจ็ถๅๅๅบๆฏๅฉๅฐไบ้็ๅฃฐ้ณๅ๏ผๆฌ้กน็ฎ็ฐๅบไบPPG-VC๏ผๅผๅ
ฅ้ขๅคไธคไธชๆจกๅ๏ผPPG extractor + PPG2Mel๏ผ, ๅฏไปฅๅฎ็ฐๅๅฃฐๅ่ฝใ๏ผๆๆกฃไธๅ
จ๏ผๅฐคๅ
ถๆฏ่ฎญ็ป้จๅ๏ผๆญฃๅจๅชๅ่กฅๅ
ไธญ๏ผ
|
164 |
+
#### 4.0 ๅๅค๏ฟฝ๏ฟฝ๏ฟฝๅข
|
165 |
+
* ็กฎไฟ้กน็ฎไปฅไธ็ฏๅขๅทฒ็ปๅฎ่ฃ
ok๏ผ่ฟ่ก`pip install espnet` ๆฅๅฎ่ฃ
ๅฉไฝ็ๅฟ
่ฆๅ
ใ
|
166 |
+
* ไธ่ฝฝไปฅไธๆจกๅ ้พๆฅ๏ผhttps://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
|
167 |
+
ๆๅ็ ๏ผgh41
|
168 |
+
* 24K้ๆ ท็ไธ็จ็vocoder๏ผhifigan๏ผๅฐ *vocoder\saved_models\xxx*
|
169 |
+
* ้ข่ฎญ็ป็ppg็นๅพencoder(ppg_extractor)ๅฐ *ppg_extractor\saved_models\xxx*
|
170 |
+
* ้ข่ฎญ็ป็PPG2Melๅฐ *ppg2mel\saved_models\xxx*
|
171 |
+
|
172 |
+
#### 4.1 ไฝฟ็จๆฐๆฎ้่ชๅทฑ่ฎญ็ปPPG2Melๆจกๅ (ๅฏ้)
|
173 |
+
|
174 |
+
* ไธ่ฝฝaidatatang_200zhๆฐๆฎ้ๅนถ่งฃๅ๏ผ็กฎไฟๆจๅฏไปฅ่ฎฟ้ฎ *train* ๆไปถๅคนไธญ็ๆๆ้ณ้ขๆไปถ๏ผๅฆ.wav๏ผ
|
175 |
+
* ่ฟ่ก้ณ้ขๅๆข
ๅฐ้ข่ฐฑๅพ้ขๅค็๏ผ
|
176 |
+
`python ./control/cli/pre4ppg.py <datasets_root> -d {dataset} -n {number}`
|
177 |
+
ๅฏไผ ๅ
ฅๅๆฐ๏ผ
|
178 |
+
* `-d {dataset}` ๆๅฎๆฐๆฎ้๏ผๆฏๆ aidatatang_200zh, ไธไผ ้ป่ฎคไธบaidatatang_200zh
|
179 |
+
* `-n {number}` ๆๅฎๅนถ่กๆฐ๏ผCPU 11700kๅจ8็ๆ
ๅตไธ๏ผ้่ฆ่ฟ่ก12ๅฐ18ๅฐๆถ๏ผๅพ
ไผๅ
|
180 |
+
> ๅๅฆไฝ ไธ่ฝฝ็ `aidatatang_200zh`ๆไปถๆพๅจD็๏ผ`train`ๆไปถ่ทฏๅพไธบ `D:\data\aidatatang_200zh\corpus\train` , ไฝ ็`datasets_root`ๅฐฑๆฏ `D:\data\`
|
181 |
+
|
182 |
+
* ่ฎญ็ปๅๆๅจ, ๆณจๆๅจไธไธๆญฅๅ
ไธ่ฝฝๅฅฝ`ppg2mel.yaml`, ไฟฎๆน้้ข็ๅฐๅๆๅ้ข่ฎญ็ปๅฅฝ็ๆไปถๅคน๏ผ
|
183 |
+
`python ./control/cli/ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
|
184 |
+
* ๅฆๆๆณ่ฆ็ปง็ปญไธไธๆฌก็่ฎญ็ป๏ผๅฏไปฅ้่ฟ`--load .\ppg2mel\saved_models\<old_pt_file>` ๅๆฐๆๅฎไธไธช้ข่ฎญ็ปๆจกๅๆไปถใ
|
185 |
+
|
186 |
+
#### 4.2 ๅฏๅจๅทฅๅ
ท็ฎฑVCๆจกๅผ
|
187 |
+
ๆจๅฏไปฅๅฐ่ฏไฝฟ็จไปฅไธๅฝไปค๏ผ
|
188 |
+
`python demo_toolbox.py -vc -d <datasets_root>`
|
189 |
+
> ่ฏทๆๅฎไธไธชๅฏ็จ็ๆฐๆฎ้ๆไปถ่ทฏๅพ๏ผๅฆๆๆๆฏๆ็ๆฐๆฎ้ๅไผ่ชๅจๅ ่ฝฝไพ่ฐ่ฏ๏ผไนๅๆถไผไฝไธบๆๅจๅฝๅถ้ณ้ข็ๅญๅจ็ฎๅฝใ
|
190 |
+
<img width="971" alt="ๅพฎไฟกๅพ็_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
|
191 |
+
|
192 |
+
## ๅผ็จๅ่ฎบๆ
|
193 |
+
> ่ฏฅๅบไธๅผๅงไปไป
ๆฏๆ่ฑ่ฏญ็[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) ๅๅๅบๆฅ็๏ผ้ธฃ่ฐขไฝ่
ใ
|
194 |
+
|
195 |
+
| URL | Designation | ๆ ้ข | ๅฎ็ฐๆบ็ |
|
196 |
+
| --- | ----------- | ----- | --------------------- |
|
197 |
+
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | ๆฌไปฃ็ ๅบ |
|
198 |
+
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | ๆฌไปฃ็ ๅบ |
|
199 |
+
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | ๆฌไปฃ็ ๅบ |
|
200 |
+
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | ๆฌไปฃ็ ๅบ |
|
201 |
+
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
202 |
+
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
203 |
+
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | ๆฌไปฃ็ ๅบ |
|
204 |
+
|
205 |
+
## ๅธธ่ง้ฎ้ข(FQ&A)
|
206 |
+
#### 1.ๆฐๆฎ้ๅจๅช้ไธ่ฝฝ?
|
207 |
+
| ๆฐๆฎ้ | OpenSLRๅฐๅ | ๅ
ถไปๆบ (Google Drive, Baidu็ฝ็็ญ) |
|
208 |
+
| --- | ----------- | ---------------|
|
209 |
+
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
210 |
+
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
211 |
+
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
212 |
+
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
213 |
+
> ่งฃๅ aidatatang_200zh ๅ๏ผ่ฟ้ๅฐ `aidatatang_200zh\corpus\train`ไธ็ๆไปถๅ
จ้่งฃๅ็ผฉ
|
214 |
+
|
215 |
+
#### 2.`<datasets_root>`ๆฏไป้บผๆๆ?
|
216 |
+
ๅๅฆๆฐๆฎ้่ทฏๅพไธบ `D:\data\aidatatang_200zh`๏ผ้ฃไน `<datasets_root>`ๅฐฑๆฏ `D:\data`
|
217 |
+
|
218 |
+
#### 3.่ฎญ็ปๆจกๅๆพๅญไธ่ถณ
|
219 |
+
่ฎญ็ปๅๆๅจๆถ๏ผๅฐ `synthesizer/hparams.py`ไธญ็batch_sizeๅๆฐ่ฐๅฐ
|
220 |
+
```
|
221 |
+
//่ฐๆดๅ
|
222 |
+
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
223 |
+
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
224 |
+
(2, 2e-4, 80_000, 12), #
|
225 |
+
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
226 |
+
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
227 |
+
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
228 |
+
//่ฐๆดๅ
|
229 |
+
tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
230 |
+
(2, 5e-4, 40_000, 8), # (r, lr, step, batch_size)
|
231 |
+
(2, 2e-4, 80_000, 8), #
|
232 |
+
(2, 1e-4, 160_000, 8), # r = reduction factor (# of mel frames
|
233 |
+
(2, 3e-5, 320_000, 8), # synthesized for each decoder iteration)
|
234 |
+
(2, 1e-5, 640_000, 8)], # lr = learning rate
|
235 |
+
```
|
236 |
+
|
237 |
+
ๅฃฐ็ ๅจ-้ขๅค็ๆฐๆฎ้ๆถ๏ผๅฐ `synthesizer/hparams.py`ไธญ็batch_sizeๅๆฐ่ฐๅฐ
|
238 |
+
```
|
239 |
+
//่ฐๆดๅ
|
240 |
+
### Data Preprocessing
|
241 |
+
max_mel_frames = 900,
|
242 |
+
rescale = True,
|
243 |
+
rescaling_max = 0.9,
|
244 |
+
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
245 |
+
//่ฐๆดๅ
|
246 |
+
### Data Preprocessing
|
247 |
+
max_mel_frames = 900,
|
248 |
+
rescale = True,
|
249 |
+
rescaling_max = 0.9,
|
250 |
+
synthesis_batch_size = 8, # For vocoder preprocessing and inference.
|
251 |
+
```
|
252 |
+
|
253 |
+
ๅฃฐ็ ๅจ-่ฎญ็ปๅฃฐ็ ๅจๆถ๏ผๅฐ `vocoder/wavernn/hparams.py`ไธญ็batch_sizeๅๆฐ่ฐๅฐ
|
254 |
+
```
|
255 |
+
//่ฐๆดๅ
|
256 |
+
# Training
|
257 |
+
voc_batch_size = 100
|
258 |
+
voc_lr = 1e-4
|
259 |
+
voc_gen_at_checkpoint = 5
|
260 |
+
voc_pad = 2
|
261 |
+
|
262 |
+
//่ฐๆดๅ
|
263 |
+
# Training
|
264 |
+
voc_batch_size = 6
|
265 |
+
voc_lr = 1e-4
|
266 |
+
voc_gen_at_checkpoint = 5
|
267 |
+
voc_pad =2
|
268 |
+
```
|
269 |
+
|
270 |
+
#### 4.็ขฐๅฐ`RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).`
|
271 |
+
่ฏทๅ็
ง issue [#37](https://github.com/babysor/MockingBird/issues/37)
|
272 |
+
|
273 |
+
#### 5.ๅฆไฝๆนๅCPUใGPUๅ ็จ็?
|
274 |
+
่งๆ
ๅต่ฐๆดbatch_sizeๅๆฐๆฅๆนๅ
|
275 |
+
|
276 |
+
#### 6.ๅ็ `้กต้ขๆไปถๅคชๅฐ๏ผๆ ๆณๅฎๆๆไฝ`
|
277 |
+
่ฏทๅ่่ฟ็ฏ[ๆ็ซ ](https://blog.csdn.net/qq_17755303/article/details/112564030)๏ผๅฐ่ๆๅ
ๅญๆดๆนไธบ100G(102400)๏ผไพๅฆ:ๆไปถๆพ็ฝฎD็ๅฐฑๆดๆนD็็่ๆๅ
ๅญ
|
278 |
+
|
279 |
+
#### 7.ไปไนๆถๅ็ฎ่ฎญ็ปๅฎๆ๏ผ
|
280 |
+
้ฆๅ
ไธๅฎ่ฆๅบ็ฐๆณจๆๅๆจกๅ๏ผๅ
ถๆฌกๆฏloss่ถณๅคไฝ๏ผๅๅณไบ็กฌไปถ่ฎพๅคๅๆฐๆฎ้ใๆฟๆฌไบบ็ไพๅ่๏ผๆ็ๆณจๆๅๆฏๅจ 18k ๆญฅไนๅๅบ็ฐ็๏ผๅนถไธๅจ 50k ๆญฅไนๅๆๅคฑๅๅพไฝไบ 0.4
|
281 |
+
![attention_step_20500_sample_1](https://user-images.githubusercontent.com/7423248/128587252-f669f05a-f411-4811-8784-222156ea5e9d.png)
|
282 |
+
|
283 |
+
![step-135500-mel-spectrogram_sample_1](https://user-images.githubusercontent.com/7423248/128587255-4945faa0-5517-46ea-b173-928eff999330.png)
|
README-LINUX-CN.md
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## ๅฎๆถ่ฏญ้ณๅ
้ - ไธญๆ/ๆฎ้่ฏ
|
2 |
+
![mockingbird](https://user-images.githubusercontent.com/12797292/131216767-6eb251d6-14fc-4951-8324-2722f0cd4c63.jpg)
|
3 |
+
|
4 |
+
[![MIT License](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](http://choosealicense.com/licenses/mit/)
|
5 |
+
|
6 |
+
### [English](README.md) | ไธญๆ
|
7 |
+
|
8 |
+
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wikiๆ็จ](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) ๏ฝ [่ฎญ็ปๆ็จ](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
|
9 |
+
|
10 |
+
## ็นๆง
|
11 |
+
๐ **ไธญๆ** ๆฏๆๆฎ้่ฏๅนถไฝฟ็จๅค็งไธญๆๆฐๆฎ้่ฟ่กๆต่ฏ๏ผaidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell ็ญ
|
12 |
+
|
13 |
+
๐คฉ **Easy & Awesome** ไป
้ไธ่ฝฝๆๆฐ่ฎญ็ปๅๆๅจ๏ผsynthesizer๏ผๅฐฑๆ่ฏๅฅฝๆๆ๏ผๅค็จ้ข่ฎญ็ป็็ผ็ ๅจ/ๅฃฐ็ ๅจ๏ผๆๅฎๆถ็HiFi-GANไฝไธบvocoder
|
14 |
+
|
15 |
+
๐ **Webserver Ready** ๅฏไผบๆไฝ ็่ฎญ็ป็ปๆ๏ผไพ่ฟ็จ่ฐ็จใ
|
16 |
+
|
17 |
+
๐คฉ **ๆ่ฐขๅไฝๅฐไผไผด็ๆฏๆ๏ผๆฌ้กน็ฎๅฐๅผๅฏๆฐไธ่ฝฎ็ๆดๆฐ**
|
18 |
+
|
19 |
+
## 1.ๅฟซ้ๅผๅง
|
20 |
+
### 1.1 ๅปบ่ฎฎ็ฏๅข
|
21 |
+
- Ubuntu 18.04
|
22 |
+
- Cuda 11.7 && CuDNN 8.5.0
|
23 |
+
- Python 3.8 ๆ 3.9
|
24 |
+
- Pytorch 2.0.1 <post cuda-11.7>
|
25 |
+
### 1.2 ็ฏๅข้
็ฝฎ
|
26 |
+
```shell
|
27 |
+
# ไธ่ฝฝๅๅปบ่ฎฎๆดๆขๅฝๅ
้ๅๆบ
|
28 |
+
|
29 |
+
conda create -n sound python=3.9
|
30 |
+
|
31 |
+
conda activate sound
|
32 |
+
|
33 |
+
git clone https://github.com/babysor/MockingBird.git
|
34 |
+
|
35 |
+
cd MockingBird
|
36 |
+
|
37 |
+
pip install -r requirements.txt
|
38 |
+
|
39 |
+
pip install webrtcvad-wheels
|
40 |
+
|
41 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
|
42 |
+
```
|
43 |
+
|
44 |
+
### 1.3 ๆจกๅๅๅค
|
45 |
+
> ๅฝๅฎๅจๆฒกๆ่ฎพๅคๆ่
ไธๆณๆ
ขๆ
ข่ฐ่ฏ๏ผๅฏไปฅไฝฟ็จ็คพๅบ่ดก็ฎ็ๆจกๅ(ๆฌข่ฟๆ็ปญๅไบซ):
|
46 |
+
|
47 |
+
| ไฝ่
| ไธ่ฝฝ้พๆฅ | ๆๆ้ข่ง | ไฟกๆฏ |
|
48 |
+
| --- | ----------- | ----- | ----- |
|
49 |
+
| ไฝ่
| https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [็พๅบฆ็้พๆฅ](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps ็จ3ไธชๅผๆบๆฐๆฎ้ๆททๅ่ฎญ็ป
|
50 |
+
| ไฝ่
| https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [็พๅบฆ็้พๆฅ](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) ๆๅ็ ๏ผom7f | | 25k steps ็จ3ไธชๅผๆบๆฐๆฎ้ๆททๅ่ฎญ็ป, ๅๆขๅฐtag v0.0.1ไฝฟ็จ
|
51 |
+
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [็พๅบฆ็้พๆฅ](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) ๆๅ็ ๏ผ1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps ๅฐๆนพๅฃ้ณ้ๅๆขๅฐtag v0.0.1ไฝฟ็จ
|
52 |
+
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ ๆๅ็ ๏ผ2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps ๆณจๆ๏ผๆ นๆฎ[issue](https://github.com/babysor/MockingBird/issues/37)ไฟฎๅค ๅนถๅๆขๅฐtag v0.0.1ไฝฟ็จ
|
53 |
+
|
54 |
+
### 1.4 ๆไปถ็ปๆๅๅค
|
55 |
+
ๆไปถ็ปๆๅๅคๅฆไธๆ็คบ๏ผ็ฎๆณๅฐ่ชๅจ้ๅsynthesizerไธ็.ptๆจกๅๆไปถใ
|
56 |
+
```
|
57 |
+
# ไปฅ็ฌฌไธไธช pretrained-11-7-21_75k.pt ไธบไพ
|
58 |
+
|
59 |
+
โโโ data
|
60 |
+
โโโ ckpt
|
61 |
+
โโโ synthesizer
|
62 |
+
โโโ pretrained-11-7-21_75k.pt
|
63 |
+
```
|
64 |
+
### 1.5 ่ฟ่ก
|
65 |
+
```
|
66 |
+
python web.py
|
67 |
+
```
|
68 |
+
|
69 |
+
## 2.ๆจกๅ่ฎญ็ป
|
70 |
+
### 2.1 ๆฐๆฎๅๅค
|
71 |
+
#### 2.1.1 ๆฐๆฎไธ่ฝฝ
|
72 |
+
``` shell
|
73 |
+
# aidatatang_200zh
|
74 |
+
|
75 |
+
wget https://openslr.elda.org/resources/62/aidatatang_200zh.tgz
|
76 |
+
```
|
77 |
+
``` shell
|
78 |
+
# MAGICDATA
|
79 |
+
|
80 |
+
wget https://openslr.magicdatatech.com/resources/68/train_set.tar.gz
|
81 |
+
|
82 |
+
wget https://openslr.magicdatatech.com/resources/68/dev_set.tar.gz
|
83 |
+
|
84 |
+
wget https://openslr.magicdatatech.com/resources/68/test_set.tar.gz
|
85 |
+
```
|
86 |
+
``` shell
|
87 |
+
# AISHELL-3
|
88 |
+
|
89 |
+
wget https://openslr.elda.org/resources/93/data_aishell3.tgz
|
90 |
+
```
|
91 |
+
```shell
|
92 |
+
# Aishell
|
93 |
+
|
94 |
+
wget https://openslr.elda.org/resources/33/data_aishell.tgz
|
95 |
+
```
|
96 |
+
#### 2.1.2 ๆฐๆฎๆน้่งฃๅ
|
97 |
+
```shell
|
98 |
+
# ่ฏฅๆไปคไธบ่งฃๅๅฝๅ็ฎๅฝไธ็ๆๆๅ็ผฉๆไปถ
|
99 |
+
|
100 |
+
for gz in *.gz; do tar -zxvf $gz; done
|
101 |
+
```
|
102 |
+
### 2.2 encoderๆจกๅ่ฎญ็ป
|
103 |
+
#### 2.2.1 ๆฐๆฎ้ขๅค็๏ผ
|
104 |
+
้่ฆๅ
ๅจ`pre.py `ๅคด้จๅ ๅ
ฅ๏ผ
|
105 |
+
```python
|
106 |
+
import torch
|
107 |
+
torch.multiprocessing.set_start_method('spawn', force=True)
|
108 |
+
```
|
109 |
+
ไฝฟ็จไปฅไธๆไปคๅฏนๆฐๆฎ้ขๅค็๏ผ
|
110 |
+
```shell
|
111 |
+
python pre.py <datasets_root> \
|
112 |
+
-d <datasets_name>
|
113 |
+
```
|
114 |
+
ๅ
ถไธญ`<datasets_root>`ไธบๅๆฐๆฎ้่ทฏๅพ๏ผ`<datasets_name>` ไธบๆฐๆฎ้ๅ็งฐใ
|
115 |
+
|
116 |
+
ๆฏๆ `librispeech_other`๏ผ`voxceleb1`๏ผ`aidatatang_200zh`๏ผไฝฟ็จ้ๅทๅๅฒๅค็ๅคๆฐๆฎ้ใ
|
117 |
+
|
118 |
+
### 2.2.2 encoderๆจกๅ่ฎญ็ป๏ผ
|
119 |
+
่ถ
ๅๆฐๆไปถ่ทฏๅพ๏ผ`models/encoder/hparams.py`
|
120 |
+
```shell
|
121 |
+
python encoder_train.py <name> \
|
122 |
+
<datasets_root>/SV2TTS/encoder
|
123 |
+
```
|
124 |
+
ๅ
ถไธญ `<name>` ๆฏ่ฎญ็ปไบง็ๆไปถ็ๅ็งฐ๏ผๅฏ่ช่กไฟฎๆนใ
|
125 |
+
|
126 |
+
ๅ
ถไธญ `<datasets_root>` ๆฏ็ป่ฟ `Step 2.1.1` ๅค็่ฟๅ็ๆฐๆฎ้่ทฏๅพใ
|
127 |
+
#### 2.2.3 ๅผๅฏencoderๆจกๅ่ฎญ็ปๆฐๆฎๅฏ่งๅ๏ผๅฏ้๏ผ
|
128 |
+
```shell
|
129 |
+
visdom
|
130 |
+
```
|
131 |
+
|
132 |
+
### 2.3 synthesizerๆจกๅ่ฎญ็ป
|
133 |
+
#### 2.3.1 ๆฐๆฎ้ขๅค็๏ผ
|
134 |
+
```shell
|
135 |
+
python pre.py <datasets_root> \
|
136 |
+
-d <datasets_name> \
|
137 |
+
-o <datasets_path> \
|
138 |
+
-n <number>
|
139 |
+
```
|
140 |
+
`<datasets_root>` ไธบๅๆฐๆฎ้่ทฏๅพ๏ผๅฝไฝ ็`aidatatang_200zh`่ทฏๅพไธบ`/data/aidatatang_200zh/corpus/train`ๆถ๏ผ`<datasets_root>` ไธบ `/data/`ใ
|
141 |
+
|
142 |
+
`<datasets_name>` ไธบๆฐๆฎ้ๅ็งฐใ
|
143 |
+
|
144 |
+
`<datasets_path>` ไธบๆฐๆฎ้ๅค็ๅ็ไฟๅญ่ทฏๅพใ
|
145 |
+
|
146 |
+
`<number>` ไธบๆฐๆฎ้ๅค็ๆถ่ฟ็จๆฐ๏ผๆ นๆฎCPUๆ
ๅต่ฐๆดๅคงๅฐใ
|
147 |
+
|
148 |
+
#### 2.3.2 ๆฐๅขๆฐๆฎ้ขๅค็๏ผ
|
149 |
+
```shell
|
150 |
+
python pre.py <datasets_root> \
|
151 |
+
-d <datasets_name> \
|
152 |
+
-o <datasets_path> \
|
153 |
+
-n <number> \
|
154 |
+
-s
|
155 |
+
```
|
156 |
+
ๅฝๆฐๅขๆฐๆฎ้ๆถ๏ผๅบๅ `-s` ้ๆฉๆฐๆฎๆผๆฅ๏ผไธๅ ๅไธบ่ฆ็ใ
|
157 |
+
#### 2.3.3 synthesizerๆจกๅ่ฎญ็ป๏ผ
|
158 |
+
่ถ
ๅๆฐๆไปถ่ทฏๅพ๏ผ`models/synthesizer/hparams.py`๏ผ้ๅฐ`MockingBird/control/cli/synthesizer_train.py`็งปๆ`MockingBird/synthesizer_train.py`็ปๆใ
|
159 |
+
```shell
|
160 |
+
python synthesizer_train.py <name> <datasets_path> \
|
161 |
+
-m <out_dir>
|
162 |
+
```
|
163 |
+
ๅ
ถไธญ `<name>` ๆฏ่ฎญ็ปไบง็ๆไปถ็ๅ็งฐ๏ผๅฏ่ช่กไฟฎๆนใ
|
164 |
+
|
165 |
+
ๅ
ถไธญ `<datasets_path>` ๆฏ็ป่ฟ `Step 2.2.1` ๅค็่ฟๅ็ๆฐๆฎ้่ทฏๅพใ
|
166 |
+
|
167 |
+
ๅ
ถไธญ `<out_dir> `ไธบ่ฎญ็ปๆถๆๆๆฐๆฎ็ไฟๅญ่ทฏๅพใ
|
168 |
+
|
169 |
+
### 2.4 vocoderๆจกๅ่ฎญ็ป
|
170 |
+
vocoderๆจกๅๅฏน็ๆๆๆๅฝฑๅไธๅคง๏ผๅทฒ้ข็ฝฎ3ๆฌพใ
|
171 |
+
#### 2.4.1 ๆฐๆฎ้ขๅค็
|
172 |
+
```shell
|
173 |
+
python vocoder_preprocess.py <datasets_root> \
|
174 |
+
-m <synthesizer_model_path>
|
175 |
+
```
|
176 |
+
|
177 |
+
ๅ
ถไธญ`<datasets_root>`ไธบไฝ ๆฐๆฎ้่ทฏๅพใ
|
178 |
+
|
179 |
+
ๅ
ถไธญ `<synthesizer_model_path>`ไธบsynthesizerๆจกๅๅฐๅใ
|
180 |
+
|
181 |
+
#### 2.4.2 wavernnๅฃฐ็ ๅจ่ฎญ็ป:
|
182 |
+
```
|
183 |
+
python vocoder_train.py <name> <datasets_root>
|
184 |
+
```
|
185 |
+
#### 2.4.3 hifiganๅฃฐ็ ๅจ่ฎญ็ป:
|
186 |
+
```
|
187 |
+
python vocoder_train.py <name> <datasets_root> hifigan
|
188 |
+
```
|
189 |
+
#### 2.4.4 freganๅฃฐ็ ๅจ่ฎญ็ป:
|
190 |
+
```
|
191 |
+
python vocoder_train.py <name> <datasets_root> \
|
192 |
+
--config config.json fregan
|
193 |
+
```
|
194 |
+
ๅฐGANๅฃฐ็ ๅจ็่ฎญ็ปๅๆขไธบๅคGPUๆจกๅผ๏ผไฟฎๆน`GAN`ๆไปถๅคนไธ`.json`ๆไปถไธญ็`num_gpus`ๅๆฐใ
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
## 3.่ด่ฐข
|
201 |
+
### 3.1 ้กน็ฎ่ด่ฐข
|
202 |
+
่ฏฅๅบไธๅผๅงไปไป
ๆฏๆ่ฑ่ฏญ็[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) ๅๅๅบๆฅ็๏ผ้ธฃ่ฐขไฝ่
ใ
|
203 |
+
### 3.2 ่ฎบๆ่ด่ฐข
|
204 |
+
| URL | Designation | ๆ ้ข | ๅฎ็ฐๆบ็ |
|
205 |
+
| --- | ----------- | ----- | --------------------- |
|
206 |
+
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | ๆฌไปฃ็ ๅบ |
|
207 |
+
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | ๆฌไปฃ็ ๅบ |
|
208 |
+
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | ๆฌไปฃ็ ๅบ |
|
209 |
+
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | ๆฌไปฃ็ ๅบ |
|
210 |
+
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
211 |
+
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
212 |
+
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | ๆฌไปฃ็ ๅบ |
|
213 |
+
|
214 |
+
### 3.3 ๅผๅ่
่ด่ฐข
|
215 |
+
|
216 |
+
ไฝไธบAI้ขๅ็ไปไธ่
๏ผๆไปฌไธไป
ไนไบๅผๅไธไบๅ
ทๆ้็จ็ขๆไน็็ฎๆณ้กน็ฎ๏ผๅๆถไนไนไบๅไบซ้กน็ฎไปฅๅๅผๅ่ฟ็จไธญๆถ่ท็ๅๆฆใ
|
217 |
+
|
218 |
+
ๅ ๆญค๏ผไฝ ไปฌ็ไฝฟ็จๆฏๅฏนๆไปฌ้กน็ฎ็ๆๅคง่ฎคๅฏใๅๆถๅฝไฝ ไปฌๅจ้กน็ฎไฝฟ็จไธญ้ๅฐไธไบ้ฎ้ขๆถ๏ผๆฌข่ฟไฝ ไปฌ้ๆถๅจissueไธ็่จใไฝ ไปฌ็ๆๆญฃ่ฟๅฏนไบ้กน็ฎ็ๅ็ปญไผๅๅ
ทๆๅๅ้ๅคง็็ๆไนใ
|
219 |
+
|
220 |
+
ไธบไบ่กจ็คบๆ่ฐข๏ผๆไปฌๅฐๅจๆฌ้กน็ฎไธญ็ไธๅไฝๅผๅ่
ไฟกๆฏไปฅๅ็ธๅฏนๅบ็่ดก็ฎใ
|
221 |
+
|
222 |
+
- ------------------------------------------------ ๅผ ๅ ่
่ดก ็ฎ ๅ
ๅฎน ---------------------------------------------------------------------------------
|
223 |
+
|
archived_untest_files/demo_cli.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.encoder.params_model import model_embedding_size as speaker_embedding_size
|
2 |
+
from utils.argutils import print_args
|
3 |
+
from utils.modelutils import check_model_paths
|
4 |
+
from models.synthesizer.inference import Synthesizer
|
5 |
+
from models.encoder import inference as encoder
|
6 |
+
from models.vocoder import inference as vocoder
|
7 |
+
from pathlib import Path
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
import librosa
|
11 |
+
import argparse
|
12 |
+
import torch
|
13 |
+
import sys
|
14 |
+
import os
|
15 |
+
from audioread.exceptions import NoBackendError
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
## Info & args
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
21 |
+
)
|
22 |
+
parser.add_argument("-e", "--enc_model_fpath", type=Path,
|
23 |
+
default="encoder/saved_models/pretrained.pt",
|
24 |
+
help="Path to a saved encoder")
|
25 |
+
parser.add_argument("-s", "--syn_model_fpath", type=Path,
|
26 |
+
default="synthesizer/saved_models/pretrained/pretrained.pt",
|
27 |
+
help="Path to a saved synthesizer")
|
28 |
+
parser.add_argument("-v", "--voc_model_fpath", type=Path,
|
29 |
+
default="vocoder/saved_models/pretrained/pretrained.pt",
|
30 |
+
help="Path to a saved vocoder")
|
31 |
+
parser.add_argument("--cpu", action="store_true", help=\
|
32 |
+
"If True, processing is done on CPU, even when a GPU is available.")
|
33 |
+
parser.add_argument("--no_sound", action="store_true", help=\
|
34 |
+
"If True, audio won't be played.")
|
35 |
+
parser.add_argument("--seed", type=int, default=None, help=\
|
36 |
+
"Optional random number seed value to make toolbox deterministic.")
|
37 |
+
parser.add_argument("--no_mp3_support", action="store_true", help=\
|
38 |
+
"If True, disallows loading mp3 files to prevent audioread errors when ffmpeg is not installed.")
|
39 |
+
args = parser.parse_args()
|
40 |
+
print_args(args, parser)
|
41 |
+
if not args.no_sound:
|
42 |
+
import sounddevice as sd
|
43 |
+
|
44 |
+
if args.cpu:
|
45 |
+
# Hide GPUs from Pytorch to force CPU processing
|
46 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
47 |
+
|
48 |
+
if not args.no_mp3_support:
|
49 |
+
try:
|
50 |
+
librosa.load("samples/1320_00000.mp3")
|
51 |
+
except NoBackendError:
|
52 |
+
print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
|
53 |
+
"Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
|
54 |
+
exit(-1)
|
55 |
+
|
56 |
+
print("Running a test of your configuration...\n")
|
57 |
+
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
device_id = torch.cuda.current_device()
|
60 |
+
gpu_properties = torch.cuda.get_device_properties(device_id)
|
61 |
+
## Print some environment information (for debugging purposes)
|
62 |
+
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
|
63 |
+
"%.1fGb total memory.\n" %
|
64 |
+
(torch.cuda.device_count(),
|
65 |
+
device_id,
|
66 |
+
gpu_properties.name,
|
67 |
+
gpu_properties.major,
|
68 |
+
gpu_properties.minor,
|
69 |
+
gpu_properties.total_memory / 1e9))
|
70 |
+
else:
|
71 |
+
print("Using CPU for inference.\n")
|
72 |
+
|
73 |
+
## Remind the user to download pretrained models if needed
|
74 |
+
check_model_paths(encoder_path=args.enc_model_fpath,
|
75 |
+
synthesizer_path=args.syn_model_fpath,
|
76 |
+
vocoder_path=args.voc_model_fpath)
|
77 |
+
|
78 |
+
## Load the models one by one.
|
79 |
+
print("Preparing the encoder, the synthesizer and the vocoder...")
|
80 |
+
encoder.load_model(args.enc_model_fpath)
|
81 |
+
synthesizer = Synthesizer(args.syn_model_fpath)
|
82 |
+
vocoder.load_model(args.voc_model_fpath)
|
83 |
+
|
84 |
+
|
85 |
+
## Run a test
|
86 |
+
print("Testing your configuration with small inputs.")
|
87 |
+
# Forward an audio waveform of zeroes that lasts 1 second. Notice how we can get the encoder's
|
88 |
+
# sampling rate, which may differ.
|
89 |
+
# If you're unfamiliar with digital audio, know that it is encoded as an array of floats
|
90 |
+
# (or sometimes integers, but mostly floats in this projects) ranging from -1 to 1.
|
91 |
+
# The sampling rate is the number of values (samples) recorded per second, it is set to
|
92 |
+
# 16000 for the encoder. Creating an array of length <sampling_rate> will always correspond
|
93 |
+
# to an audio of 1 second.
|
94 |
+
print("\tTesting the encoder...")
|
95 |
+
encoder.embed_utterance(np.zeros(encoder.sampling_rate))
|
96 |
+
|
97 |
+
# Create a dummy embedding. You would normally use the embedding that encoder.embed_utterance
|
98 |
+
# returns, but here we're going to make one ourselves just for the sake of showing that it's
|
99 |
+
# possible.
|
100 |
+
embed = np.random.rand(speaker_embedding_size)
|
101 |
+
# Embeddings are L2-normalized (this isn't important here, but if you want to make your own
|
102 |
+
# embeddings it will be).
|
103 |
+
embed /= np.linalg.norm(embed)
|
104 |
+
# The synthesizer can handle multiple inputs with batching. Let's create another embedding to
|
105 |
+
# illustrate that
|
106 |
+
embeds = [embed, np.zeros(speaker_embedding_size)]
|
107 |
+
texts = ["test 1", "test 2"]
|
108 |
+
print("\tTesting the synthesizer... (loading the model will output a lot of text)")
|
109 |
+
mels = synthesizer.synthesize_spectrograms(texts, embeds)
|
110 |
+
|
111 |
+
# The vocoder synthesizes one waveform at a time, but it's more efficient for long ones. We
|
112 |
+
# can concatenate the mel spectrograms to a single one.
|
113 |
+
mel = np.concatenate(mels, axis=1)
|
114 |
+
# The vocoder can take a callback function to display the generation. More on that later. For
|
115 |
+
# now we'll simply hide it like this:
|
116 |
+
no_action = lambda *args: None
|
117 |
+
print("\tTesting the vocoder...")
|
118 |
+
# For the sake of making this test short, we'll pass a short target length. The target length
|
119 |
+
# is the length of the wav segments that are processed in parallel. E.g. for audio sampled
|
120 |
+
# at 16000 Hertz, a target length of 8000 means that the target audio will be cut in chunks of
|
121 |
+
# 0.5 seconds which will all be generated together. The parameters here are absurdly short, and
|
122 |
+
# that has a detrimental effect on the quality of the audio. The default parameters are
|
123 |
+
# recommended in general.
|
124 |
+
vocoder.infer_waveform(mel, target=200, overlap=50, progress_callback=no_action)
|
125 |
+
|
126 |
+
print("All test passed! You can now synthesize speech.\n\n")
|
127 |
+
|
128 |
+
|
129 |
+
## Interactive speech generation
|
130 |
+
print("This is a GUI-less example of interface to SV2TTS. The purpose of this script is to "
|
131 |
+
"show how you can interface this project easily with your own. See the source code for "
|
132 |
+
"an explanation of what is happening.\n")
|
133 |
+
|
134 |
+
print("Interactive generation loop")
|
135 |
+
num_generated = 0
|
136 |
+
while True:
|
137 |
+
try:
|
138 |
+
# Get the reference audio filepath
|
139 |
+
message = "Reference voice: enter an audio filepath of a voice to be cloned (mp3, " \
|
140 |
+
"wav, m4a, flac, ...):\n"
|
141 |
+
in_fpath = Path(input(message).replace("\"", "").replace("\'", ""))
|
142 |
+
|
143 |
+
if in_fpath.suffix.lower() == ".mp3" and args.no_mp3_support:
|
144 |
+
print("Can't Use mp3 files please try again:")
|
145 |
+
continue
|
146 |
+
## Computing the embedding
|
147 |
+
# First, we load the wav using the function that the speaker encoder provides. This is
|
148 |
+
# important: there is preprocessing that must be applied.
|
149 |
+
|
150 |
+
# The following two methods are equivalent:
|
151 |
+
# - Directly load from the filepath:
|
152 |
+
preprocessed_wav = encoder.preprocess_wav(in_fpath)
|
153 |
+
# - If the wav is already loaded:
|
154 |
+
original_wav, sampling_rate = librosa.load(str(in_fpath))
|
155 |
+
preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
|
156 |
+
print("Loaded file succesfully")
|
157 |
+
|
158 |
+
# Then we derive the embedding. There are many functions and parameters that the
|
159 |
+
# speaker encoder interfaces. These are mostly for in-depth research. You will typically
|
160 |
+
# only use this function (with its default parameters):
|
161 |
+
embed = encoder.embed_utterance(preprocessed_wav)
|
162 |
+
print("Created the embedding")
|
163 |
+
|
164 |
+
|
165 |
+
## Generating the spectrogram
|
166 |
+
text = input("Write a sentence (+-20 words) to be synthesized:\n")
|
167 |
+
|
168 |
+
# If seed is specified, reset torch seed and force synthesizer reload
|
169 |
+
if args.seed is not None:
|
170 |
+
torch.manual_seed(args.seed)
|
171 |
+
synthesizer = Synthesizer(args.syn_model_fpath)
|
172 |
+
|
173 |
+
# The synthesizer works in batch, so you need to put your data in a list or numpy array
|
174 |
+
texts = [text]
|
175 |
+
embeds = [embed]
|
176 |
+
# If you know what the attention layer alignments are, you can retrieve them here by
|
177 |
+
# passing return_alignments=True
|
178 |
+
specs = synthesizer.synthesize_spectrograms(texts, embeds)
|
179 |
+
spec = specs[0]
|
180 |
+
print("Created the mel spectrogram")
|
181 |
+
|
182 |
+
|
183 |
+
## Generating the waveform
|
184 |
+
print("Synthesizing the waveform:")
|
185 |
+
|
186 |
+
# If seed is specified, reset torch seed and reload vocoder
|
187 |
+
if args.seed is not None:
|
188 |
+
torch.manual_seed(args.seed)
|
189 |
+
vocoder.load_model(args.voc_model_fpath)
|
190 |
+
|
191 |
+
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
|
192 |
+
# spectrogram, the more time-efficient the vocoder.
|
193 |
+
generated_wav = vocoder.infer_waveform(spec)
|
194 |
+
|
195 |
+
|
196 |
+
## Post-generation
|
197 |
+
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
|
198 |
+
# pad it.
|
199 |
+
generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
|
200 |
+
|
201 |
+
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
|
202 |
+
generated_wav = encoder.preprocess_wav(generated_wav)
|
203 |
+
|
204 |
+
# Play the audio (non-blocking)
|
205 |
+
if not args.no_sound:
|
206 |
+
try:
|
207 |
+
sd.stop()
|
208 |
+
sd.play(generated_wav, synthesizer.sample_rate)
|
209 |
+
except sd.PortAudioError as e:
|
210 |
+
print("\nCaught exception: %s" % repr(e))
|
211 |
+
print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
|
212 |
+
except:
|
213 |
+
raise
|
214 |
+
|
215 |
+
# Save it on the disk
|
216 |
+
filename = "demo_output_%02d.wav" % num_generated
|
217 |
+
print(generated_wav.dtype)
|
218 |
+
sf.write(filename, generated_wav.astype(np.float32), synthesizer.sample_rate)
|
219 |
+
num_generated += 1
|
220 |
+
print("\nSaved output as %s\n\n" % filename)
|
221 |
+
|
222 |
+
|
223 |
+
except Exception as e:
|
224 |
+
print("Caught exception: %s" % repr(e))
|
225 |
+
print("Restarting\n")
|
control/__init__.py
ADDED
File without changes
|
control/cli/__init__.py
ADDED
File without changes
|
control/cli/encoder_preprocess.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from models.encoder.preprocess import (preprocess_aidatatang_200zh,
|
5 |
+
preprocess_librispeech, preprocess_voxceleb1,
|
6 |
+
preprocess_voxceleb2)
|
7 |
+
from utils.argutils import print_args
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
11 |
+
pass
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(
|
14 |
+
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
15 |
+
"writes them to the disk. This will allow you to train the encoder. The "
|
16 |
+
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
|
17 |
+
formatter_class=MyFormatter
|
18 |
+
)
|
19 |
+
parser.add_argument("datasets_root", type=Path, help=\
|
20 |
+
"Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
|
21 |
+
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
22 |
+
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
23 |
+
"defaults to <datasets_root>/SV2TTS/encoder/")
|
24 |
+
parser.add_argument("-d", "--datasets", type=str,
|
25 |
+
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
|
26 |
+
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
27 |
+
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
28 |
+
"voxceleb2.")
|
29 |
+
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
30 |
+
"Whether to skip existing output files with the same name. Useful if this script was "
|
31 |
+
"interrupted.")
|
32 |
+
parser.add_argument("--no_trim", action="store_true", help=\
|
33 |
+
"Preprocess audio without trimming silences (not recommended).")
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
# Verify webrtcvad is available
|
37 |
+
if not args.no_trim:
|
38 |
+
try:
|
39 |
+
import webrtcvad
|
40 |
+
except:
|
41 |
+
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
42 |
+
"noise removal and is recommended. Please install and try again. If installation fails, "
|
43 |
+
"use --no_trim to disable this error message.")
|
44 |
+
del args.no_trim
|
45 |
+
|
46 |
+
# Process the arguments
|
47 |
+
args.datasets = args.datasets.split(",")
|
48 |
+
if not hasattr(args, "out_dir"):
|
49 |
+
args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
|
50 |
+
assert args.datasets_root.exists()
|
51 |
+
args.out_dir.mkdir(exist_ok=True, parents=True)
|
52 |
+
|
53 |
+
# Preprocess the datasets
|
54 |
+
print_args(args, parser)
|
55 |
+
preprocess_func = {
|
56 |
+
"librispeech_other": preprocess_librispeech,
|
57 |
+
"voxceleb1": preprocess_voxceleb1,
|
58 |
+
"voxceleb2": preprocess_voxceleb2,
|
59 |
+
"aidatatang_200zh": preprocess_aidatatang_200zh,
|
60 |
+
}
|
61 |
+
args = vars(args)
|
62 |
+
for dataset in args.pop("datasets"):
|
63 |
+
print("Preprocessing %s" % dataset)
|
64 |
+
preprocess_func[dataset](**args)
|
control/cli/encoder_train.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.argutils import print_args
|
2 |
+
from models.encoder.train import train
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser(
|
9 |
+
description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
|
10 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
11 |
+
)
|
12 |
+
|
13 |
+
parser.add_argument("run_id", type=str, help= \
|
14 |
+
"Name for this model instance. If a model state from the same run ID was previously "
|
15 |
+
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
16 |
+
"restart from scratch.")
|
17 |
+
parser.add_argument("clean_data_root", type=Path, help= \
|
18 |
+
"Path to the output directory of encoder_preprocess.py. If you left the default "
|
19 |
+
"output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
|
20 |
+
parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
|
21 |
+
"Path to the output directory that will contain the saved model weights, as well as "
|
22 |
+
"backups of those weights and plots generated during training.")
|
23 |
+
parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
|
24 |
+
"Number of steps between updates of the loss and the plots.")
|
25 |
+
parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
|
26 |
+
"Number of steps between updates of the umap projection. Set to 0 to never update the "
|
27 |
+
"projections.")
|
28 |
+
parser.add_argument("-s", "--save_every", type=int, default=500, help= \
|
29 |
+
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
30 |
+
"model.")
|
31 |
+
parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
|
32 |
+
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
33 |
+
"model.")
|
34 |
+
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
35 |
+
"Do not load any saved model.")
|
36 |
+
parser.add_argument("--visdom_server", type=str, default="http://localhost")
|
37 |
+
parser.add_argument("--no_visdom", action="store_true", help= \
|
38 |
+
"Disable visdom.")
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
# Process the arguments
|
42 |
+
args.models_dir.mkdir(exist_ok=True)
|
43 |
+
|
44 |
+
# Run the training
|
45 |
+
print_args(args, parser)
|
46 |
+
train(**vars(args))
|
47 |
+
|
control/cli/ppg2mel_train.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from utils.hparams import HpsYaml
|
6 |
+
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
7 |
+
|
8 |
+
# For reproducibility, comment these may speed up training
|
9 |
+
torch.backends.cudnn.deterministic = True
|
10 |
+
torch.backends.cudnn.benchmark = False
|
11 |
+
|
12 |
+
def main():
|
13 |
+
# Arguments
|
14 |
+
parser = argparse.ArgumentParser(description=
|
15 |
+
'Training PPG2Mel VC model.')
|
16 |
+
parser.add_argument('--config', type=str,
|
17 |
+
help='Path to experiment config, e.g., config/vc.yaml')
|
18 |
+
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
19 |
+
parser.add_argument('--logdir', default='log/', type=str,
|
20 |
+
help='Logging path.', required=False)
|
21 |
+
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
22 |
+
help='Checkpoint path.', required=False)
|
23 |
+
parser.add_argument('--outdir', default='result/', type=str,
|
24 |
+
help='Decode output path.', required=False)
|
25 |
+
parser.add_argument('--load', default=None, type=str,
|
26 |
+
help='Load pre-trained model (for training only)', required=False)
|
27 |
+
parser.add_argument('--warm_start', action='store_true',
|
28 |
+
help='Load model weights only, ignore specified layers.')
|
29 |
+
parser.add_argument('--seed', default=0, type=int,
|
30 |
+
help='Random seed for reproducable results.', required=False)
|
31 |
+
parser.add_argument('--njobs', default=8, type=int,
|
32 |
+
help='Number of threads for dataloader/decoding.', required=False)
|
33 |
+
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
34 |
+
parser.add_argument('--no-pin', action='store_true',
|
35 |
+
help='Disable pin-memory for dataloader')
|
36 |
+
parser.add_argument('--test', action='store_true', help='Test the model.')
|
37 |
+
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
38 |
+
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
39 |
+
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
40 |
+
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
41 |
+
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
42 |
+
|
43 |
+
###
|
44 |
+
|
45 |
+
paras = parser.parse_args()
|
46 |
+
setattr(paras, 'gpu', not paras.cpu)
|
47 |
+
setattr(paras, 'pin_memory', not paras.no_pin)
|
48 |
+
setattr(paras, 'verbose', not paras.no_msg)
|
49 |
+
# Make the config dict dot visitable
|
50 |
+
config = HpsYaml(paras.config)
|
51 |
+
|
52 |
+
np.random.seed(paras.seed)
|
53 |
+
torch.manual_seed(paras.seed)
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
torch.cuda.manual_seed_all(paras.seed)
|
56 |
+
|
57 |
+
print(">>> OneShot VC training ...")
|
58 |
+
mode = "train"
|
59 |
+
solver = Solver(config, paras, mode)
|
60 |
+
solver.load_data()
|
61 |
+
solver.set_model()
|
62 |
+
solver.exec()
|
63 |
+
print(">>> Oneshot VC train finished!")
|
64 |
+
sys.exit(0)
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
main()
|
control/cli/pre4ppg.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from models.ppg2mel.preprocess import preprocess_dataset
|
5 |
+
from pathlib import Path
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
recognized_datasets = [
|
9 |
+
"aidatatang_200zh",
|
10 |
+
"aidatatang_200zh_s", # sample
|
11 |
+
]
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
description="Preprocesses audio files from datasets, to be used by the "
|
16 |
+
"ppg2mel model for training.",
|
17 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
18 |
+
)
|
19 |
+
parser.add_argument("datasets_root", type=Path, help=\
|
20 |
+
"Path to the directory containing your datasets.")
|
21 |
+
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
22 |
+
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
23 |
+
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
24 |
+
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
25 |
+
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
26 |
+
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
27 |
+
"Number of processes in parallel.")
|
28 |
+
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
29 |
+
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
30 |
+
# "interrupted. ")
|
31 |
+
# parser.add_argument("--hparams", type=str, default="", help=\
|
32 |
+
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
33 |
+
# parser.add_argument("--no_trim", action="store_true", help=\
|
34 |
+
# "Preprocess audio without trimming silences (not recommended).")
|
35 |
+
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
36 |
+
"Path your trained ppg encoder model.")
|
37 |
+
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
38 |
+
"Path your trained speaker encoder model.")
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
42 |
+
|
43 |
+
# Create directories
|
44 |
+
assert args.datasets_root.exists()
|
45 |
+
if not hasattr(args, "out_dir"):
|
46 |
+
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
47 |
+
args.out_dir.mkdir(exist_ok=True, parents=True)
|
48 |
+
|
49 |
+
preprocess_dataset(**vars(args))
|
control/cli/synthesizer_train.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.synthesizer.hparams import hparams
|
2 |
+
from models.synthesizer.train import train
|
3 |
+
from utils.argutils import print_args
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
def new_train():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument("run_id", type=str, help= \
|
9 |
+
"Name for this model instance. If a model state from the same run ID was previously "
|
10 |
+
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
11 |
+
"restart from scratch.")
|
12 |
+
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
13 |
+
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
14 |
+
"the wavs and the embeds.")
|
15 |
+
parser.add_argument("-m", "--models_dir", type=str, default=f"data/ckpt/synthesizer/", help=\
|
16 |
+
"Path to the output directory that will contain the saved model weights and the logs.")
|
17 |
+
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
18 |
+
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
19 |
+
"model.")
|
20 |
+
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
21 |
+
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
22 |
+
"model.")
|
23 |
+
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
|
24 |
+
"Number of steps between summary the training info in tensorboard")
|
25 |
+
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
26 |
+
"Do not load any saved model and restart from scratch.")
|
27 |
+
parser.add_argument("--hparams", default="",
|
28 |
+
help="Hyperparameter overrides as a comma-separated list of name=value "
|
29 |
+
"pairs")
|
30 |
+
args, _ = parser.parse_known_args()
|
31 |
+
print_args(args, parser)
|
32 |
+
|
33 |
+
args.hparams = hparams.parse(args.hparams)
|
34 |
+
|
35 |
+
# Run the training
|
36 |
+
train(**vars(args))
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
new_train()
|
control/cli/train_ppg2mel.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from utils.hparams import HpsYaml
|
6 |
+
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
7 |
+
|
8 |
+
# For reproducibility, comment these may speed up training
|
9 |
+
torch.backends.cudnn.deterministic = True
|
10 |
+
torch.backends.cudnn.benchmark = False
|
11 |
+
|
12 |
+
def main():
|
13 |
+
# Arguments
|
14 |
+
parser = argparse.ArgumentParser(description=
|
15 |
+
'Training PPG2Mel VC model.')
|
16 |
+
parser.add_argument('--config', type=str,
|
17 |
+
help='Path to experiment config, e.g., config/vc.yaml')
|
18 |
+
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
19 |
+
parser.add_argument('--logdir', default='log/', type=str,
|
20 |
+
help='Logging path.', required=False)
|
21 |
+
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
22 |
+
help='Checkpoint path.', required=False)
|
23 |
+
parser.add_argument('--outdir', default='result/', type=str,
|
24 |
+
help='Decode output path.', required=False)
|
25 |
+
parser.add_argument('--load', default=None, type=str,
|
26 |
+
help='Load pre-trained model (for training only)', required=False)
|
27 |
+
parser.add_argument('--warm_start', action='store_true',
|
28 |
+
help='Load model weights only, ignore specified layers.')
|
29 |
+
parser.add_argument('--seed', default=0, type=int,
|
30 |
+
help='Random seed for reproducable results.', required=False)
|
31 |
+
parser.add_argument('--njobs', default=8, type=int,
|
32 |
+
help='Number of threads for dataloader/decoding.', required=False)
|
33 |
+
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
34 |
+
parser.add_argument('--no-pin', action='store_true',
|
35 |
+
help='Disable pin-memory for dataloader')
|
36 |
+
parser.add_argument('--test', action='store_true', help='Test the model.')
|
37 |
+
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
38 |
+
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
39 |
+
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
40 |
+
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
41 |
+
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
42 |
+
|
43 |
+
###
|
44 |
+
paras = parser.parse_args()
|
45 |
+
setattr(paras, 'gpu', not paras.cpu)
|
46 |
+
setattr(paras, 'pin_memory', not paras.no_pin)
|
47 |
+
setattr(paras, 'verbose', not paras.no_msg)
|
48 |
+
# Make the config dict dot visitable
|
49 |
+
config = HpsYaml(paras.config)
|
50 |
+
|
51 |
+
np.random.seed(paras.seed)
|
52 |
+
torch.manual_seed(paras.seed)
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
torch.cuda.manual_seed_all(paras.seed)
|
55 |
+
|
56 |
+
print(">>> OneShot VC training ...")
|
57 |
+
mode = "train"
|
58 |
+
solver = Solver(config, paras, mode)
|
59 |
+
solver.load_data()
|
60 |
+
solver.set_model()
|
61 |
+
solver.exec()
|
62 |
+
print(">>> Oneshot VC train finished!")
|
63 |
+
sys.exit(0)
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
main()
|
control/cli/vocoder_preprocess.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.synthesizer.synthesize import run_synthesis
|
2 |
+
from models.synthesizer.hparams import hparams
|
3 |
+
from utils.argutils import print_args
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
10 |
+
pass
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser(
|
13 |
+
description="Creates ground-truth aligned (GTA) spectrograms from the vocoder.",
|
14 |
+
formatter_class=MyFormatter
|
15 |
+
)
|
16 |
+
parser.add_argument("datasets_root", type=str, help=\
|
17 |
+
"Path to the directory containing your SV2TTS directory. If you specify both --in_dir and "
|
18 |
+
"--out_dir, this argument won't be used.")
|
19 |
+
parser.add_argument("-m", "--model_dir", type=str,
|
20 |
+
default="synthesizer/saved_models/mandarin/", help=\
|
21 |
+
"Path to the pretrained model directory.")
|
22 |
+
parser.add_argument("-i", "--in_dir", type=str, default=argparse.SUPPRESS, help= \
|
23 |
+
"Path to the synthesizer directory that contains the mel spectrograms, the wavs and the "
|
24 |
+
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
25 |
+
parser.add_argument("-o", "--out_dir", type=str, default=argparse.SUPPRESS, help= \
|
26 |
+
"Path to the output vocoder directory that will contain the ground truth aligned mel "
|
27 |
+
"spectrograms. Defaults to <datasets_root>/SV2TTS/vocoder/.")
|
28 |
+
parser.add_argument("--hparams", default="",
|
29 |
+
help="Hyperparameter overrides as a comma-separated list of name=value "
|
30 |
+
"pairs")
|
31 |
+
parser.add_argument("--no_trim", action="store_true", help=\
|
32 |
+
"Preprocess audio without trimming silences (not recommended).")
|
33 |
+
parser.add_argument("--cpu", action="store_true", help=\
|
34 |
+
"If True, processing is done on CPU, even when a GPU is available.")
|
35 |
+
args = parser.parse_args()
|
36 |
+
print_args(args, parser)
|
37 |
+
modified_hp = hparams.parse(args.hparams)
|
38 |
+
|
39 |
+
if not hasattr(args, "in_dir"):
|
40 |
+
args.in_dir = os.path.join(args.datasets_root, "SV2TTS", "synthesizer")
|
41 |
+
if not hasattr(args, "out_dir"):
|
42 |
+
args.out_dir = os.path.join(args.datasets_root, "SV2TTS", "vocoder")
|
43 |
+
|
44 |
+
if args.cpu:
|
45 |
+
# Hide GPUs from Pytorch to force CPU processing
|
46 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
47 |
+
|
48 |
+
# Verify webrtcvad is available
|
49 |
+
if not args.no_trim:
|
50 |
+
try:
|
51 |
+
import webrtcvad
|
52 |
+
except:
|
53 |
+
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
54 |
+
"noise removal and is recommended. Please install and try again. If installation fails, "
|
55 |
+
"use --no_trim to disable this error message.")
|
56 |
+
del args.no_trim
|
57 |
+
|
58 |
+
run_synthesis(args.in_dir, args.out_dir, args.model_dir, modified_hp)
|
59 |
+
|
control/cli/vocoder_train.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.argutils import print_args
|
2 |
+
from models.vocoder.wavernn.train import train
|
3 |
+
from models.vocoder.hifigan.train import train as train_hifigan
|
4 |
+
from models.vocoder.fregan.train import train as train_fregan
|
5 |
+
from utils.util import AttrDict
|
6 |
+
from pathlib import Path
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import torch
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
parser = argparse.ArgumentParser(
|
14 |
+
description="Trains the vocoder from the synthesizer audios and the GTA synthesized mels, "
|
15 |
+
"or ground truth mels.",
|
16 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
17 |
+
)
|
18 |
+
|
19 |
+
parser.add_argument("run_id", type=str, help= \
|
20 |
+
"Name for this model instance. If a model state from the same run ID was previously "
|
21 |
+
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
22 |
+
"restart from scratch.")
|
23 |
+
parser.add_argument("datasets_root", type=str, help= \
|
24 |
+
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
|
25 |
+
"will take priority over this argument.")
|
26 |
+
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
|
27 |
+
"Choose the vocoder type for train. Defaults to wavernn"
|
28 |
+
"Now, Support <hifigan> and <wavernn> for choose")
|
29 |
+
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
30 |
+
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
31 |
+
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
32 |
+
parser.add_argument("--voc_dir", type=str, default=argparse.SUPPRESS, help= \
|
33 |
+
"Path to the vocoder directory that contains the GTA synthesized mel spectrograms. "
|
34 |
+
"Defaults to <datasets_root>/SV2TTS/vocoder/. Unused if --ground_truth is passed.")
|
35 |
+
parser.add_argument("-m", "--models_dir", type=str, default="vocoder/saved_models/", help=\
|
36 |
+
"Path to the directory that will contain the saved model weights, as well as backups "
|
37 |
+
"of those weights and wavs generated during training.")
|
38 |
+
parser.add_argument("-g", "--ground_truth", action="store_true", help= \
|
39 |
+
"Train on ground truth spectrograms (<datasets_root>/SV2TTS/synthesizer/mels).")
|
40 |
+
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
41 |
+
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
42 |
+
"model.")
|
43 |
+
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
44 |
+
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
45 |
+
"model.")
|
46 |
+
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
47 |
+
"Do not load any saved model and restart from scratch.")
|
48 |
+
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
|
49 |
+
args = parser.parse_args()
|
50 |
+
|
51 |
+
if not hasattr(args, "syn_dir"):
|
52 |
+
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
|
53 |
+
args.syn_dir = Path(args.syn_dir)
|
54 |
+
if not hasattr(args, "voc_dir"):
|
55 |
+
args.voc_dir = Path(args.datasets_root, "SV2TTS", "vocoder")
|
56 |
+
args.voc_dir = Path(args.voc_dir)
|
57 |
+
del args.datasets_root
|
58 |
+
args.models_dir = Path(args.models_dir)
|
59 |
+
args.models_dir.mkdir(exist_ok=True)
|
60 |
+
|
61 |
+
print_args(args, parser)
|
62 |
+
|
63 |
+
# Process the arguments
|
64 |
+
if args.vocoder_type == "wavernn":
|
65 |
+
# Run the training wavernn
|
66 |
+
delattr(args, 'vocoder_type')
|
67 |
+
delattr(args, 'config')
|
68 |
+
train(**vars(args))
|
69 |
+
elif args.vocoder_type == "hifigan":
|
70 |
+
with open(args.config) as f:
|
71 |
+
json_config = json.load(f)
|
72 |
+
h = AttrDict(json_config)
|
73 |
+
if h.num_gpus > 1:
|
74 |
+
h.num_gpus = torch.cuda.device_count()
|
75 |
+
h.batch_size = int(h.batch_size / h.num_gpus)
|
76 |
+
print('Batch size per GPU :', h.batch_size)
|
77 |
+
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
|
78 |
+
else:
|
79 |
+
train_hifigan(0, args, h)
|
80 |
+
elif args.vocoder_type == "fregan":
|
81 |
+
with Path('vocoder/fregan/config.json').open() as f:
|
82 |
+
json_config = json.load(f)
|
83 |
+
h = AttrDict(json_config)
|
84 |
+
if h.num_gpus > 1:
|
85 |
+
h.num_gpus = torch.cuda.device_count()
|
86 |
+
h.batch_size = int(h.batch_size / h.num_gpus)
|
87 |
+
print('Batch size per GPU :', h.batch_size)
|
88 |
+
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
|
89 |
+
else:
|
90 |
+
train_fregan(0, args, h)
|
91 |
+
|
92 |
+
|
control/mkgui/__init__.py
ADDED
File without changes
|
control/mkgui/app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from models.encoder import inference as encoder
|
6 |
+
import librosa
|
7 |
+
from scipy.io.wavfile import write
|
8 |
+
import re
|
9 |
+
import numpy as np
|
10 |
+
from control.mkgui.base.components.types import FileContent
|
11 |
+
from models.vocoder.hifigan import inference as gan_vocoder
|
12 |
+
from models.synthesizer.inference import Synthesizer
|
13 |
+
from typing import Any, Tuple
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
|
16 |
+
# Constants
|
17 |
+
AUDIO_SAMPLES_DIR = f"data{os.sep}samples{os.sep}"
|
18 |
+
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
|
19 |
+
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
20 |
+
VOC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}vocoder"
|
21 |
+
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
|
22 |
+
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
|
23 |
+
if not os.path.isdir("wavs"):
|
24 |
+
os.makedirs("wavs")
|
25 |
+
|
26 |
+
# Load local sample audio as options TODO: load dataset
|
27 |
+
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
28 |
+
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
29 |
+
# Pre-Load models
|
30 |
+
if os.path.isdir(SYN_MODELS_DIRT):
|
31 |
+
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
32 |
+
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
33 |
+
else:
|
34 |
+
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist. ่ฏทๅฐๆจกๅๆไปถไฝ็ฝฎ็งปๅจๅฐไธ่ฟฐไฝ็ฝฎไธญ่ฟ่ก้่ฏ๏ผ")
|
35 |
+
|
36 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
37 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
38 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
39 |
+
else:
|
40 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
41 |
+
|
42 |
+
if os.path.isdir(VOC_MODELS_DIRT):
|
43 |
+
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
44 |
+
print("Loaded vocoders models: " + str(len(synthesizers)))
|
45 |
+
else:
|
46 |
+
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
47 |
+
|
48 |
+
|
49 |
+
class Input(BaseModel):
|
50 |
+
message: str = Field(
|
51 |
+
..., example="ๆฌข่ฟไฝฟ็จๅทฅๅ
ท็ฎฑ, ็ฐๅทฒๆฏๆไธญๆ่พๅ
ฅ๏ผ", alias="ๆๆฌๅ
ๅฎน"
|
52 |
+
)
|
53 |
+
local_audio_file: audio_input_selection = Field(
|
54 |
+
..., alias="้ๆฉ่ฏญ้ณ๏ผๆฌๅฐwav๏ผ",
|
55 |
+
description="้ๆฉๆฌๅฐ่ฏญ้ณๆไปถ."
|
56 |
+
)
|
57 |
+
record_audio_file: FileContent = Field(default=None, alias="ๅฝๅถ่ฏญ้ณ",
|
58 |
+
description="ๅฝ้ณ.", is_recorder=True, mime_type="audio/wav")
|
59 |
+
upload_audio_file: FileContent = Field(default=None, alias="ๆไธไผ ่ฏญ้ณ",
|
60 |
+
description="ๆๆฝๆ็นๅปไธไผ .", mime_type="audio/wav")
|
61 |
+
encoder: encoders = Field(
|
62 |
+
..., alias="็ผ็ ๆจกๅ",
|
63 |
+
description="้ๆฉ่ฏญ้ณ็ผ็ ๆจกๅๆไปถ."
|
64 |
+
)
|
65 |
+
synthesizer: synthesizers = Field(
|
66 |
+
..., alias="ๅๆๆจกๅ",
|
67 |
+
description="้ๆฉ่ฏญ้ณๅๆๆจกๅๆไปถ."
|
68 |
+
)
|
69 |
+
vocoder: vocoders = Field(
|
70 |
+
..., alias="่ฏญ้ณ่งฃ็ ๆจกๅ",
|
71 |
+
description="้ๆฉ่ฏญ้ณ่งฃ็ ๆจกๅๆไปถ(็ฎๅๅชๆฏๆHifiGan็ฑปๅ)."
|
72 |
+
)
|
73 |
+
|
74 |
+
class AudioEntity(BaseModel):
|
75 |
+
content: bytes
|
76 |
+
mel: Any
|
77 |
+
|
78 |
+
class Output(BaseModel):
|
79 |
+
__root__: Tuple[AudioEntity, AudioEntity]
|
80 |
+
|
81 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
82 |
+
"""Custom output UI.
|
83 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
84 |
+
"""
|
85 |
+
src, result = self.__root__
|
86 |
+
|
87 |
+
streamlit_app.subheader("Synthesized Audio")
|
88 |
+
streamlit_app.audio(result.content, format="audio/wav")
|
89 |
+
|
90 |
+
fig, ax = plt.subplots()
|
91 |
+
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
92 |
+
ax.set_title("mel spectrogram(Source Audio)")
|
93 |
+
streamlit_app.pyplot(fig)
|
94 |
+
fig, ax = plt.subplots()
|
95 |
+
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
96 |
+
ax.set_title("mel spectrogram(Result Audio)")
|
97 |
+
streamlit_app.pyplot(fig)
|
98 |
+
|
99 |
+
|
100 |
+
def synthesize(input: Input) -> Output:
|
101 |
+
"""synthesize(ๅๆ)"""
|
102 |
+
# load models
|
103 |
+
encoder.load_model(Path(input.encoder.value))
|
104 |
+
current_synt = Synthesizer(Path(input.synthesizer.value))
|
105 |
+
gan_vocoder.load_model(Path(input.vocoder.value))
|
106 |
+
|
107 |
+
# load file
|
108 |
+
if input.record_audio_file != None:
|
109 |
+
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
110 |
+
f.write(input.record_audio_file.as_bytes())
|
111 |
+
f.seek(0)
|
112 |
+
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
113 |
+
elif input.upload_audio_file != None:
|
114 |
+
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
115 |
+
f.write(input.upload_audio_file.as_bytes())
|
116 |
+
f.seek(0)
|
117 |
+
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
118 |
+
else:
|
119 |
+
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
120 |
+
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
121 |
+
|
122 |
+
source_spec = Synthesizer.make_spectrogram(wav)
|
123 |
+
|
124 |
+
# preprocess
|
125 |
+
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
126 |
+
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
127 |
+
|
128 |
+
# Load input text
|
129 |
+
texts = filter(None, input.message.split("\n"))
|
130 |
+
punctuation = '๏ผ๏ผใใ,' # punctuate and split/clean text
|
131 |
+
processed_texts = []
|
132 |
+
for text in texts:
|
133 |
+
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
134 |
+
if processed_text:
|
135 |
+
processed_texts.append(processed_text.strip())
|
136 |
+
texts = processed_texts
|
137 |
+
|
138 |
+
# synthesize and vocode
|
139 |
+
embeds = [embed] * len(texts)
|
140 |
+
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
141 |
+
spec = np.concatenate(specs, axis=1)
|
142 |
+
sample_rate = Synthesizer.sample_rate
|
143 |
+
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
144 |
+
|
145 |
+
# write and output
|
146 |
+
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
147 |
+
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
148 |
+
source_file = f.read()
|
149 |
+
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
150 |
+
result_file = f.read()
|
151 |
+
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
control/mkgui/app_vc.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from enum import Enum
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any, Tuple
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import torch
|
9 |
+
from pydantic import BaseModel, Field
|
10 |
+
from scipy.io.wavfile import write
|
11 |
+
|
12 |
+
import models.ppg2mel as Convertor
|
13 |
+
import models.ppg_extractor as Extractor
|
14 |
+
from control.mkgui.base.components.types import FileContent
|
15 |
+
from models.encoder import inference as speacker_encoder
|
16 |
+
from models.synthesizer.inference import Synthesizer
|
17 |
+
from models.vocoder.hifigan import inference as gan_vocoder
|
18 |
+
|
19 |
+
# Constants
|
20 |
+
AUDIO_SAMPLES_DIR = f'data{os.sep}samples{os.sep}'
|
21 |
+
EXT_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg_extractor'
|
22 |
+
CONV_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg2mel'
|
23 |
+
VOC_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}vocoder'
|
24 |
+
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
|
25 |
+
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
|
26 |
+
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
|
27 |
+
|
28 |
+
# Load local sample audio as options TODO: load dataset
|
29 |
+
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
30 |
+
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
31 |
+
# Pre-Load models
|
32 |
+
if os.path.isdir(EXT_MODELS_DIRT):
|
33 |
+
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
34 |
+
print("Loaded extractor models: " + str(len(extractors)))
|
35 |
+
else:
|
36 |
+
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
37 |
+
|
38 |
+
if os.path.isdir(CONV_MODELS_DIRT):
|
39 |
+
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
40 |
+
print("Loaded convertor models: " + str(len(convertors)))
|
41 |
+
else:
|
42 |
+
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
43 |
+
|
44 |
+
if os.path.isdir(VOC_MODELS_DIRT):
|
45 |
+
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
46 |
+
print("Loaded vocoders models: " + str(len(vocoders)))
|
47 |
+
else:
|
48 |
+
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
49 |
+
|
50 |
+
class Input(BaseModel):
|
51 |
+
local_audio_file: audio_input_selection = Field(
|
52 |
+
..., alias="่พๅ
ฅ่ฏญ้ณ๏ผๆฌๅฐwav๏ผ",
|
53 |
+
description="้ๆฉๆฌๅฐ่ฏญ้ณๆไปถ."
|
54 |
+
)
|
55 |
+
upload_audio_file: FileContent = Field(default=None, alias="ๆไธไผ ่ฏญ้ณ",
|
56 |
+
description="ๆๆฝๆ็นๅปไธไผ .", mime_type="audio/wav")
|
57 |
+
local_audio_file_target: audio_input_selection = Field(
|
58 |
+
..., alias="็ฎๆ ่ฏญ้ณ๏ผๆฌๅฐwav๏ผ",
|
59 |
+
description="้ๆฉๆฌๅฐ่ฏญ้ณๆไปถ."
|
60 |
+
)
|
61 |
+
upload_audio_file_target: FileContent = Field(default=None, alias="ๆไธไผ ็ฎๆ ่ฏญ้ณ",
|
62 |
+
description="ๆๆฝๆ็นๅปไธไผ .", mime_type="audio/wav")
|
63 |
+
extractor: extractors = Field(
|
64 |
+
..., alias="็ผ็ ๆจกๅ",
|
65 |
+
description="้ๆฉ่ฏญ้ณ็ผ็ ๆจกๅๆไปถ."
|
66 |
+
)
|
67 |
+
convertor: convertors = Field(
|
68 |
+
..., alias="่ฝฌๆขๆจกๅ",
|
69 |
+
description="้ๆฉ่ฏญ้ณ่ฝฌๆขๆจกๅๆไปถ."
|
70 |
+
)
|
71 |
+
vocoder: vocoders = Field(
|
72 |
+
..., alias="่ฏญ้ณ่งฃ็ ๆจกๅ",
|
73 |
+
description="้ๆฉ่ฏญ้ณ่งฃ็ ๆจกๅๆไปถ(็ฎๅๅชๆฏๆHifiGan็ฑปๅ)."
|
74 |
+
)
|
75 |
+
|
76 |
+
class AudioEntity(BaseModel):
|
77 |
+
content: bytes
|
78 |
+
mel: Any
|
79 |
+
|
80 |
+
class Output(BaseModel):
|
81 |
+
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
|
82 |
+
|
83 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
84 |
+
"""Custom output UI.
|
85 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
86 |
+
"""
|
87 |
+
src, target, result = self.__root__
|
88 |
+
|
89 |
+
streamlit_app.subheader("Synthesized Audio")
|
90 |
+
streamlit_app.audio(result.content, format="audio/wav")
|
91 |
+
|
92 |
+
fig, ax = plt.subplots()
|
93 |
+
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
94 |
+
ax.set_title("mel spectrogram(Source Audio)")
|
95 |
+
streamlit_app.pyplot(fig)
|
96 |
+
fig, ax = plt.subplots()
|
97 |
+
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
98 |
+
ax.set_title("mel spectrogram(Target Audio)")
|
99 |
+
streamlit_app.pyplot(fig)
|
100 |
+
fig, ax = plt.subplots()
|
101 |
+
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
102 |
+
ax.set_title("mel spectrogram(Result Audio)")
|
103 |
+
streamlit_app.pyplot(fig)
|
104 |
+
|
105 |
+
def convert(input: Input) -> Output:
|
106 |
+
"""convert(่ฝฌๆข)"""
|
107 |
+
# load models
|
108 |
+
extractor = Extractor.load_model(Path(input.extractor.value))
|
109 |
+
convertor = Convertor.load_model(Path(input.convertor.value))
|
110 |
+
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
111 |
+
gan_vocoder.load_model(Path(input.vocoder.value))
|
112 |
+
|
113 |
+
# load file
|
114 |
+
if input.upload_audio_file != None:
|
115 |
+
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
116 |
+
f.write(input.upload_audio_file.as_bytes())
|
117 |
+
f.seek(0)
|
118 |
+
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
119 |
+
else:
|
120 |
+
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
121 |
+
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
122 |
+
|
123 |
+
if input.upload_audio_file_target != None:
|
124 |
+
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
125 |
+
f.write(input.upload_audio_file_target.as_bytes())
|
126 |
+
f.seek(0)
|
127 |
+
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
128 |
+
else:
|
129 |
+
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
130 |
+
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
131 |
+
|
132 |
+
ppg = extractor.extract_from_wav(src_wav)
|
133 |
+
# Import necessary dependency of Voice Conversion
|
134 |
+
from utils.f0_utils import (compute_f0, compute_mean_std, f02lf0,
|
135 |
+
get_converted_lf0uv)
|
136 |
+
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
137 |
+
speacker_encoder.load_model(Path(f"data{os.sep}ckpt{os.sep}encoder{os.sep}pretrained_bak_5805000.pt"))
|
138 |
+
embed = speacker_encoder.embed_utterance(ref_wav)
|
139 |
+
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
140 |
+
min_len = min(ppg.shape[1], len(lf0_uv))
|
141 |
+
ppg = ppg[:, :min_len]
|
142 |
+
lf0_uv = lf0_uv[:min_len]
|
143 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
144 |
+
_, mel_pred, att_ws = convertor.inference(
|
145 |
+
ppg,
|
146 |
+
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
147 |
+
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
148 |
+
)
|
149 |
+
mel_pred= mel_pred.transpose(0, 1)
|
150 |
+
breaks = [mel_pred.shape[1]]
|
151 |
+
mel_pred= mel_pred.detach().cpu().numpy()
|
152 |
+
|
153 |
+
# synthesize and vocode
|
154 |
+
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
155 |
+
|
156 |
+
# write and output
|
157 |
+
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
158 |
+
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
159 |
+
source_file = f.read()
|
160 |
+
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
161 |
+
target_file = f.read()
|
162 |
+
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
163 |
+
result_file = f.read()
|
164 |
+
|
165 |
+
|
166 |
+
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
control/mkgui/base/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .core import Opyrator
|
control/mkgui/base/api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .fastapi_app import create_api
|
control/mkgui/base/api/fastapi_utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Collection of utilities for FastAPI apps."""
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Any, Type
|
5 |
+
|
6 |
+
from fastapi import FastAPI, Form
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
|
10 |
+
def as_form(cls: Type[BaseModel]) -> Any:
|
11 |
+
"""Adds an as_form class method to decorated models.
|
12 |
+
|
13 |
+
The as_form class method can be used with FastAPI endpoints
|
14 |
+
"""
|
15 |
+
new_params = [
|
16 |
+
inspect.Parameter(
|
17 |
+
field.alias,
|
18 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
19 |
+
default=(Form(field.default) if not field.required else Form(...)),
|
20 |
+
)
|
21 |
+
for field in cls.__fields__.values()
|
22 |
+
]
|
23 |
+
|
24 |
+
async def _as_form(**data): # type: ignore
|
25 |
+
return cls(**data)
|
26 |
+
|
27 |
+
sig = inspect.signature(_as_form)
|
28 |
+
sig = sig.replace(parameters=new_params)
|
29 |
+
_as_form.__signature__ = sig # type: ignore
|
30 |
+
setattr(cls, "as_form", _as_form)
|
31 |
+
return cls
|
32 |
+
|
33 |
+
|
34 |
+
def patch_fastapi(app: FastAPI) -> None:
|
35 |
+
"""Patch function to allow relative url resolution.
|
36 |
+
|
37 |
+
This patch is required to make fastapi fully functional with a relative url path.
|
38 |
+
This code snippet can be copy-pasted to any Fastapi application.
|
39 |
+
"""
|
40 |
+
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
41 |
+
from starlette.requests import Request
|
42 |
+
from starlette.responses import HTMLResponse
|
43 |
+
|
44 |
+
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
45 |
+
assert app.openapi_url is not None
|
46 |
+
redoc_ui = get_redoc_html(
|
47 |
+
openapi_url="./" + app.openapi_url.lstrip("/"),
|
48 |
+
title=app.title + " - Redoc UI",
|
49 |
+
)
|
50 |
+
|
51 |
+
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
52 |
+
|
53 |
+
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
54 |
+
assert app.openapi_url is not None
|
55 |
+
swagger_ui = get_swagger_ui_html(
|
56 |
+
openapi_url="./" + app.openapi_url.lstrip("/"),
|
57 |
+
title=app.title + " - Swagger UI",
|
58 |
+
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
59 |
+
)
|
60 |
+
|
61 |
+
# insert request interceptor to have all request run on relativ path
|
62 |
+
request_interceptor = (
|
63 |
+
"requestInterceptor: (e) => {"
|
64 |
+
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
65 |
+
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
66 |
+
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
67 |
+
"\n\t\t\te.contextUrl = url"
|
68 |
+
"\n\t\t\te.url = url"
|
69 |
+
"\n\t\t\treturn e;}"
|
70 |
+
)
|
71 |
+
|
72 |
+
return HTMLResponse(
|
73 |
+
swagger_ui.body.decode("utf-8").replace(
|
74 |
+
"dom_id: '#swagger-ui',",
|
75 |
+
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
76 |
+
)
|
77 |
+
)
|
78 |
+
|
79 |
+
# remove old docs route and add our patched route
|
80 |
+
routes_new = []
|
81 |
+
for app_route in app.routes:
|
82 |
+
if app_route.path == "/docs": # type: ignore
|
83 |
+
continue
|
84 |
+
|
85 |
+
if app_route.path == "/redoc": # type: ignore
|
86 |
+
continue
|
87 |
+
|
88 |
+
routes_new.append(app_route)
|
89 |
+
|
90 |
+
app.router.routes = routes_new
|
91 |
+
|
92 |
+
assert app.docs_url is not None
|
93 |
+
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
94 |
+
assert app.redoc_url is not None
|
95 |
+
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
96 |
+
|
97 |
+
# Make graphql realtive
|
98 |
+
from starlette import graphql
|
99 |
+
|
100 |
+
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
101 |
+
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
102 |
+
)
|
control/mkgui/base/components/__init__.py
ADDED
File without changes
|
control/mkgui/base/components/outputs.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class ScoredLabel(BaseModel):
|
7 |
+
label: str
|
8 |
+
score: float
|
9 |
+
|
10 |
+
|
11 |
+
class ClassificationOutput(BaseModel):
|
12 |
+
__root__: List[ScoredLabel]
|
13 |
+
|
14 |
+
def __iter__(self): # type: ignore
|
15 |
+
return iter(self.__root__)
|
16 |
+
|
17 |
+
def __getitem__(self, item): # type: ignore
|
18 |
+
return self.__root__[item]
|
19 |
+
|
20 |
+
def render_output_ui(self, streamlit) -> None: # type: ignore
|
21 |
+
import plotly.express as px
|
22 |
+
|
23 |
+
sorted_predictions = sorted(
|
24 |
+
[prediction.dict() for prediction in self.__root__],
|
25 |
+
key=lambda k: k["score"],
|
26 |
+
)
|
27 |
+
|
28 |
+
num_labels = len(sorted_predictions)
|
29 |
+
if len(sorted_predictions) > 10:
|
30 |
+
num_labels = streamlit.slider(
|
31 |
+
"Maximum labels to show: ",
|
32 |
+
min_value=1,
|
33 |
+
max_value=len(sorted_predictions),
|
34 |
+
value=len(sorted_predictions),
|
35 |
+
)
|
36 |
+
fig = px.bar(
|
37 |
+
sorted_predictions[len(sorted_predictions) - num_labels :],
|
38 |
+
x="score",
|
39 |
+
y="label",
|
40 |
+
orientation="h",
|
41 |
+
)
|
42 |
+
streamlit.plotly_chart(fig, use_container_width=True)
|
43 |
+
# fig.show()
|
control/mkgui/base/components/types.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from typing import Any, Dict, overload
|
3 |
+
|
4 |
+
|
5 |
+
class FileContent(str):
|
6 |
+
def as_bytes(self) -> bytes:
|
7 |
+
return base64.b64decode(self, validate=True)
|
8 |
+
|
9 |
+
def as_str(self) -> str:
|
10 |
+
return self.as_bytes().decode()
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
14 |
+
field_schema.update(format="byte")
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def __get_validators__(cls) -> Any: # type: ignore
|
18 |
+
yield cls.validate
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def validate(cls, value: Any) -> "FileContent":
|
22 |
+
if isinstance(value, FileContent):
|
23 |
+
return value
|
24 |
+
elif isinstance(value, str):
|
25 |
+
return FileContent(value)
|
26 |
+
elif isinstance(value, (bytes, bytearray, memoryview)):
|
27 |
+
return FileContent(base64.b64encode(value).decode())
|
28 |
+
else:
|
29 |
+
raise Exception("Wrong type")
|
30 |
+
|
31 |
+
# # ๆๆถๆ ๆณไฝฟ็จ๏ผๅ ไธบๆต่งๅจไธญๆฒกๆ่่้ๆฉๆไปถๅคน
|
32 |
+
# class DirectoryContent(FileContent):
|
33 |
+
# @classmethod
|
34 |
+
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
35 |
+
# field_schema.update(format="path")
|
36 |
+
|
37 |
+
# @classmethod
|
38 |
+
# def validate(cls, value: Any) -> "DirectoryContent":
|
39 |
+
# if isinstance(value, DirectoryContent):
|
40 |
+
# return value
|
41 |
+
# elif isinstance(value, str):
|
42 |
+
# return DirectoryContent(value)
|
43 |
+
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
44 |
+
# return DirectoryContent(base64.b64encode(value).decode())
|
45 |
+
# else:
|
46 |
+
# raise Exception("Wrong type")
|
control/mkgui/base/core.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import inspect
|
3 |
+
import re
|
4 |
+
from typing import Any, Callable, Type, Union, get_type_hints
|
5 |
+
|
6 |
+
from pydantic import BaseModel, parse_raw_as
|
7 |
+
from pydantic.tools import parse_obj_as
|
8 |
+
|
9 |
+
|
10 |
+
def name_to_title(name: str) -> str:
|
11 |
+
"""Converts a camelCase or snake_case name to title case."""
|
12 |
+
# If camelCase -> convert to snake case
|
13 |
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
14 |
+
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
15 |
+
# Convert to title case
|
16 |
+
return name.replace("_", " ").strip().title()
|
17 |
+
|
18 |
+
|
19 |
+
def is_compatible_type(type: Type) -> bool:
|
20 |
+
"""Returns `True` if the type is opyrator-compatible."""
|
21 |
+
try:
|
22 |
+
if issubclass(type, BaseModel):
|
23 |
+
return True
|
24 |
+
except Exception:
|
25 |
+
pass
|
26 |
+
|
27 |
+
try:
|
28 |
+
# valid list type
|
29 |
+
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
30 |
+
return True
|
31 |
+
except Exception:
|
32 |
+
pass
|
33 |
+
|
34 |
+
return False
|
35 |
+
|
36 |
+
|
37 |
+
def get_input_type(func: Callable) -> Type:
|
38 |
+
"""Returns the input type of a given function (callable).
|
39 |
+
|
40 |
+
Args:
|
41 |
+
func: The function for which to get the input type.
|
42 |
+
|
43 |
+
Raises:
|
44 |
+
ValueError: If the function does not have a valid input type annotation.
|
45 |
+
"""
|
46 |
+
type_hints = get_type_hints(func)
|
47 |
+
|
48 |
+
if "input" not in type_hints:
|
49 |
+
raise ValueError(
|
50 |
+
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
51 |
+
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
52 |
+
)
|
53 |
+
|
54 |
+
input_type = type_hints["input"]
|
55 |
+
|
56 |
+
if not is_compatible_type(input_type):
|
57 |
+
raise ValueError(
|
58 |
+
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
59 |
+
)
|
60 |
+
|
61 |
+
# TODO: return warning if more than one input parameters
|
62 |
+
|
63 |
+
return input_type
|
64 |
+
|
65 |
+
|
66 |
+
def get_output_type(func: Callable) -> Type:
|
67 |
+
"""Returns the output type of a given function (callable).
|
68 |
+
|
69 |
+
Args:
|
70 |
+
func: The function for which to get the output type.
|
71 |
+
|
72 |
+
Raises:
|
73 |
+
ValueError: If the function does not have a valid output type annotation.
|
74 |
+
"""
|
75 |
+
type_hints = get_type_hints(func)
|
76 |
+
if "return" not in type_hints:
|
77 |
+
raise ValueError(
|
78 |
+
"The return type of the callable MUST be annotated with type hints."
|
79 |
+
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
80 |
+
)
|
81 |
+
|
82 |
+
output_type = type_hints["return"]
|
83 |
+
|
84 |
+
if not is_compatible_type(output_type):
|
85 |
+
raise ValueError(
|
86 |
+
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
87 |
+
)
|
88 |
+
|
89 |
+
return output_type
|
90 |
+
|
91 |
+
|
92 |
+
def get_callable(import_string: str) -> Callable:
|
93 |
+
"""Import a callable from an string."""
|
94 |
+
callable_seperator = ":"
|
95 |
+
if callable_seperator not in import_string:
|
96 |
+
# Use dot as seperator
|
97 |
+
callable_seperator = "."
|
98 |
+
|
99 |
+
if callable_seperator not in import_string:
|
100 |
+
raise ValueError("The callable path MUST specify the function. ")
|
101 |
+
|
102 |
+
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
103 |
+
mod = importlib.import_module(mod_name)
|
104 |
+
return getattr(mod, callable_name)
|
105 |
+
|
106 |
+
|
107 |
+
class Opyrator:
|
108 |
+
def __init__(self, func: Union[Callable, str]) -> None:
|
109 |
+
if isinstance(func, str):
|
110 |
+
# Try to load the function from a string notion
|
111 |
+
self.function = get_callable(func)
|
112 |
+
else:
|
113 |
+
self.function = func
|
114 |
+
|
115 |
+
self._action = "Execute"
|
116 |
+
self._input_type = None
|
117 |
+
self._output_type = None
|
118 |
+
|
119 |
+
if not callable(self.function):
|
120 |
+
raise ValueError("The provided function parameters is not a callable.")
|
121 |
+
|
122 |
+
if inspect.isclass(self.function):
|
123 |
+
raise ValueError(
|
124 |
+
"The provided callable is an uninitialized Class. This is not allowed."
|
125 |
+
)
|
126 |
+
|
127 |
+
if inspect.isfunction(self.function):
|
128 |
+
# The provided callable is a function
|
129 |
+
self._input_type = get_input_type(self.function)
|
130 |
+
self._output_type = get_output_type(self.function)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# Get name
|
134 |
+
self._name = name_to_title(self.function.__name__)
|
135 |
+
except Exception:
|
136 |
+
pass
|
137 |
+
|
138 |
+
try:
|
139 |
+
# Get description from function
|
140 |
+
doc_string = inspect.getdoc(self.function)
|
141 |
+
if doc_string:
|
142 |
+
self._action = doc_string
|
143 |
+
except Exception:
|
144 |
+
pass
|
145 |
+
elif hasattr(self.function, "__call__"):
|
146 |
+
# The provided callable is a function
|
147 |
+
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
148 |
+
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
149 |
+
|
150 |
+
try:
|
151 |
+
# Get name
|
152 |
+
self._name = name_to_title(type(self.function).__name__)
|
153 |
+
except Exception:
|
154 |
+
pass
|
155 |
+
|
156 |
+
try:
|
157 |
+
# Get action from
|
158 |
+
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
159 |
+
if doc_string:
|
160 |
+
self._action = doc_string
|
161 |
+
|
162 |
+
if (
|
163 |
+
not self._action
|
164 |
+
or self._action == "Call"
|
165 |
+
):
|
166 |
+
# Get docstring from class instead of __call__ function
|
167 |
+
doc_string = inspect.getdoc(self.function)
|
168 |
+
if doc_string:
|
169 |
+
self._action = doc_string
|
170 |
+
except Exception:
|
171 |
+
pass
|
172 |
+
else:
|
173 |
+
raise ValueError("Unknown callable type.")
|
174 |
+
|
175 |
+
@property
|
176 |
+
def name(self) -> str:
|
177 |
+
return self._name
|
178 |
+
|
179 |
+
@property
|
180 |
+
def action(self) -> str:
|
181 |
+
return self._action
|
182 |
+
|
183 |
+
@property
|
184 |
+
def input_type(self) -> Any:
|
185 |
+
return self._input_type
|
186 |
+
|
187 |
+
@property
|
188 |
+
def output_type(self) -> Any:
|
189 |
+
return self._output_type
|
190 |
+
|
191 |
+
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
192 |
+
|
193 |
+
input_obj = input
|
194 |
+
|
195 |
+
if isinstance(input, str):
|
196 |
+
# Allow json input
|
197 |
+
input_obj = parse_raw_as(self.input_type, input)
|
198 |
+
|
199 |
+
if isinstance(input, dict):
|
200 |
+
# Allow dict input
|
201 |
+
input_obj = parse_obj_as(self.input_type, input)
|
202 |
+
|
203 |
+
return self.function(input_obj, **kwargs)
|
control/mkgui/base/ui/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .streamlit_ui import render_streamlit_ui
|
control/mkgui/base/ui/schema_utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
|
4 |
+
def resolve_reference(reference: str, references: Dict) -> Dict:
|
5 |
+
return references[reference.split("/")[-1]]
|
6 |
+
|
7 |
+
|
8 |
+
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
9 |
+
# Ref can either be directly in the properties or the first element of allOf
|
10 |
+
reference = property.get("$ref")
|
11 |
+
if reference is None:
|
12 |
+
reference = property["allOf"][0]["$ref"]
|
13 |
+
return resolve_reference(reference, references)
|
14 |
+
|
15 |
+
|
16 |
+
def is_single_string_property(property: Dict) -> bool:
|
17 |
+
return property.get("type") == "string"
|
18 |
+
|
19 |
+
|
20 |
+
def is_single_datetime_property(property: Dict) -> bool:
|
21 |
+
if property.get("type") != "string":
|
22 |
+
return False
|
23 |
+
return property.get("format") in ["date-time", "time", "date"]
|
24 |
+
|
25 |
+
|
26 |
+
def is_single_boolean_property(property: Dict) -> bool:
|
27 |
+
return property.get("type") == "boolean"
|
28 |
+
|
29 |
+
|
30 |
+
def is_single_number_property(property: Dict) -> bool:
|
31 |
+
return property.get("type") in ["integer", "number"]
|
32 |
+
|
33 |
+
|
34 |
+
def is_single_file_property(property: Dict) -> bool:
|
35 |
+
if property.get("type") != "string":
|
36 |
+
return False
|
37 |
+
# TODO: binary?
|
38 |
+
return property.get("format") == "byte"
|
39 |
+
|
40 |
+
def is_single_autio_property(property: Dict) -> bool:
|
41 |
+
if property.get("type") != "string":
|
42 |
+
return False
|
43 |
+
# TODO: binary?
|
44 |
+
return property.get("format") == "bytes"
|
45 |
+
|
46 |
+
|
47 |
+
def is_single_directory_property(property: Dict) -> bool:
|
48 |
+
if property.get("type") != "string":
|
49 |
+
return False
|
50 |
+
return property.get("format") == "path"
|
51 |
+
|
52 |
+
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
53 |
+
if property.get("type") != "array":
|
54 |
+
return False
|
55 |
+
|
56 |
+
if property.get("uniqueItems") is not True:
|
57 |
+
# Only relevant if it is a set or other datastructures with unique items
|
58 |
+
return False
|
59 |
+
|
60 |
+
try:
|
61 |
+
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
62 |
+
return True
|
63 |
+
except Exception:
|
64 |
+
return False
|
65 |
+
|
66 |
+
|
67 |
+
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
68 |
+
try:
|
69 |
+
_ = get_single_reference_item(property, references)["enum"]
|
70 |
+
return True
|
71 |
+
except Exception:
|
72 |
+
return False
|
73 |
+
|
74 |
+
|
75 |
+
def is_single_dict_property(property: Dict) -> bool:
|
76 |
+
if property.get("type") != "object":
|
77 |
+
return False
|
78 |
+
return "additionalProperties" in property
|
79 |
+
|
80 |
+
|
81 |
+
def is_single_reference(property: Dict) -> bool:
|
82 |
+
if property.get("type") is not None:
|
83 |
+
return False
|
84 |
+
|
85 |
+
return bool(property.get("$ref"))
|
86 |
+
|
87 |
+
|
88 |
+
def is_multi_file_property(property: Dict) -> bool:
|
89 |
+
if property.get("type") != "array":
|
90 |
+
return False
|
91 |
+
|
92 |
+
if property.get("items") is None:
|
93 |
+
return False
|
94 |
+
|
95 |
+
try:
|
96 |
+
# TODO: binary
|
97 |
+
return property["items"]["format"] == "byte"
|
98 |
+
except Exception:
|
99 |
+
return False
|
100 |
+
|
101 |
+
|
102 |
+
def is_single_object(property: Dict, references: Dict) -> bool:
|
103 |
+
try:
|
104 |
+
object_reference = get_single_reference_item(property, references)
|
105 |
+
if object_reference["type"] != "object":
|
106 |
+
return False
|
107 |
+
return "properties" in object_reference
|
108 |
+
except Exception:
|
109 |
+
return False
|
110 |
+
|
111 |
+
|
112 |
+
def is_property_list(property: Dict) -> bool:
|
113 |
+
if property.get("type") != "array":
|
114 |
+
return False
|
115 |
+
|
116 |
+
if property.get("items") is None:
|
117 |
+
return False
|
118 |
+
|
119 |
+
try:
|
120 |
+
return property["items"]["type"] in ["string", "number", "integer"]
|
121 |
+
except Exception:
|
122 |
+
return False
|
123 |
+
|
124 |
+
|
125 |
+
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
126 |
+
if property.get("type") != "array":
|
127 |
+
return False
|
128 |
+
|
129 |
+
try:
|
130 |
+
object_reference = resolve_reference(property["items"]["$ref"], references)
|
131 |
+
if object_reference["type"] != "object":
|
132 |
+
return False
|
133 |
+
return "properties" in object_reference
|
134 |
+
except Exception:
|
135 |
+
return False
|
control/mkgui/base/ui/streamlit_ui.py
ADDED
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import inspect
|
3 |
+
import mimetypes
|
4 |
+
import sys
|
5 |
+
from os import getcwd, unlink, path
|
6 |
+
from platform import system
|
7 |
+
from tempfile import NamedTemporaryFile
|
8 |
+
from typing import Any, Callable, Dict, List, Type
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
import streamlit as st
|
13 |
+
from fastapi.encoders import jsonable_encoder
|
14 |
+
from loguru import logger
|
15 |
+
from pydantic import BaseModel, ValidationError, parse_obj_as
|
16 |
+
|
17 |
+
from control.mkgui.base import Opyrator
|
18 |
+
from control.mkgui.base.core import name_to_title
|
19 |
+
from . import schema_utils
|
20 |
+
from .streamlit_utils import CUSTOM_STREAMLIT_CSS
|
21 |
+
|
22 |
+
STREAMLIT_RUNNER_SNIPPET = """
|
23 |
+
from control.mkgui.base.ui import render_streamlit_ui
|
24 |
+
|
25 |
+
import streamlit as st
|
26 |
+
|
27 |
+
# TODO: Make it configurable
|
28 |
+
# Page config can only be setup once
|
29 |
+
st.set_page_config(
|
30 |
+
page_title="MockingBird",
|
31 |
+
page_icon="๐ง",
|
32 |
+
layout="wide")
|
33 |
+
|
34 |
+
render_streamlit_ui()
|
35 |
+
"""
|
36 |
+
|
37 |
+
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
38 |
+
# opyrator = Opyrator("{opyrator_path}")
|
39 |
+
|
40 |
+
|
41 |
+
def launch_ui(port: int = 8501) -> None:
|
42 |
+
with NamedTemporaryFile(
|
43 |
+
suffix=".py", mode="w", encoding="utf-8", delete=False
|
44 |
+
) as f:
|
45 |
+
f.write(STREAMLIT_RUNNER_SNIPPET)
|
46 |
+
f.seek(0)
|
47 |
+
|
48 |
+
import subprocess
|
49 |
+
|
50 |
+
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
51 |
+
if system() == "Windows":
|
52 |
+
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
53 |
+
subprocess.run(
|
54 |
+
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
55 |
+
shell=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
subprocess.run(
|
59 |
+
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
60 |
+
shell=True,
|
61 |
+
)
|
62 |
+
|
63 |
+
f.close()
|
64 |
+
unlink(f.name)
|
65 |
+
|
66 |
+
|
67 |
+
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
68 |
+
try:
|
69 |
+
sig = inspect.signature(func)
|
70 |
+
for param in sig.parameters.values():
|
71 |
+
if param.name == "input":
|
72 |
+
return True
|
73 |
+
except Exception:
|
74 |
+
return False
|
75 |
+
return False
|
76 |
+
|
77 |
+
|
78 |
+
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
79 |
+
return hasattr(data_item, "render_output_ui")
|
80 |
+
|
81 |
+
|
82 |
+
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
83 |
+
return hasattr(input_class, "render_input_ui")
|
84 |
+
|
85 |
+
|
86 |
+
def is_compatible_audio(mime_type: str) -> bool:
|
87 |
+
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
88 |
+
|
89 |
+
|
90 |
+
def is_compatible_image(mime_type: str) -> bool:
|
91 |
+
return mime_type in ["image/png", "image/jpeg"]
|
92 |
+
|
93 |
+
|
94 |
+
def is_compatible_video(mime_type: str) -> bool:
|
95 |
+
return mime_type in ["video/mp4"]
|
96 |
+
|
97 |
+
|
98 |
+
class InputUI:
|
99 |
+
def __init__(self, session_state, input_class: Type[BaseModel]):
|
100 |
+
self._session_state = session_state
|
101 |
+
self._input_class = input_class
|
102 |
+
|
103 |
+
self._schema_properties = input_class.schema(by_alias=True).get(
|
104 |
+
"properties", {}
|
105 |
+
)
|
106 |
+
self._schema_references = input_class.schema(by_alias=True).get(
|
107 |
+
"definitions", {}
|
108 |
+
)
|
109 |
+
|
110 |
+
def render_ui(self, streamlit_app_root) -> None:
|
111 |
+
if has_input_ui_renderer(self._input_class):
|
112 |
+
# The input model has a rendering function
|
113 |
+
# The rendering also returns the current state of input data
|
114 |
+
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
115 |
+
st, self._session_state.input_data
|
116 |
+
)
|
117 |
+
return
|
118 |
+
|
119 |
+
# print(self._schema_properties)
|
120 |
+
for property_key in self._schema_properties.keys():
|
121 |
+
property = self._schema_properties[property_key]
|
122 |
+
|
123 |
+
if not property.get("title"):
|
124 |
+
# Set property key as fallback title
|
125 |
+
property["title"] = name_to_title(property_key)
|
126 |
+
|
127 |
+
try:
|
128 |
+
if "input_data" in self._session_state:
|
129 |
+
self._store_value(
|
130 |
+
property_key,
|
131 |
+
self._render_property(streamlit_app_root, property_key, property),
|
132 |
+
)
|
133 |
+
except Exception as e:
|
134 |
+
print("Exception!", e)
|
135 |
+
pass
|
136 |
+
|
137 |
+
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
138 |
+
streamlit_kwargs = {
|
139 |
+
"label": property.get("title"),
|
140 |
+
"key": key,
|
141 |
+
}
|
142 |
+
|
143 |
+
if property.get("description"):
|
144 |
+
streamlit_kwargs["help"] = property.get("description")
|
145 |
+
return streamlit_kwargs
|
146 |
+
|
147 |
+
def _store_value(self, key: str, value: Any) -> None:
|
148 |
+
data_element = self._session_state.input_data
|
149 |
+
key_elements = key.split(".")
|
150 |
+
for i, key_element in enumerate(key_elements):
|
151 |
+
if i == len(key_elements) - 1:
|
152 |
+
# add value to this element
|
153 |
+
data_element[key_element] = value
|
154 |
+
return
|
155 |
+
if key_element not in data_element:
|
156 |
+
data_element[key_element] = {}
|
157 |
+
data_element = data_element[key_element]
|
158 |
+
|
159 |
+
def _get_value(self, key: str) -> Any:
|
160 |
+
data_element = self._session_state.input_data
|
161 |
+
key_elements = key.split(".")
|
162 |
+
for i, key_element in enumerate(key_elements):
|
163 |
+
if i == len(key_elements) - 1:
|
164 |
+
# add value to this element
|
165 |
+
if key_element not in data_element:
|
166 |
+
return None
|
167 |
+
return data_element[key_element]
|
168 |
+
if key_element not in data_element:
|
169 |
+
data_element[key_element] = {}
|
170 |
+
data_element = data_element[key_element]
|
171 |
+
return None
|
172 |
+
|
173 |
+
def _render_single_datetime_input(
|
174 |
+
self, streamlit_app: st, key: str, property: Dict
|
175 |
+
) -> Any:
|
176 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
177 |
+
|
178 |
+
if property.get("format") == "time":
|
179 |
+
if property.get("default"):
|
180 |
+
try:
|
181 |
+
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
182 |
+
property.get("default")
|
183 |
+
)
|
184 |
+
except Exception:
|
185 |
+
pass
|
186 |
+
return streamlit_app.time_input(**streamlit_kwargs)
|
187 |
+
elif property.get("format") == "date":
|
188 |
+
if property.get("default"):
|
189 |
+
try:
|
190 |
+
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
191 |
+
property.get("default")
|
192 |
+
)
|
193 |
+
except Exception:
|
194 |
+
pass
|
195 |
+
return streamlit_app.date_input(**streamlit_kwargs)
|
196 |
+
elif property.get("format") == "date-time":
|
197 |
+
if property.get("default"):
|
198 |
+
try:
|
199 |
+
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
200 |
+
property.get("default")
|
201 |
+
)
|
202 |
+
except Exception:
|
203 |
+
pass
|
204 |
+
with streamlit_app.container():
|
205 |
+
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
206 |
+
if streamlit_kwargs.get("description"):
|
207 |
+
streamlit_app.text(streamlit_kwargs.get("description"))
|
208 |
+
selected_date = None
|
209 |
+
selected_time = None
|
210 |
+
date_col, time_col = streamlit_app.columns(2)
|
211 |
+
with date_col:
|
212 |
+
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
213 |
+
if streamlit_kwargs.get("value"):
|
214 |
+
try:
|
215 |
+
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
216 |
+
"value"
|
217 |
+
).date()
|
218 |
+
except Exception:
|
219 |
+
pass
|
220 |
+
selected_date = streamlit_app.date_input(**date_kwargs)
|
221 |
+
|
222 |
+
with time_col:
|
223 |
+
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
224 |
+
if streamlit_kwargs.get("value"):
|
225 |
+
try:
|
226 |
+
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
227 |
+
"value"
|
228 |
+
).time()
|
229 |
+
except Exception:
|
230 |
+
pass
|
231 |
+
selected_time = streamlit_app.time_input(**time_kwargs)
|
232 |
+
return datetime.datetime.combine(selected_date, selected_time)
|
233 |
+
else:
|
234 |
+
streamlit_app.warning(
|
235 |
+
"Date format is not supported: " + str(property.get("format"))
|
236 |
+
)
|
237 |
+
|
238 |
+
def _render_single_file_input(
|
239 |
+
self, streamlit_app: st, key: str, property: Dict
|
240 |
+
) -> Any:
|
241 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
242 |
+
file_extension = None
|
243 |
+
if "mime_type" in property:
|
244 |
+
file_extension = mimetypes.guess_extension(property["mime_type"])
|
245 |
+
|
246 |
+
if "is_recorder" in property:
|
247 |
+
from audio_recorder_streamlit import audio_recorder
|
248 |
+
audio_bytes = audio_recorder()
|
249 |
+
if audio_bytes:
|
250 |
+
streamlit_app.audio(audio_bytes, format="audio/wav")
|
251 |
+
return audio_bytes
|
252 |
+
|
253 |
+
uploaded_file = streamlit_app.file_uploader(
|
254 |
+
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
255 |
+
)
|
256 |
+
if uploaded_file is None:
|
257 |
+
return None
|
258 |
+
|
259 |
+
bytes = uploaded_file.getvalue()
|
260 |
+
if property.get("mime_type"):
|
261 |
+
if is_compatible_audio(property["mime_type"]):
|
262 |
+
# Show audio
|
263 |
+
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
264 |
+
if is_compatible_image(property["mime_type"]):
|
265 |
+
# Show image
|
266 |
+
streamlit_app.image(bytes)
|
267 |
+
if is_compatible_video(property["mime_type"]):
|
268 |
+
# Show video
|
269 |
+
streamlit_app.video(bytes, format=property.get("mime_type"))
|
270 |
+
return bytes
|
271 |
+
|
272 |
+
def _render_single_audio_input(
|
273 |
+
self, streamlit_app: st, key: str, property: Dict
|
274 |
+
) -> Any:
|
275 |
+
# streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
276 |
+
from audio_recorder_streamlit import audio_recorder
|
277 |
+
audio_bytes = audio_recorder()
|
278 |
+
if audio_bytes:
|
279 |
+
streamlit_app.audio(audio_bytes, format="audio/wav")
|
280 |
+
return audio_bytes
|
281 |
+
|
282 |
+
# file_extension = None
|
283 |
+
# if "mime_type" in property:
|
284 |
+
# file_extension = mimetypes.guess_extension(property["mime_type"])
|
285 |
+
|
286 |
+
# uploaded_file = streamlit_app.file_uploader(
|
287 |
+
# **streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
288 |
+
# )
|
289 |
+
# if uploaded_file is None:
|
290 |
+
# return None
|
291 |
+
|
292 |
+
# bytes = uploaded_file.getvalue()
|
293 |
+
# if property.get("mime_type"):
|
294 |
+
# if is_compatible_audio(property["mime_type"]):
|
295 |
+
# # Show audio
|
296 |
+
# streamlit_app.audio(bytes, format=property.get("mime_type"))
|
297 |
+
# if is_compatible_image(property["mime_type"]):
|
298 |
+
# # Show image
|
299 |
+
# streamlit_app.image(bytes)
|
300 |
+
# if is_compatible_video(property["mime_type"]):
|
301 |
+
# # Show video
|
302 |
+
# streamlit_app.video(bytes, format=property.get("mime_type"))
|
303 |
+
# return bytes
|
304 |
+
|
305 |
+
def _render_single_string_input(
|
306 |
+
self, streamlit_app: st, key: str, property: Dict
|
307 |
+
) -> Any:
|
308 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
309 |
+
|
310 |
+
if property.get("default"):
|
311 |
+
streamlit_kwargs["value"] = property.get("default")
|
312 |
+
elif property.get("example"):
|
313 |
+
# TODO: also use example for other property types
|
314 |
+
# Use example as value if it is provided
|
315 |
+
streamlit_kwargs["value"] = property.get("example")
|
316 |
+
|
317 |
+
if property.get("maxLength") is not None:
|
318 |
+
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
319 |
+
|
320 |
+
if (
|
321 |
+
property.get("format")
|
322 |
+
or (
|
323 |
+
property.get("maxLength") is not None
|
324 |
+
and int(property.get("maxLength")) < 140 # type: ignore
|
325 |
+
)
|
326 |
+
or property.get("writeOnly")
|
327 |
+
):
|
328 |
+
# If any format is set, use single text input
|
329 |
+
# If max chars is set to less than 140, use single text input
|
330 |
+
# If write only -> password field
|
331 |
+
if property.get("writeOnly"):
|
332 |
+
streamlit_kwargs["type"] = "password"
|
333 |
+
return streamlit_app.text_input(**streamlit_kwargs)
|
334 |
+
else:
|
335 |
+
# Otherwise use multiline text area
|
336 |
+
return streamlit_app.text_area(**streamlit_kwargs)
|
337 |
+
|
338 |
+
def _render_multi_enum_input(
|
339 |
+
self, streamlit_app: st, key: str, property: Dict
|
340 |
+
) -> Any:
|
341 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
342 |
+
reference_item = schema_utils.resolve_reference(
|
343 |
+
property["items"]["$ref"], self._schema_references
|
344 |
+
)
|
345 |
+
# TODO: how to select defaults
|
346 |
+
return streamlit_app.multiselect(
|
347 |
+
**streamlit_kwargs, options=reference_item["enum"]
|
348 |
+
)
|
349 |
+
|
350 |
+
def _render_single_enum_input(
|
351 |
+
self, streamlit_app: st, key: str, property: Dict
|
352 |
+
) -> Any:
|
353 |
+
|
354 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
355 |
+
reference_item = schema_utils.get_single_reference_item(
|
356 |
+
property, self._schema_references
|
357 |
+
)
|
358 |
+
|
359 |
+
if property.get("default") is not None:
|
360 |
+
try:
|
361 |
+
streamlit_kwargs["index"] = reference_item["enum"].index(
|
362 |
+
property.get("default")
|
363 |
+
)
|
364 |
+
except Exception:
|
365 |
+
# Use default selection
|
366 |
+
pass
|
367 |
+
|
368 |
+
return streamlit_app.selectbox(
|
369 |
+
**streamlit_kwargs, options=reference_item["enum"]
|
370 |
+
)
|
371 |
+
|
372 |
+
def _render_single_dict_input(
|
373 |
+
self, streamlit_app: st, key: str, property: Dict
|
374 |
+
) -> Any:
|
375 |
+
|
376 |
+
# Add title and subheader
|
377 |
+
streamlit_app.subheader(property.get("title"))
|
378 |
+
if property.get("description"):
|
379 |
+
streamlit_app.markdown(property.get("description"))
|
380 |
+
|
381 |
+
streamlit_app.markdown("---")
|
382 |
+
|
383 |
+
current_dict = self._get_value(key)
|
384 |
+
if not current_dict:
|
385 |
+
current_dict = {}
|
386 |
+
|
387 |
+
key_col, value_col = streamlit_app.columns(2)
|
388 |
+
|
389 |
+
with key_col:
|
390 |
+
updated_key = streamlit_app.text_input(
|
391 |
+
"Key", value="", key=key + "-new-key"
|
392 |
+
)
|
393 |
+
|
394 |
+
with value_col:
|
395 |
+
# TODO: also add boolean?
|
396 |
+
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
397 |
+
if property["additionalProperties"].get("type") == "integer":
|
398 |
+
value_kwargs["value"] = 0 # type: ignore
|
399 |
+
updated_value = streamlit_app.number_input(**value_kwargs)
|
400 |
+
elif property["additionalProperties"].get("type") == "number":
|
401 |
+
value_kwargs["value"] = 0.0 # type: ignore
|
402 |
+
value_kwargs["format"] = "%f"
|
403 |
+
updated_value = streamlit_app.number_input(**value_kwargs)
|
404 |
+
else:
|
405 |
+
value_kwargs["value"] = ""
|
406 |
+
updated_value = streamlit_app.text_input(**value_kwargs)
|
407 |
+
|
408 |
+
streamlit_app.markdown("---")
|
409 |
+
|
410 |
+
with streamlit_app.container():
|
411 |
+
clear_col, add_col = streamlit_app.columns([1, 2])
|
412 |
+
|
413 |
+
with clear_col:
|
414 |
+
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
415 |
+
current_dict = {}
|
416 |
+
|
417 |
+
with add_col:
|
418 |
+
if (
|
419 |
+
streamlit_app.button("Add Item", key=key + "-add-item")
|
420 |
+
and updated_key
|
421 |
+
):
|
422 |
+
current_dict[updated_key] = updated_value
|
423 |
+
|
424 |
+
streamlit_app.write(current_dict)
|
425 |
+
|
426 |
+
return current_dict
|
427 |
+
|
428 |
+
def _render_single_reference(
|
429 |
+
self, streamlit_app: st, key: str, property: Dict
|
430 |
+
) -> Any:
|
431 |
+
reference_item = schema_utils.get_single_reference_item(
|
432 |
+
property, self._schema_references
|
433 |
+
)
|
434 |
+
return self._render_property(streamlit_app, key, reference_item)
|
435 |
+
|
436 |
+
def _render_multi_file_input(
|
437 |
+
self, streamlit_app: st, key: str, property: Dict
|
438 |
+
) -> Any:
|
439 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
440 |
+
|
441 |
+
file_extension = None
|
442 |
+
if "mime_type" in property:
|
443 |
+
file_extension = mimetypes.guess_extension(property["mime_type"])
|
444 |
+
|
445 |
+
uploaded_files = streamlit_app.file_uploader(
|
446 |
+
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
447 |
+
)
|
448 |
+
uploaded_files_bytes = []
|
449 |
+
if uploaded_files:
|
450 |
+
for uploaded_file in uploaded_files:
|
451 |
+
uploaded_files_bytes.append(uploaded_file.read())
|
452 |
+
return uploaded_files_bytes
|
453 |
+
|
454 |
+
def _render_single_boolean_input(
|
455 |
+
self, streamlit_app: st, key: str, property: Dict
|
456 |
+
) -> Any:
|
457 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
458 |
+
|
459 |
+
if property.get("default"):
|
460 |
+
streamlit_kwargs["value"] = property.get("default")
|
461 |
+
return streamlit_app.checkbox(**streamlit_kwargs)
|
462 |
+
|
463 |
+
def _render_single_number_input(
|
464 |
+
self, streamlit_app: st, key: str, property: Dict
|
465 |
+
) -> Any:
|
466 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
467 |
+
|
468 |
+
number_transform = int
|
469 |
+
if property.get("type") == "number":
|
470 |
+
number_transform = float # type: ignore
|
471 |
+
streamlit_kwargs["format"] = "%f"
|
472 |
+
|
473 |
+
if "multipleOf" in property:
|
474 |
+
# Set stepcount based on multiple of parameter
|
475 |
+
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
476 |
+
elif number_transform == int:
|
477 |
+
# Set step size to 1 as default
|
478 |
+
streamlit_kwargs["step"] = 1
|
479 |
+
elif number_transform == float:
|
480 |
+
# Set step size to 0.01 as default
|
481 |
+
# TODO: adapt to default value
|
482 |
+
streamlit_kwargs["step"] = 0.01
|
483 |
+
|
484 |
+
if "minimum" in property:
|
485 |
+
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
486 |
+
if "exclusiveMinimum" in property:
|
487 |
+
streamlit_kwargs["min_value"] = number_transform(
|
488 |
+
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
489 |
+
)
|
490 |
+
if "maximum" in property:
|
491 |
+
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
492 |
+
|
493 |
+
if "exclusiveMaximum" in property:
|
494 |
+
streamlit_kwargs["max_value"] = number_transform(
|
495 |
+
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
496 |
+
)
|
497 |
+
|
498 |
+
if property.get("default") is not None:
|
499 |
+
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
500 |
+
else:
|
501 |
+
if "min_value" in streamlit_kwargs:
|
502 |
+
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
503 |
+
elif number_transform == int:
|
504 |
+
streamlit_kwargs["value"] = 0
|
505 |
+
else:
|
506 |
+
# Set default value to step
|
507 |
+
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
508 |
+
|
509 |
+
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
510 |
+
# TODO: Only if less than X steps
|
511 |
+
return streamlit_app.slider(**streamlit_kwargs)
|
512 |
+
else:
|
513 |
+
return streamlit_app.number_input(**streamlit_kwargs)
|
514 |
+
|
515 |
+
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
516 |
+
properties = property["properties"]
|
517 |
+
object_inputs = {}
|
518 |
+
for property_key in properties:
|
519 |
+
property = properties[property_key]
|
520 |
+
if not property.get("title"):
|
521 |
+
# Set property key as fallback title
|
522 |
+
property["title"] = name_to_title(property_key)
|
523 |
+
# construct full key based on key parts -> required later to get the value
|
524 |
+
full_key = key + "." + property_key
|
525 |
+
object_inputs[property_key] = self._render_property(
|
526 |
+
streamlit_app, full_key, property
|
527 |
+
)
|
528 |
+
return object_inputs
|
529 |
+
|
530 |
+
def _render_single_object_input(
|
531 |
+
self, streamlit_app: st, key: str, property: Dict
|
532 |
+
) -> Any:
|
533 |
+
# Add title and subheader
|
534 |
+
title = property.get("title")
|
535 |
+
streamlit_app.subheader(title)
|
536 |
+
if property.get("description"):
|
537 |
+
streamlit_app.markdown(property.get("description"))
|
538 |
+
|
539 |
+
object_reference = schema_utils.get_single_reference_item(
|
540 |
+
property, self._schema_references
|
541 |
+
)
|
542 |
+
return self._render_object_input(streamlit_app, key, object_reference)
|
543 |
+
|
544 |
+
def _render_property_list_input(
|
545 |
+
self, streamlit_app: st, key: str, property: Dict
|
546 |
+
) -> Any:
|
547 |
+
|
548 |
+
# Add title and subheader
|
549 |
+
streamlit_app.subheader(property.get("title"))
|
550 |
+
if property.get("description"):
|
551 |
+
streamlit_app.markdown(property.get("description"))
|
552 |
+
|
553 |
+
streamlit_app.markdown("---")
|
554 |
+
|
555 |
+
current_list = self._get_value(key)
|
556 |
+
if not current_list:
|
557 |
+
current_list = []
|
558 |
+
|
559 |
+
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
560 |
+
if property["items"]["type"] == "integer":
|
561 |
+
value_kwargs["value"] = 0 # type: ignore
|
562 |
+
new_value = streamlit_app.number_input(**value_kwargs)
|
563 |
+
elif property["items"]["type"] == "number":
|
564 |
+
value_kwargs["value"] = 0.0 # type: ignore
|
565 |
+
value_kwargs["format"] = "%f"
|
566 |
+
new_value = streamlit_app.number_input(**value_kwargs)
|
567 |
+
else:
|
568 |
+
value_kwargs["value"] = ""
|
569 |
+
new_value = streamlit_app.text_input(**value_kwargs)
|
570 |
+
|
571 |
+
streamlit_app.markdown("---")
|
572 |
+
|
573 |
+
with streamlit_app.container():
|
574 |
+
clear_col, add_col = streamlit_app.columns([1, 2])
|
575 |
+
|
576 |
+
with clear_col:
|
577 |
+
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
578 |
+
current_list = []
|
579 |
+
|
580 |
+
with add_col:
|
581 |
+
if (
|
582 |
+
streamlit_app.button("Add Item", key=key + "-add-item")
|
583 |
+
and new_value is not None
|
584 |
+
):
|
585 |
+
current_list.append(new_value)
|
586 |
+
|
587 |
+
streamlit_app.write(current_list)
|
588 |
+
|
589 |
+
return current_list
|
590 |
+
|
591 |
+
def _render_object_list_input(
|
592 |
+
self, streamlit_app: st, key: str, property: Dict
|
593 |
+
) -> Any:
|
594 |
+
|
595 |
+
# TODO: support max_items, and min_items properties
|
596 |
+
|
597 |
+
# Add title and subheader
|
598 |
+
streamlit_app.subheader(property.get("title"))
|
599 |
+
if property.get("description"):
|
600 |
+
streamlit_app.markdown(property.get("description"))
|
601 |
+
|
602 |
+
streamlit_app.markdown("---")
|
603 |
+
|
604 |
+
current_list = self._get_value(key)
|
605 |
+
if not current_list:
|
606 |
+
current_list = []
|
607 |
+
|
608 |
+
object_reference = schema_utils.resolve_reference(
|
609 |
+
property["items"]["$ref"], self._schema_references
|
610 |
+
)
|
611 |
+
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
612 |
+
|
613 |
+
streamlit_app.markdown("---")
|
614 |
+
|
615 |
+
with streamlit_app.container():
|
616 |
+
clear_col, add_col = streamlit_app.columns([1, 2])
|
617 |
+
|
618 |
+
with clear_col:
|
619 |
+
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
620 |
+
current_list = []
|
621 |
+
|
622 |
+
with add_col:
|
623 |
+
if (
|
624 |
+
streamlit_app.button("Add Item", key=key + "-add-item")
|
625 |
+
and input_data
|
626 |
+
):
|
627 |
+
current_list.append(input_data)
|
628 |
+
|
629 |
+
streamlit_app.write(current_list)
|
630 |
+
return current_list
|
631 |
+
|
632 |
+
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
633 |
+
if schema_utils.is_single_enum_property(property, self._schema_references):
|
634 |
+
return self._render_single_enum_input(streamlit_app, key, property)
|
635 |
+
|
636 |
+
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
637 |
+
return self._render_multi_enum_input(streamlit_app, key, property)
|
638 |
+
|
639 |
+
if schema_utils.is_single_file_property(property):
|
640 |
+
return self._render_single_file_input(streamlit_app, key, property)
|
641 |
+
|
642 |
+
if schema_utils.is_multi_file_property(property):
|
643 |
+
return self._render_multi_file_input(streamlit_app, key, property)
|
644 |
+
|
645 |
+
if schema_utils.is_single_datetime_property(property):
|
646 |
+
return self._render_single_datetime_input(streamlit_app, key, property)
|
647 |
+
|
648 |
+
if schema_utils.is_single_boolean_property(property):
|
649 |
+
return self._render_single_boolean_input(streamlit_app, key, property)
|
650 |
+
|
651 |
+
if schema_utils.is_single_dict_property(property):
|
652 |
+
return self._render_single_dict_input(streamlit_app, key, property)
|
653 |
+
|
654 |
+
if schema_utils.is_single_number_property(property):
|
655 |
+
return self._render_single_number_input(streamlit_app, key, property)
|
656 |
+
|
657 |
+
if schema_utils.is_single_string_property(property):
|
658 |
+
return self._render_single_string_input(streamlit_app, key, property)
|
659 |
+
|
660 |
+
if schema_utils.is_single_object(property, self._schema_references):
|
661 |
+
return self._render_single_object_input(streamlit_app, key, property)
|
662 |
+
|
663 |
+
if schema_utils.is_object_list_property(property, self._schema_references):
|
664 |
+
return self._render_object_list_input(streamlit_app, key, property)
|
665 |
+
|
666 |
+
if schema_utils.is_property_list(property):
|
667 |
+
return self._render_property_list_input(streamlit_app, key, property)
|
668 |
+
|
669 |
+
if schema_utils.is_single_reference(property):
|
670 |
+
return self._render_single_reference(streamlit_app, key, property)
|
671 |
+
|
672 |
+
streamlit_app.warning(
|
673 |
+
"The type of the following property is currently not supported: "
|
674 |
+
+ str(property.get("title"))
|
675 |
+
)
|
676 |
+
raise Exception("Unsupported property")
|
677 |
+
|
678 |
+
|
679 |
+
class OutputUI:
|
680 |
+
def __init__(self, output_data: Any, input_data: Any):
|
681 |
+
self._output_data = output_data
|
682 |
+
self._input_data = input_data
|
683 |
+
|
684 |
+
def render_ui(self, streamlit_app) -> None:
|
685 |
+
try:
|
686 |
+
if isinstance(self._output_data, BaseModel):
|
687 |
+
self._render_single_output(streamlit_app, self._output_data)
|
688 |
+
return
|
689 |
+
if type(self._output_data) == list:
|
690 |
+
self._render_list_output(streamlit_app, self._output_data)
|
691 |
+
return
|
692 |
+
except Exception as ex:
|
693 |
+
streamlit_app.exception(ex)
|
694 |
+
# Fallback to
|
695 |
+
streamlit_app.json(jsonable_encoder(self._output_data))
|
696 |
+
|
697 |
+
def _render_single_text_property(
|
698 |
+
self, streamlit: st, property_schema: Dict, value: Any
|
699 |
+
) -> None:
|
700 |
+
# Add title and subheader
|
701 |
+
streamlit.subheader(property_schema.get("title"))
|
702 |
+
if property_schema.get("description"):
|
703 |
+
streamlit.markdown(property_schema.get("description"))
|
704 |
+
if value is None or value == "":
|
705 |
+
streamlit.info("No value returned!")
|
706 |
+
else:
|
707 |
+
streamlit.code(str(value), language="plain")
|
708 |
+
|
709 |
+
def _render_single_file_property(
|
710 |
+
self, streamlit: st, property_schema: Dict, value: Any
|
711 |
+
) -> None:
|
712 |
+
# Add title and subheader
|
713 |
+
streamlit.subheader(property_schema.get("title"))
|
714 |
+
if property_schema.get("description"):
|
715 |
+
streamlit.markdown(property_schema.get("description"))
|
716 |
+
if value is None or value == "":
|
717 |
+
streamlit.info("No value returned!")
|
718 |
+
else:
|
719 |
+
# TODO: Detect if it is a FileContent instance
|
720 |
+
# TODO: detect if it is base64
|
721 |
+
file_extension = ""
|
722 |
+
if "mime_type" in property_schema:
|
723 |
+
mime_type = property_schema["mime_type"]
|
724 |
+
file_extension = mimetypes.guess_extension(mime_type) or ""
|
725 |
+
|
726 |
+
if is_compatible_audio(mime_type):
|
727 |
+
streamlit.audio(value.as_bytes(), format=mime_type)
|
728 |
+
return
|
729 |
+
|
730 |
+
if is_compatible_image(mime_type):
|
731 |
+
streamlit.image(value.as_bytes())
|
732 |
+
return
|
733 |
+
|
734 |
+
if is_compatible_video(mime_type):
|
735 |
+
streamlit.video(value.as_bytes(), format=mime_type)
|
736 |
+
return
|
737 |
+
|
738 |
+
filename = (
|
739 |
+
(property_schema["title"] + file_extension)
|
740 |
+
.lower()
|
741 |
+
.strip()
|
742 |
+
.replace(" ", "-")
|
743 |
+
)
|
744 |
+
streamlit.markdown(
|
745 |
+
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
746 |
+
unsafe_allow_html=True,
|
747 |
+
)
|
748 |
+
|
749 |
+
def _render_single_complex_property(
|
750 |
+
self, streamlit: st, property_schema: Dict, value: Any
|
751 |
+
) -> None:
|
752 |
+
# Add title and subheader
|
753 |
+
streamlit.subheader(property_schema.get("title"))
|
754 |
+
if property_schema.get("description"):
|
755 |
+
streamlit.markdown(property_schema.get("description"))
|
756 |
+
|
757 |
+
streamlit.json(jsonable_encoder(value))
|
758 |
+
|
759 |
+
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
760 |
+
try:
|
761 |
+
if has_output_ui_renderer(output_data):
|
762 |
+
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
763 |
+
# render method also requests the input data
|
764 |
+
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
765 |
+
else:
|
766 |
+
output_data.render_output_ui(streamlit) # type: ignore
|
767 |
+
return
|
768 |
+
except Exception:
|
769 |
+
# Use default auto-generation methods if the custom rendering throws an exception
|
770 |
+
logger.exception(
|
771 |
+
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
772 |
+
)
|
773 |
+
|
774 |
+
model_schema = output_data.schema(by_alias=False)
|
775 |
+
model_properties = model_schema.get("properties")
|
776 |
+
definitions = model_schema.get("definitions")
|
777 |
+
|
778 |
+
if model_properties:
|
779 |
+
for property_key in output_data.__dict__:
|
780 |
+
property_schema = model_properties.get(property_key)
|
781 |
+
if not property_schema.get("title"):
|
782 |
+
# Set property key as fallback title
|
783 |
+
property_schema["title"] = property_key
|
784 |
+
|
785 |
+
output_property_value = output_data.__dict__[property_key]
|
786 |
+
|
787 |
+
if has_output_ui_renderer(output_property_value):
|
788 |
+
output_property_value.render_output_ui(streamlit) # type: ignore
|
789 |
+
continue
|
790 |
+
|
791 |
+
if isinstance(output_property_value, BaseModel):
|
792 |
+
# Render output recursivly
|
793 |
+
streamlit.subheader(property_schema.get("title"))
|
794 |
+
if property_schema.get("description"):
|
795 |
+
streamlit.markdown(property_schema.get("description"))
|
796 |
+
self._render_single_output(streamlit, output_property_value)
|
797 |
+
continue
|
798 |
+
|
799 |
+
if property_schema:
|
800 |
+
if schema_utils.is_single_file_property(property_schema):
|
801 |
+
self._render_single_file_property(
|
802 |
+
streamlit, property_schema, output_property_value
|
803 |
+
)
|
804 |
+
continue
|
805 |
+
|
806 |
+
if (
|
807 |
+
schema_utils.is_single_string_property(property_schema)
|
808 |
+
or schema_utils.is_single_number_property(property_schema)
|
809 |
+
or schema_utils.is_single_datetime_property(property_schema)
|
810 |
+
or schema_utils.is_single_boolean_property(property_schema)
|
811 |
+
):
|
812 |
+
self._render_single_text_property(
|
813 |
+
streamlit, property_schema, output_property_value
|
814 |
+
)
|
815 |
+
continue
|
816 |
+
if definitions and schema_utils.is_single_enum_property(
|
817 |
+
property_schema, definitions
|
818 |
+
):
|
819 |
+
self._render_single_text_property(
|
820 |
+
streamlit, property_schema, output_property_value.value
|
821 |
+
)
|
822 |
+
continue
|
823 |
+
|
824 |
+
# TODO: render dict as table
|
825 |
+
|
826 |
+
self._render_single_complex_property(
|
827 |
+
streamlit, property_schema, output_property_value
|
828 |
+
)
|
829 |
+
return
|
830 |
+
|
831 |
+
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
832 |
+
try:
|
833 |
+
data_items: List = []
|
834 |
+
for data_item in output_data:
|
835 |
+
if has_output_ui_renderer(data_item):
|
836 |
+
# Render using the render function
|
837 |
+
data_item.render_output_ui(streamlit) # type: ignore
|
838 |
+
continue
|
839 |
+
data_items.append(data_item.dict())
|
840 |
+
# Try to show as dataframe
|
841 |
+
streamlit.table(pd.DataFrame(data_items))
|
842 |
+
except Exception:
|
843 |
+
# Fallback to
|
844 |
+
streamlit.json(jsonable_encoder(output_data))
|
845 |
+
|
846 |
+
|
847 |
+
def getOpyrator(mode: str) -> Opyrator:
|
848 |
+
if mode == None or mode.startswith('VC'):
|
849 |
+
from control.mkgui.app_vc import convert
|
850 |
+
return Opyrator(convert)
|
851 |
+
if mode == None or mode.startswith('้ขๅค็'):
|
852 |
+
from control.mkgui.preprocess import preprocess
|
853 |
+
return Opyrator(preprocess)
|
854 |
+
if mode == None or mode.startswith('ๆจกๅ่ฎญ็ป'):
|
855 |
+
from control.mkgui.train import train
|
856 |
+
return Opyrator(train)
|
857 |
+
if mode == None or mode.startswith('ๆจกๅ่ฎญ็ป(VC)'):
|
858 |
+
from control.mkgui.train_vc import train_vc
|
859 |
+
return Opyrator(train_vc)
|
860 |
+
from control.mkgui.app import synthesize
|
861 |
+
return Opyrator(synthesize)
|
862 |
+
|
863 |
+
def render_streamlit_ui() -> None:
|
864 |
+
# init
|
865 |
+
session_state = st.session_state
|
866 |
+
session_state.input_data = {}
|
867 |
+
# Add custom css settings
|
868 |
+
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
869 |
+
|
870 |
+
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
871 |
+
session_state.mode = st.sidebar.selectbox(
|
872 |
+
'ๆจกๅผ้ๆฉ',
|
873 |
+
( "AIๆ้ณ", "VCๆ้ณ", "้ขๅค็", "ๆจกๅ่ฎญ็ป", "ๆจกๅ่ฎญ็ป(VC)")
|
874 |
+
)
|
875 |
+
if "mode" in session_state:
|
876 |
+
mode = session_state.mode
|
877 |
+
else:
|
878 |
+
mode = ""
|
879 |
+
opyrator = getOpyrator(mode)
|
880 |
+
title = opyrator.name + mode
|
881 |
+
|
882 |
+
col1, col2, _ = st.columns(3)
|
883 |
+
col2.title(title)
|
884 |
+
col2.markdown("ๆฌข่ฟไฝฟ็จMockingBird Web 2")
|
885 |
+
|
886 |
+
image = Image.open(path.join('control','mkgui', 'static', 'mb.png'))
|
887 |
+
col1.image(image)
|
888 |
+
|
889 |
+
st.markdown("---")
|
890 |
+
left, right = st.columns([0.4, 0.6])
|
891 |
+
|
892 |
+
with left:
|
893 |
+
st.header("Control ๆงๅถ")
|
894 |
+
# if session_state.mode in ["AIๆ้ณ", "VCๆ้ณ"] :
|
895 |
+
# from audiorecorder import audiorecorder
|
896 |
+
# audio = audiorecorder("Click to record", "Recording...")
|
897 |
+
# if len(audio) > 0:
|
898 |
+
# # To play audio in frontend:
|
899 |
+
# st.audio(audio.tobytes())
|
900 |
+
|
901 |
+
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
902 |
+
execute_selected = st.button(opyrator.action)
|
903 |
+
if execute_selected:
|
904 |
+
with st.spinner("Executing operation. Please wait..."):
|
905 |
+
try:
|
906 |
+
input_data_obj = parse_obj_as(
|
907 |
+
opyrator.input_type, session_state.input_data
|
908 |
+
)
|
909 |
+
session_state.output_data = opyrator(input=input_data_obj)
|
910 |
+
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
911 |
+
except ValidationError as ex:
|
912 |
+
st.error(ex)
|
913 |
+
else:
|
914 |
+
# st.success("Operation executed successfully.")
|
915 |
+
pass
|
916 |
+
|
917 |
+
with right:
|
918 |
+
st.header("Result ็ปๆ")
|
919 |
+
if 'output_data' in session_state:
|
920 |
+
OutputUI(
|
921 |
+
session_state.output_data, session_state.latest_operation_input
|
922 |
+
).render_ui(st)
|
923 |
+
if st.button("Clear"):
|
924 |
+
# Clear all state
|
925 |
+
for key in st.session_state.keys():
|
926 |
+
del st.session_state[key]
|
927 |
+
session_state.input_data = {}
|
928 |
+
st.experimental_rerun()
|
929 |
+
else:
|
930 |
+
# placeholder
|
931 |
+
st.caption("่ฏทไฝฟ็จๅทฆไพงๆงๅถๆฟ่ฟ่ก่พๅ
ฅๅนถ่ฟ่ก่ทๅพ็ปๆ")
|
932 |
+
|
933 |
+
|
control/mkgui/base/ui/streamlit_utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUSTOM_STREAMLIT_CSS = """
|
2 |
+
div[data-testid="stBlock"] button {
|
3 |
+
width: 100% !important;
|
4 |
+
margin-bottom: 20px !important;
|
5 |
+
border-color: #bfbfbf !important;
|
6 |
+
}
|
7 |
+
section[data-testid="stSidebar"] div {
|
8 |
+
max-width: 10rem;
|
9 |
+
}
|
10 |
+
pre code {
|
11 |
+
white-space: pre-wrap;
|
12 |
+
}
|
13 |
+
"""
|
control/mkgui/preprocess.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any, Tuple
|
6 |
+
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
10 |
+
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
11 |
+
|
12 |
+
|
13 |
+
if os.path.isdir(EXT_MODELS_DIRT):
|
14 |
+
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
15 |
+
print("Loaded extractor models: " + str(len(extractors)))
|
16 |
+
else:
|
17 |
+
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
18 |
+
|
19 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
20 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
21 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
22 |
+
else:
|
23 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
24 |
+
|
25 |
+
class Model(str, Enum):
|
26 |
+
VC_PPG2MEL = "ppg2mel"
|
27 |
+
|
28 |
+
class Dataset(str, Enum):
|
29 |
+
AIDATATANG_200ZH = "aidatatang_200zh"
|
30 |
+
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
31 |
+
|
32 |
+
class Input(BaseModel):
|
33 |
+
# def render_input_ui(st, input) -> Dict:
|
34 |
+
# input["selected_dataset"] = st.selectbox(
|
35 |
+
# '้ๆฉๆฐๆฎ้',
|
36 |
+
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
37 |
+
# )
|
38 |
+
# return input
|
39 |
+
model: Model = Field(
|
40 |
+
Model.VC_PPG2MEL, title="็ฎๆ ๆจกๅ",
|
41 |
+
)
|
42 |
+
dataset: Dataset = Field(
|
43 |
+
Dataset.AIDATATANG_200ZH, title="ๆฐๆฎ้้ๆฉ",
|
44 |
+
)
|
45 |
+
datasets_root: str = Field(
|
46 |
+
..., alias="ๆฐๆฎ้ๆ น็ฎๅฝ", description="่พๅ
ฅๆฐๆฎ้ๆ น็ฎๅฝ๏ผ็ธๅฏน/็ปๅฏน๏ผ",
|
47 |
+
format=True,
|
48 |
+
example="..\\trainning_data\\"
|
49 |
+
)
|
50 |
+
output_root: str = Field(
|
51 |
+
..., alias="่พๅบๆ น็ฎๅฝ", description="่พๅบ็ปๆๆ น็ฎๅฝ๏ผ็ธๅฏน/็ปๅฏน๏ผ",
|
52 |
+
format=True,
|
53 |
+
example="..\\trainning_data\\"
|
54 |
+
)
|
55 |
+
n_processes: int = Field(
|
56 |
+
2, alias="ๅค็็บฟ็จๆฐ", description="ๆ นๆฎCPU็บฟ็จๆฐๆฅ่ฎพ็ฝฎ",
|
57 |
+
le=32, ge=1
|
58 |
+
)
|
59 |
+
extractor: extractors = Field(
|
60 |
+
..., alias="็นๅพๆๅๆจกๅ",
|
61 |
+
description="้ๆฉPPG็นๅพๆๅๆจกๅๆไปถ."
|
62 |
+
)
|
63 |
+
encoder: encoders = Field(
|
64 |
+
..., alias="่ฏญ้ณ็ผ็ ๆจกๅ",
|
65 |
+
description="้ๆฉ่ฏญ้ณ็ผ็ ๆจกๅๆไปถ."
|
66 |
+
)
|
67 |
+
|
68 |
+
class AudioEntity(BaseModel):
|
69 |
+
content: bytes
|
70 |
+
mel: Any
|
71 |
+
|
72 |
+
class Output(BaseModel):
|
73 |
+
__root__: Tuple[str, int]
|
74 |
+
|
75 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
76 |
+
"""Custom output UI.
|
77 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
78 |
+
"""
|
79 |
+
sr, count = self.__root__
|
80 |
+
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
81 |
+
|
82 |
+
def preprocess(input: Input) -> Output:
|
83 |
+
"""Preprocess(้ขๅค็)"""
|
84 |
+
finished = 0
|
85 |
+
if input.model == Model.VC_PPG2MEL:
|
86 |
+
from models.ppg2mel.preprocess import preprocess_dataset
|
87 |
+
finished = preprocess_dataset(
|
88 |
+
datasets_root=Path(input.datasets_root),
|
89 |
+
dataset=input.dataset,
|
90 |
+
out_dir=Path(input.output_root),
|
91 |
+
n_processes=input.n_processes,
|
92 |
+
ppg_encoder_model_fpath=Path(input.extractor.value),
|
93 |
+
speaker_encoder_model=Path(input.encoder.value)
|
94 |
+
)
|
95 |
+
# TODO: pass useful return code
|
96 |
+
return Output(__root__=(input.dataset, finished))
|
control/mkgui/static/mb.png
ADDED
control/mkgui/train.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any
|
6 |
+
from models.synthesizer.hparams import hparams
|
7 |
+
from models.synthesizer.train import train as synt_train
|
8 |
+
|
9 |
+
# Constants
|
10 |
+
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
|
11 |
+
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
12 |
+
|
13 |
+
|
14 |
+
# EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
15 |
+
# CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
|
16 |
+
# ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
17 |
+
|
18 |
+
# Pre-Load models
|
19 |
+
if os.path.isdir(SYN_MODELS_DIRT):
|
20 |
+
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
21 |
+
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
22 |
+
else:
|
23 |
+
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
24 |
+
|
25 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
26 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
27 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
28 |
+
else:
|
29 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
30 |
+
|
31 |
+
class Model(str, Enum):
|
32 |
+
DEFAULT = "default"
|
33 |
+
|
34 |
+
class Input(BaseModel):
|
35 |
+
model: Model = Field(
|
36 |
+
Model.DEFAULT, title="ๆจกๅ็ฑปๅ",
|
37 |
+
)
|
38 |
+
# datasets_root: str = Field(
|
39 |
+
# ..., alias="้ขๅค็ๆฐๆฎๆ น็ฎๅฝ", description="่พๅ
ฅ็ฎๅฝ๏ผ็ธๅฏน/็ปๅฏน๏ผ,ไธ้็จไบppg2melๆจกๅ",
|
40 |
+
# format=True,
|
41 |
+
# example="..\\trainning_data\\"
|
42 |
+
# )
|
43 |
+
input_root: str = Field(
|
44 |
+
..., alias="่พๅ
ฅ็ฎๅฝ", description="้ขๅค็ๆฐๆฎๆ น็ฎๅฝ",
|
45 |
+
format=True,
|
46 |
+
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
47 |
+
)
|
48 |
+
run_id: str = Field(
|
49 |
+
"", alias="ๆฐๆจกๅๅ/่ฟ่กID", description="ไฝฟ็จๆฐID่ฟ่ก้ๆฐ่ฎญ็ป๏ผๅฆๅ้ๆฉไธ้ข็ๆจกๅ่ฟ่ก็ปง็ปญ่ฎญ็ป",
|
50 |
+
)
|
51 |
+
synthesizer: synthesizers = Field(
|
52 |
+
..., alias="ๅทฒๆๅๆๆจกๅ",
|
53 |
+
description="้ๆฉ่ฏญ้ณๅๆๆจกๅๆไปถ."
|
54 |
+
)
|
55 |
+
gpu: bool = Field(
|
56 |
+
True, alias="GPU่ฎญ็ป", description="้ๆฉโๆฏโ๏ผๅไฝฟ็จGPU่ฎญ็ป",
|
57 |
+
)
|
58 |
+
verbose: bool = Field(
|
59 |
+
True, alias="ๆๅฐ่ฏฆๆ
", description="้ๆฉโๆฏโ๏ผ่พๅบๆดๅค่ฏฆๆ
",
|
60 |
+
)
|
61 |
+
encoder: encoders = Field(
|
62 |
+
..., alias="่ฏญ้ณ็ผ็ ๆจกๅ",
|
63 |
+
description="้ๆฉ่ฏญ้ณ็ผ็ ๆจกๅๆไปถ."
|
64 |
+
)
|
65 |
+
save_every: int = Field(
|
66 |
+
1000, alias="ๆดๆฐ้ด้", description="ๆฏ้nๆญฅๅๆดๆฐไธๆฌกๆจกๅ",
|
67 |
+
)
|
68 |
+
backup_every: int = Field(
|
69 |
+
10000, alias="ไฟๅญ้ด้", description="ๆฏ้nๆญฅๅไฟๅญไธๆฌกๆจกๅ",
|
70 |
+
)
|
71 |
+
log_every: int = Field(
|
72 |
+
500, alias="ๆๅฐ้ด้", description="ๆฏ้nๆญฅๅๆๅฐไธๆฌก่ฎญ็ป็ป่ฎก",
|
73 |
+
)
|
74 |
+
|
75 |
+
class AudioEntity(BaseModel):
|
76 |
+
content: bytes
|
77 |
+
mel: Any
|
78 |
+
|
79 |
+
class Output(BaseModel):
|
80 |
+
__root__: int
|
81 |
+
|
82 |
+
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
83 |
+
"""Custom output UI.
|
84 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
85 |
+
"""
|
86 |
+
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
87 |
+
|
88 |
+
def train(input: Input) -> Output:
|
89 |
+
"""Train(่ฎญ็ป)"""
|
90 |
+
|
91 |
+
print(">>> Start training ...")
|
92 |
+
force_restart = len(input.run_id) > 0
|
93 |
+
if not force_restart:
|
94 |
+
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
95 |
+
|
96 |
+
synt_train(
|
97 |
+
input.run_id,
|
98 |
+
input.input_root,
|
99 |
+
f"data{os.sep}ckpt{os.sep}synthesizer",
|
100 |
+
input.save_every,
|
101 |
+
input.backup_every,
|
102 |
+
input.log_every,
|
103 |
+
force_restart,
|
104 |
+
hparams
|
105 |
+
)
|
106 |
+
return Output(__root__=0)
|
control/mkgui/train_vc.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any, Tuple
|
6 |
+
import numpy as np
|
7 |
+
from utils.hparams import HpsYaml
|
8 |
+
from utils.util import AttrDict
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# Constants
|
12 |
+
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
13 |
+
CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
|
14 |
+
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
15 |
+
|
16 |
+
|
17 |
+
if os.path.isdir(EXT_MODELS_DIRT):
|
18 |
+
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
19 |
+
print("Loaded extractor models: " + str(len(extractors)))
|
20 |
+
else:
|
21 |
+
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
22 |
+
|
23 |
+
if os.path.isdir(CONV_MODELS_DIRT):
|
24 |
+
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
25 |
+
print("Loaded convertor models: " + str(len(convertors)))
|
26 |
+
else:
|
27 |
+
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
28 |
+
|
29 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
30 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
31 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
32 |
+
else:
|
33 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
34 |
+
|
35 |
+
class Model(str, Enum):
|
36 |
+
VC_PPG2MEL = "ppg2mel"
|
37 |
+
|
38 |
+
class Dataset(str, Enum):
|
39 |
+
AIDATATANG_200ZH = "aidatatang_200zh"
|
40 |
+
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
41 |
+
|
42 |
+
class Input(BaseModel):
|
43 |
+
# def render_input_ui(st, input) -> Dict:
|
44 |
+
# input["selected_dataset"] = st.selectbox(
|
45 |
+
# '้ๆฉๆฐๆฎ้',
|
46 |
+
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
47 |
+
# )
|
48 |
+
# return input
|
49 |
+
model: Model = Field(
|
50 |
+
Model.VC_PPG2MEL, title="ๆจกๅ็ฑปๅ",
|
51 |
+
)
|
52 |
+
# datasets_root: str = Field(
|
53 |
+
# ..., alias="้ขๅค็ๆฐๆฎๆ น็ฎๅฝ", description="่พๅ
ฅ็ฎๅฝ๏ผ็ธๅฏน/็ปๅฏน๏ผ,ไธ้็จไบppg2melๆจกๅ",
|
54 |
+
# format=True,
|
55 |
+
# example="..\\trainning_data\\"
|
56 |
+
# )
|
57 |
+
output_root: str = Field(
|
58 |
+
..., alias="่พๅบ็ฎๅฝ(ๅฏ้)", description="ๅปบ่ฎฎไธๅกซ๏ผไฟๆ้ป่ฎค",
|
59 |
+
format=True,
|
60 |
+
example=""
|
61 |
+
)
|
62 |
+
continue_mode: bool = Field(
|
63 |
+
True, alias="็ปง็ปญ่ฎญ็ปๆจกๅผ", description="้ๆฉโๆฏโ๏ผๅไปไธ้ข้ๆฉ็ๆจกๅไธญ็ปง็ปญ่ฎญ็ป",
|
64 |
+
)
|
65 |
+
gpu: bool = Field(
|
66 |
+
True, alias="GPU่ฎญ็ป", description="้ๆฉโๆฏโ๏ผๅไฝฟ็จGPU่ฎญ็ป",
|
67 |
+
)
|
68 |
+
verbose: bool = Field(
|
69 |
+
True, alias="ๆๅฐ่ฏฆๆ
", description="้ๆฉโๆฏโ๏ผ่พๅบๆดๅค่ฏฆๆ
",
|
70 |
+
)
|
71 |
+
# TODO: Move to hiden fields by default
|
72 |
+
convertor: convertors = Field(
|
73 |
+
..., alias="่ฝฌๆขๆจกๅ",
|
74 |
+
description="้ๆฉ่ฏญ้ณ่ฝฌๆขๆจกๅๆไปถ."
|
75 |
+
)
|
76 |
+
extractor: extractors = Field(
|
77 |
+
..., alias="็นๅพๆๅๆจกๅ",
|
78 |
+
description="้ๆฉPPG็นๅพๆๅๆจกๅๆไปถ."
|
79 |
+
)
|
80 |
+
encoder: encoders = Field(
|
81 |
+
..., alias="่ฏญ้ณ็ผ็ ๆจกๅ",
|
82 |
+
description="้ๆฉ่ฏญ้ณ็ผ็ ๆจกๅๆไปถ."
|
83 |
+
)
|
84 |
+
njobs: int = Field(
|
85 |
+
8, alias="่ฟ็จๆฐ", description="้็จไบppg2mel",
|
86 |
+
)
|
87 |
+
seed: int = Field(
|
88 |
+
default=0, alias="ๅๅง้ๆบๆฐ", description="้็จไบppg2mel",
|
89 |
+
)
|
90 |
+
model_name: str = Field(
|
91 |
+
..., alias="ๆฐๆจกๅๅ", description="ไป
ๅจ้ๆฐ่ฎญ็ปๆถ็ๆ,้ไธญ็ปง็ปญ่ฎญ็ปๆถๆ ๆ",
|
92 |
+
example="test"
|
93 |
+
)
|
94 |
+
model_config: str = Field(
|
95 |
+
..., alias="ๆฐๆจกๅ้
็ฝฎ", description="ไป
ๅจ้ๆฐ่ฎญ็ปๆถ็ๆ,้ไธญ็ปง็ปญ่ฎญ็ปๆถๆ ๆ",
|
96 |
+
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
97 |
+
)
|
98 |
+
|
99 |
+
class AudioEntity(BaseModel):
|
100 |
+
content: bytes
|
101 |
+
mel: Any
|
102 |
+
|
103 |
+
class Output(BaseModel):
|
104 |
+
__root__: Tuple[str, int]
|
105 |
+
|
106 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
107 |
+
"""Custom output UI.
|
108 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
109 |
+
"""
|
110 |
+
sr, count = self.__root__
|
111 |
+
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
112 |
+
|
113 |
+
def train_vc(input: Input) -> Output:
|
114 |
+
"""Train VC(่ฎญ็ป VC)"""
|
115 |
+
|
116 |
+
print(">>> OneShot VC training ...")
|
117 |
+
params = AttrDict()
|
118 |
+
params.update({
|
119 |
+
"gpu": input.gpu,
|
120 |
+
"cpu": not input.gpu,
|
121 |
+
"njobs": input.njobs,
|
122 |
+
"seed": input.seed,
|
123 |
+
"verbose": input.verbose,
|
124 |
+
"load": input.convertor.value,
|
125 |
+
"warm_start": False,
|
126 |
+
})
|
127 |
+
if input.continue_mode:
|
128 |
+
# trace old model and config
|
129 |
+
p = Path(input.convertor.value)
|
130 |
+
params.name = p.parent.name
|
131 |
+
# search a config file
|
132 |
+
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
133 |
+
if len(model_config_fpaths) == 0:
|
134 |
+
raise "No model yaml config found for convertor"
|
135 |
+
config = HpsYaml(model_config_fpaths[0])
|
136 |
+
params.ckpdir = p.parent.parent
|
137 |
+
params.config = model_config_fpaths[0]
|
138 |
+
params.logdir = os.path.join(p.parent, "log")
|
139 |
+
else:
|
140 |
+
# Make the config dict dot visitable
|
141 |
+
config = HpsYaml(input.config)
|
142 |
+
np.random.seed(input.seed)
|
143 |
+
torch.manual_seed(input.seed)
|
144 |
+
if torch.cuda.is_available():
|
145 |
+
torch.cuda.manual_seed_all(input.seed)
|
146 |
+
mode = "train"
|
147 |
+
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
148 |
+
solver = Solver(config, params, mode)
|
149 |
+
solver.load_data()
|
150 |
+
solver.set_model()
|
151 |
+
solver.exec()
|
152 |
+
print(">>> Oneshot VC train finished!")
|
153 |
+
|
154 |
+
# TODO: pass useful return code
|
155 |
+
return Output(__root__=(input.dataset, 0))
|
control/toolbox/__init__.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from control.toolbox.ui import UI
|
2 |
+
from models.encoder import inference as encoder
|
3 |
+
from models.synthesizer.inference import Synthesizer
|
4 |
+
from models.vocoder.wavernn import inference as rnn_vocoder
|
5 |
+
from models.vocoder.hifigan import inference as gan_vocoder
|
6 |
+
from models.vocoder.fregan import inference as fgan_vocoder
|
7 |
+
from pathlib import Path
|
8 |
+
from time import perf_counter as timer
|
9 |
+
from control.toolbox.utterance import Utterance
|
10 |
+
import numpy as np
|
11 |
+
import traceback
|
12 |
+
import sys
|
13 |
+
import torch
|
14 |
+
import re
|
15 |
+
|
16 |
+
# ้ป่ฎคไฝฟ็จwavernn
|
17 |
+
vocoder = rnn_vocoder
|
18 |
+
|
19 |
+
# Use this directory structure for your datasets, or modify it to fit your needs
|
20 |
+
recognized_datasets = [
|
21 |
+
"LibriSpeech/dev-clean",
|
22 |
+
"LibriSpeech/dev-other",
|
23 |
+
"LibriSpeech/test-clean",
|
24 |
+
"LibriSpeech/test-other",
|
25 |
+
"LibriSpeech/train-clean-100",
|
26 |
+
"LibriSpeech/train-clean-360",
|
27 |
+
"LibriSpeech/train-other-500",
|
28 |
+
"LibriTTS/dev-clean",
|
29 |
+
"LibriTTS/dev-other",
|
30 |
+
"LibriTTS/test-clean",
|
31 |
+
"LibriTTS/test-other",
|
32 |
+
"LibriTTS/train-clean-100",
|
33 |
+
"LibriTTS/train-clean-360",
|
34 |
+
"LibriTTS/train-other-500",
|
35 |
+
"LJSpeech-1.1",
|
36 |
+
"VoxCeleb1/wav",
|
37 |
+
"VoxCeleb1/test_wav",
|
38 |
+
"VoxCeleb2/dev/aac",
|
39 |
+
"VoxCeleb2/test/aac",
|
40 |
+
"VCTK-Corpus/wav48",
|
41 |
+
"aidatatang_200zh/corpus/test",
|
42 |
+
"aidatatang_200zh/corpus/train",
|
43 |
+
"aishell3/test/wav",
|
44 |
+
"magicdata/train",
|
45 |
+
]
|
46 |
+
|
47 |
+
#Maximum of generated wavs to keep on memory
|
48 |
+
MAX_WAVES = 15
|
49 |
+
|
50 |
+
class Toolbox:
|
51 |
+
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
|
52 |
+
self.no_mp3_support = no_mp3_support
|
53 |
+
self.vc_mode = vc_mode
|
54 |
+
sys.excepthook = self.excepthook
|
55 |
+
self.datasets_root = datasets_root
|
56 |
+
self.utterances = set()
|
57 |
+
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
58 |
+
|
59 |
+
self.synthesizer = None # type: Synthesizer
|
60 |
+
|
61 |
+
# for ppg-based voice conversion
|
62 |
+
self.extractor = None
|
63 |
+
self.convertor = None # ppg2mel
|
64 |
+
|
65 |
+
self.current_wav = None
|
66 |
+
self.waves_list = []
|
67 |
+
self.waves_count = 0
|
68 |
+
self.waves_namelist = []
|
69 |
+
|
70 |
+
# Check for webrtcvad (enables removal of silences in vocoder output)
|
71 |
+
try:
|
72 |
+
import webrtcvad
|
73 |
+
self.trim_silences = True
|
74 |
+
except:
|
75 |
+
self.trim_silences = False
|
76 |
+
|
77 |
+
# Initialize the events and the interface
|
78 |
+
self.ui = UI(vc_mode)
|
79 |
+
self.style_idx = 0
|
80 |
+
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
|
81 |
+
self.setup_events()
|
82 |
+
self.ui.start()
|
83 |
+
|
84 |
+
def excepthook(self, exc_type, exc_value, exc_tb):
|
85 |
+
traceback.print_exception(exc_type, exc_value, exc_tb)
|
86 |
+
self.ui.log("Exception: %s" % exc_value)
|
87 |
+
|
88 |
+
def setup_events(self):
|
89 |
+
# Dataset, speaker and utterance selection
|
90 |
+
self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
|
91 |
+
random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
|
92 |
+
recognized_datasets,
|
93 |
+
level)
|
94 |
+
self.ui.random_dataset_button.clicked.connect(random_func(0))
|
95 |
+
self.ui.random_speaker_button.clicked.connect(random_func(1))
|
96 |
+
self.ui.random_utterance_button.clicked.connect(random_func(2))
|
97 |
+
self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
|
98 |
+
self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
|
99 |
+
|
100 |
+
# Model selection
|
101 |
+
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
102 |
+
def func():
|
103 |
+
self.synthesizer = None
|
104 |
+
if self.vc_mode:
|
105 |
+
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
|
106 |
+
else:
|
107 |
+
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
108 |
+
|
109 |
+
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
110 |
+
|
111 |
+
# Utterance selection
|
112 |
+
func = lambda: self.load_from_browser(self.ui.browse_file())
|
113 |
+
self.ui.browser_browse_button.clicked.connect(func)
|
114 |
+
func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
|
115 |
+
self.ui.utterance_history.currentIndexChanged.connect(func)
|
116 |
+
func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
|
117 |
+
self.ui.play_button.clicked.connect(func)
|
118 |
+
self.ui.stop_button.clicked.connect(self.ui.stop)
|
119 |
+
self.ui.record_button.clicked.connect(self.record)
|
120 |
+
|
121 |
+
# Source Utterance selection
|
122 |
+
if self.vc_mode:
|
123 |
+
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
|
124 |
+
self.ui.load_soruce_button.clicked.connect(func)
|
125 |
+
|
126 |
+
#Audio
|
127 |
+
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
128 |
+
|
129 |
+
#Wav playback & save
|
130 |
+
func = lambda: self.replay_last_wav()
|
131 |
+
self.ui.replay_wav_button.clicked.connect(func)
|
132 |
+
func = lambda: self.export_current_wave()
|
133 |
+
self.ui.export_wav_button.clicked.connect(func)
|
134 |
+
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
135 |
+
|
136 |
+
# Generation
|
137 |
+
self.ui.vocode_button.clicked.connect(self.vocode)
|
138 |
+
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
139 |
+
|
140 |
+
if self.vc_mode:
|
141 |
+
func = lambda: self.convert() or self.vocode()
|
142 |
+
self.ui.convert_button.clicked.connect(func)
|
143 |
+
else:
|
144 |
+
func = lambda: self.synthesize() or self.vocode()
|
145 |
+
self.ui.generate_button.clicked.connect(func)
|
146 |
+
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
147 |
+
|
148 |
+
# UMAP legend
|
149 |
+
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
150 |
+
|
151 |
+
def set_current_wav(self, index):
|
152 |
+
self.current_wav = self.waves_list[index]
|
153 |
+
|
154 |
+
def export_current_wave(self):
|
155 |
+
self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
|
156 |
+
|
157 |
+
def replay_last_wav(self):
|
158 |
+
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
159 |
+
|
160 |
+
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
|
161 |
+
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
162 |
+
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
|
163 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
164 |
+
|
165 |
+
def load_from_browser(self, fpath=None):
|
166 |
+
if fpath is None:
|
167 |
+
fpath = Path(self.datasets_root,
|
168 |
+
self.ui.current_dataset_name,
|
169 |
+
self.ui.current_speaker_name,
|
170 |
+
self.ui.current_utterance_name)
|
171 |
+
name = str(fpath.relative_to(self.datasets_root))
|
172 |
+
speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
|
173 |
+
|
174 |
+
# Select the next utterance
|
175 |
+
if self.ui.auto_next_checkbox.isChecked():
|
176 |
+
self.ui.browser_select_next()
|
177 |
+
elif fpath == "":
|
178 |
+
return
|
179 |
+
else:
|
180 |
+
name = fpath.name
|
181 |
+
speaker_name = fpath.parent.name
|
182 |
+
|
183 |
+
if fpath.suffix.lower() == ".mp3" and self.no_mp3_support:
|
184 |
+
self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used")
|
185 |
+
return
|
186 |
+
|
187 |
+
# Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
|
188 |
+
# playback, so as to have a fair comparison with the generated audio
|
189 |
+
wav = Synthesizer.load_preprocess_wav(fpath)
|
190 |
+
self.ui.log("Loaded %s" % name)
|
191 |
+
|
192 |
+
self.add_real_utterance(wav, name, speaker_name)
|
193 |
+
|
194 |
+
def load_soruce_button(self, utterance: Utterance):
|
195 |
+
self.selected_source_utterance = utterance
|
196 |
+
|
197 |
+
def record(self):
|
198 |
+
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
199 |
+
if wav is None:
|
200 |
+
return
|
201 |
+
self.ui.play(wav, encoder.sampling_rate)
|
202 |
+
|
203 |
+
speaker_name = "user01"
|
204 |
+
name = speaker_name + "_rec_%05d" % np.random.randint(100000)
|
205 |
+
self.add_real_utterance(wav, name, speaker_name)
|
206 |
+
|
207 |
+
def add_real_utterance(self, wav, name, speaker_name):
|
208 |
+
# Compute the mel spectrogram
|
209 |
+
spec = Synthesizer.make_spectrogram(wav)
|
210 |
+
self.ui.draw_spec(spec, "current")
|
211 |
+
|
212 |
+
# Compute the embedding
|
213 |
+
if not encoder.is_loaded():
|
214 |
+
self.init_encoder()
|
215 |
+
encoder_wav = encoder.preprocess_wav(wav)
|
216 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
217 |
+
|
218 |
+
# Add the utterance
|
219 |
+
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
220 |
+
self.utterances.add(utterance)
|
221 |
+
self.ui.register_utterance(utterance, self.vc_mode)
|
222 |
+
|
223 |
+
# Plot it
|
224 |
+
self.ui.draw_embed(embed, name, "current")
|
225 |
+
self.ui.draw_umap_projections(self.utterances)
|
226 |
+
|
227 |
+
def clear_utterances(self):
|
228 |
+
self.utterances.clear()
|
229 |
+
self.ui.draw_umap_projections(self.utterances)
|
230 |
+
|
231 |
+
def synthesize(self):
|
232 |
+
self.ui.log("Generating the mel spectrogram...")
|
233 |
+
self.ui.set_loading(1)
|
234 |
+
|
235 |
+
# Update the synthesizer random seed
|
236 |
+
if self.ui.random_seed_checkbox.isChecked():
|
237 |
+
seed = int(self.ui.seed_textbox.text())
|
238 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
239 |
+
else:
|
240 |
+
seed = None
|
241 |
+
|
242 |
+
if seed is not None:
|
243 |
+
torch.manual_seed(seed)
|
244 |
+
|
245 |
+
# Synthesize the spectrogram
|
246 |
+
if self.synthesizer is None or seed is not None:
|
247 |
+
self.init_synthesizer()
|
248 |
+
|
249 |
+
texts = self.ui.text_prompt.toPlainText().split("\n")
|
250 |
+
punctuation = '๏ผ๏ผใใ,' # punctuate and split/clean text
|
251 |
+
processed_texts = []
|
252 |
+
for text in texts:
|
253 |
+
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
254 |
+
if processed_text:
|
255 |
+
processed_texts.append(processed_text.strip())
|
256 |
+
texts = processed_texts
|
257 |
+
embed = self.ui.selected_utterance.embed
|
258 |
+
embeds = [embed] * len(texts)
|
259 |
+
min_token = int(self.ui.token_slider.value())
|
260 |
+
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
|
261 |
+
breaks = [spec.shape[1] for spec in specs]
|
262 |
+
spec = np.concatenate(specs, axis=1)
|
263 |
+
|
264 |
+
self.ui.draw_spec(spec, "generated")
|
265 |
+
self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
|
266 |
+
self.ui.set_loading(0)
|
267 |
+
|
268 |
+
def vocode(self):
|
269 |
+
speaker_name, spec, breaks, _ = self.current_generated
|
270 |
+
assert spec is not None
|
271 |
+
|
272 |
+
# Initialize the vocoder model and make it determinstic, if user provides a seed
|
273 |
+
if self.ui.random_seed_checkbox.isChecked():
|
274 |
+
seed = int(self.ui.seed_textbox.text())
|
275 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
276 |
+
else:
|
277 |
+
seed = None
|
278 |
+
|
279 |
+
if seed is not None:
|
280 |
+
torch.manual_seed(seed)
|
281 |
+
|
282 |
+
# Synthesize the waveform
|
283 |
+
if not vocoder.is_loaded() or seed is not None:
|
284 |
+
self.init_vocoder()
|
285 |
+
|
286 |
+
def vocoder_progress(i, seq_len, b_size, gen_rate):
|
287 |
+
real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
|
288 |
+
line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
|
289 |
+
% (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
|
290 |
+
self.ui.log(line, "overwrite")
|
291 |
+
self.ui.set_loading(i, seq_len)
|
292 |
+
if self.ui.current_vocoder_fpath is not None:
|
293 |
+
self.ui.log("")
|
294 |
+
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
295 |
+
else:
|
296 |
+
self.ui.log("Waveform generation with Griffin-Lim... ")
|
297 |
+
wav = Synthesizer.griffin_lim(spec)
|
298 |
+
self.ui.set_loading(0)
|
299 |
+
self.ui.log(" Done!", "append")
|
300 |
+
|
301 |
+
# Add breaks
|
302 |
+
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
303 |
+
b_starts = np.concatenate(([0], b_ends[:-1]))
|
304 |
+
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
305 |
+
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
|
306 |
+
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
307 |
+
|
308 |
+
# Trim excessive silences
|
309 |
+
if self.ui.trim_silences_checkbox.isChecked():
|
310 |
+
wav = encoder.preprocess_wav(wav)
|
311 |
+
|
312 |
+
# Play it
|
313 |
+
wav = wav / np.abs(wav).max() * 0.97
|
314 |
+
self.ui.play(wav, sample_rate)
|
315 |
+
|
316 |
+
# Name it (history displayed in combobox)
|
317 |
+
# TODO better naming for the combobox items?
|
318 |
+
wav_name = str(self.waves_count + 1)
|
319 |
+
|
320 |
+
#Update waves combobox
|
321 |
+
self.waves_count += 1
|
322 |
+
if self.waves_count > MAX_WAVES:
|
323 |
+
self.waves_list.pop()
|
324 |
+
self.waves_namelist.pop()
|
325 |
+
self.waves_list.insert(0, wav)
|
326 |
+
self.waves_namelist.insert(0, wav_name)
|
327 |
+
|
328 |
+
self.ui.waves_cb.disconnect()
|
329 |
+
self.ui.waves_cb_model.setStringList(self.waves_namelist)
|
330 |
+
self.ui.waves_cb.setCurrentIndex(0)
|
331 |
+
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
332 |
+
|
333 |
+
# Update current wav
|
334 |
+
self.set_current_wav(0)
|
335 |
+
|
336 |
+
#Enable replay and save buttons:
|
337 |
+
self.ui.replay_wav_button.setDisabled(False)
|
338 |
+
self.ui.export_wav_button.setDisabled(False)
|
339 |
+
|
340 |
+
# Compute the embedding
|
341 |
+
# TODO: this is problematic with different sampling rates, gotta fix it
|
342 |
+
if not encoder.is_loaded():
|
343 |
+
self.init_encoder()
|
344 |
+
encoder_wav = encoder.preprocess_wav(wav)
|
345 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
346 |
+
|
347 |
+
# Add the utterance
|
348 |
+
name = speaker_name + "_gen_%05d" % np.random.randint(100000)
|
349 |
+
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
|
350 |
+
self.utterances.add(utterance)
|
351 |
+
|
352 |
+
# Plot it
|
353 |
+
self.ui.draw_embed(embed, name, "generated")
|
354 |
+
self.ui.draw_umap_projections(self.utterances)
|
355 |
+
|
356 |
+
def convert(self):
|
357 |
+
self.ui.log("Extract PPG and Converting...")
|
358 |
+
self.ui.set_loading(1)
|
359 |
+
|
360 |
+
# Init
|
361 |
+
if self.convertor is None:
|
362 |
+
self.init_convertor()
|
363 |
+
if self.extractor is None:
|
364 |
+
self.init_extractor()
|
365 |
+
|
366 |
+
src_wav = self.selected_source_utterance.wav
|
367 |
+
|
368 |
+
# Compute the ppg
|
369 |
+
if not self.extractor is None:
|
370 |
+
ppg = self.extractor.extract_from_wav(src_wav)
|
371 |
+
|
372 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
373 |
+
ref_wav = self.ui.selected_utterance.wav
|
374 |
+
# Import necessary dependency of Voice Conversion
|
375 |
+
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
376 |
+
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
377 |
+
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
378 |
+
min_len = min(ppg.shape[1], len(lf0_uv))
|
379 |
+
ppg = ppg[:, :min_len]
|
380 |
+
lf0_uv = lf0_uv[:min_len]
|
381 |
+
_, mel_pred, att_ws = self.convertor.inference(
|
382 |
+
ppg,
|
383 |
+
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
384 |
+
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
|
385 |
+
)
|
386 |
+
mel_pred= mel_pred.transpose(0, 1)
|
387 |
+
breaks = [mel_pred.shape[1]]
|
388 |
+
mel_pred= mel_pred.detach().cpu().numpy()
|
389 |
+
self.ui.draw_spec(mel_pred, "generated")
|
390 |
+
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
|
391 |
+
self.ui.set_loading(0)
|
392 |
+
|
393 |
+
def init_extractor(self):
|
394 |
+
if self.ui.current_extractor_fpath is None:
|
395 |
+
return
|
396 |
+
model_fpath = self.ui.current_extractor_fpath
|
397 |
+
self.ui.log("Loading the extractor %s... " % model_fpath)
|
398 |
+
self.ui.set_loading(1)
|
399 |
+
start = timer()
|
400 |
+
import models.ppg_extractor as extractor
|
401 |
+
self.extractor = extractor.load_model(model_fpath)
|
402 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
403 |
+
self.ui.set_loading(0)
|
404 |
+
|
405 |
+
def init_convertor(self):
|
406 |
+
if self.ui.current_convertor_fpath is None:
|
407 |
+
return
|
408 |
+
model_fpath = self.ui.current_convertor_fpath
|
409 |
+
self.ui.log("Loading the convertor %s... " % model_fpath)
|
410 |
+
self.ui.set_loading(1)
|
411 |
+
start = timer()
|
412 |
+
import models.ppg2mel as convertor
|
413 |
+
self.convertor = convertor.load_model( model_fpath)
|
414 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
415 |
+
self.ui.set_loading(0)
|
416 |
+
|
417 |
+
def init_encoder(self):
|
418 |
+
model_fpath = self.ui.current_encoder_fpath
|
419 |
+
|
420 |
+
self.ui.log("Loading the encoder %s... " % model_fpath)
|
421 |
+
self.ui.set_loading(1)
|
422 |
+
start = timer()
|
423 |
+
encoder.load_model(model_fpath)
|
424 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
425 |
+
self.ui.set_loading(0)
|
426 |
+
|
427 |
+
def init_synthesizer(self):
|
428 |
+
model_fpath = self.ui.current_synthesizer_fpath
|
429 |
+
|
430 |
+
self.ui.log("Loading the synthesizer %s... " % model_fpath)
|
431 |
+
self.ui.set_loading(1)
|
432 |
+
start = timer()
|
433 |
+
self.synthesizer = Synthesizer(model_fpath)
|
434 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
435 |
+
self.ui.set_loading(0)
|
436 |
+
|
437 |
+
def init_vocoder(self):
|
438 |
+
|
439 |
+
global vocoder
|
440 |
+
model_fpath = self.ui.current_vocoder_fpath
|
441 |
+
# Case of Griffin-lim
|
442 |
+
if model_fpath is None:
|
443 |
+
return
|
444 |
+
# Sekect vocoder based on model name
|
445 |
+
model_config_fpath = None
|
446 |
+
if model_fpath.name is not None and model_fpath.name.find("hifigan") > -1:
|
447 |
+
vocoder = gan_vocoder
|
448 |
+
self.ui.log("set hifigan as vocoder")
|
449 |
+
# search a config file
|
450 |
+
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
451 |
+
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
452 |
+
return
|
453 |
+
if len(model_config_fpaths) > 0:
|
454 |
+
model_config_fpath = model_config_fpaths[0]
|
455 |
+
elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1:
|
456 |
+
vocoder = fgan_vocoder
|
457 |
+
self.ui.log("set fregan as vocoder")
|
458 |
+
# search a config file
|
459 |
+
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
460 |
+
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
461 |
+
return
|
462 |
+
if len(model_config_fpaths) > 0:
|
463 |
+
model_config_fpath = model_config_fpaths[0]
|
464 |
+
else:
|
465 |
+
vocoder = rnn_vocoder
|
466 |
+
self.ui.log("set wavernn as vocoder")
|
467 |
+
|
468 |
+
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
469 |
+
self.ui.set_loading(1)
|
470 |
+
start = timer()
|
471 |
+
vocoder.load_model(model_fpath, model_config_fpath)
|
472 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
473 |
+
self.ui.set_loading(0)
|
474 |
+
|
475 |
+
def update_seed_textbox(self):
|
476 |
+
self.ui.update_seed_textbox()
|
control/toolbox/assets/mb.png
ADDED
control/toolbox/ui.py
ADDED
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt5.QtCore import Qt, QStringListModel
|
2 |
+
from PyQt5 import QtGui
|
3 |
+
from PyQt5.QtWidgets import *
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
6 |
+
from models.encoder.inference import plot_embedding_as_heatmap
|
7 |
+
from control.toolbox.utterance import Utterance
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Set
|
10 |
+
import sounddevice as sd
|
11 |
+
import soundfile as sf
|
12 |
+
import numpy as np
|
13 |
+
# from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP
|
14 |
+
from time import sleep
|
15 |
+
import umap
|
16 |
+
import sys
|
17 |
+
from warnings import filterwarnings, warn
|
18 |
+
filterwarnings("ignore")
|
19 |
+
|
20 |
+
|
21 |
+
colormap = np.array([
|
22 |
+
[0, 127, 70],
|
23 |
+
[255, 0, 0],
|
24 |
+
[255, 217, 38],
|
25 |
+
[0, 135, 255],
|
26 |
+
[165, 0, 165],
|
27 |
+
[255, 167, 255],
|
28 |
+
[97, 142, 151],
|
29 |
+
[0, 255, 255],
|
30 |
+
[255, 96, 38],
|
31 |
+
[142, 76, 0],
|
32 |
+
[33, 0, 127],
|
33 |
+
[0, 0, 0],
|
34 |
+
[183, 183, 183],
|
35 |
+
[76, 255, 0],
|
36 |
+
], dtype=float) / 255
|
37 |
+
|
38 |
+
default_text = \
|
39 |
+
"ๆฌข่ฟไฝฟ็จๅทฅๅ
ท็ฎฑ, ็ฐๅทฒๆฏๆไธญๆ่พๅ
ฅ๏ผ"
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class UI(QDialog):
|
44 |
+
min_umap_points = 4
|
45 |
+
max_log_lines = 5
|
46 |
+
max_saved_utterances = 20
|
47 |
+
|
48 |
+
def draw_utterance(self, utterance: Utterance, which):
|
49 |
+
self.draw_spec(utterance.spec, which)
|
50 |
+
self.draw_embed(utterance.embed, utterance.name, which)
|
51 |
+
|
52 |
+
def draw_embed(self, embed, name, which):
|
53 |
+
embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
|
54 |
+
embed_ax.figure.suptitle("" if embed is None else name)
|
55 |
+
|
56 |
+
## Embedding
|
57 |
+
# Clear the plot
|
58 |
+
if len(embed_ax.images) > 0:
|
59 |
+
embed_ax.images[0].colorbar.remove()
|
60 |
+
embed_ax.clear()
|
61 |
+
|
62 |
+
# Draw the embed
|
63 |
+
if embed is not None:
|
64 |
+
plot_embedding_as_heatmap(embed, embed_ax)
|
65 |
+
embed_ax.set_title("embedding")
|
66 |
+
embed_ax.set_aspect("equal", "datalim")
|
67 |
+
embed_ax.set_xticks([])
|
68 |
+
embed_ax.set_yticks([])
|
69 |
+
embed_ax.figure.canvas.draw()
|
70 |
+
|
71 |
+
def draw_spec(self, spec, which):
|
72 |
+
_, spec_ax = self.current_ax if which == "current" else self.gen_ax
|
73 |
+
|
74 |
+
## Spectrogram
|
75 |
+
# Draw the spectrogram
|
76 |
+
spec_ax.clear()
|
77 |
+
if spec is not None:
|
78 |
+
im = spec_ax.imshow(spec, aspect="auto", interpolation="none")
|
79 |
+
# spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal",
|
80 |
+
# spec_ax=spec_ax)
|
81 |
+
spec_ax.set_title("mel spectrogram")
|
82 |
+
|
83 |
+
spec_ax.set_xticks([])
|
84 |
+
spec_ax.set_yticks([])
|
85 |
+
spec_ax.figure.canvas.draw()
|
86 |
+
if which != "current":
|
87 |
+
self.vocode_button.setDisabled(spec is None)
|
88 |
+
|
89 |
+
def draw_umap_projections(self, utterances: Set[Utterance]):
|
90 |
+
self.umap_ax.clear()
|
91 |
+
|
92 |
+
speakers = np.unique([u.speaker_name for u in utterances])
|
93 |
+
colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
|
94 |
+
embeds = [u.embed for u in utterances]
|
95 |
+
|
96 |
+
# Display a message if there aren't enough points
|
97 |
+
if len(utterances) < self.min_umap_points:
|
98 |
+
self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
|
99 |
+
(self.min_umap_points - len(utterances)),
|
100 |
+
horizontalalignment='center', fontsize=15)
|
101 |
+
self.umap_ax.set_title("")
|
102 |
+
|
103 |
+
# Compute the projections
|
104 |
+
else:
|
105 |
+
if not self.umap_hot:
|
106 |
+
self.log(
|
107 |
+
"Drawing UMAP projections for the first time, this will take a few seconds.")
|
108 |
+
self.umap_hot = True
|
109 |
+
|
110 |
+
reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
|
111 |
+
# reducer = TSNE()
|
112 |
+
projections = reducer.fit_transform(embeds)
|
113 |
+
|
114 |
+
speakers_done = set()
|
115 |
+
for projection, utterance in zip(projections, utterances):
|
116 |
+
color = colors[utterance.speaker_name]
|
117 |
+
mark = "x" if "_gen_" in utterance.name else "o"
|
118 |
+
label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
|
119 |
+
speakers_done.add(utterance.speaker_name)
|
120 |
+
self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
|
121 |
+
label=label)
|
122 |
+
# self.umap_ax.set_title("UMAP projections")
|
123 |
+
self.umap_ax.legend(prop={'size': 10})
|
124 |
+
|
125 |
+
# Draw the plot
|
126 |
+
self.umap_ax.set_aspect("equal", "datalim")
|
127 |
+
self.umap_ax.set_xticks([])
|
128 |
+
self.umap_ax.set_yticks([])
|
129 |
+
self.umap_ax.figure.canvas.draw()
|
130 |
+
|
131 |
+
def save_audio_file(self, wav, sample_rate):
|
132 |
+
dialog = QFileDialog()
|
133 |
+
dialog.setDefaultSuffix(".wav")
|
134 |
+
fpath, _ = dialog.getSaveFileName(
|
135 |
+
parent=self,
|
136 |
+
caption="Select a path to save the audio file",
|
137 |
+
filter="Audio Files (*.flac *.wav)"
|
138 |
+
)
|
139 |
+
if fpath:
|
140 |
+
#Default format is wav
|
141 |
+
if Path(fpath).suffix == "":
|
142 |
+
fpath += ".wav"
|
143 |
+
sf.write(fpath, wav, sample_rate)
|
144 |
+
|
145 |
+
def setup_audio_devices(self, sample_rate):
|
146 |
+
input_devices = []
|
147 |
+
output_devices = []
|
148 |
+
for device in sd.query_devices():
|
149 |
+
# Check if valid input
|
150 |
+
try:
|
151 |
+
sd.check_input_settings(device=device["name"], samplerate=sample_rate)
|
152 |
+
input_devices.append(device["name"])
|
153 |
+
except:
|
154 |
+
pass
|
155 |
+
|
156 |
+
# Check if valid output
|
157 |
+
try:
|
158 |
+
sd.check_output_settings(device=device["name"], samplerate=sample_rate)
|
159 |
+
output_devices.append(device["name"])
|
160 |
+
except Exception as e:
|
161 |
+
# Log a warning only if the device is not an input
|
162 |
+
if not device["name"] in input_devices:
|
163 |
+
warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
|
164 |
+
|
165 |
+
if len(input_devices) == 0:
|
166 |
+
self.log("No audio input device detected. Recording may not work.")
|
167 |
+
self.audio_in_device = None
|
168 |
+
else:
|
169 |
+
self.audio_in_device = input_devices[0]
|
170 |
+
|
171 |
+
if len(output_devices) == 0:
|
172 |
+
self.log("No supported output audio devices were found! Audio output may not work.")
|
173 |
+
self.audio_out_devices_cb.addItems(["None"])
|
174 |
+
self.audio_out_devices_cb.setDisabled(True)
|
175 |
+
else:
|
176 |
+
self.audio_out_devices_cb.clear()
|
177 |
+
self.audio_out_devices_cb.addItems(output_devices)
|
178 |
+
self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
|
179 |
+
|
180 |
+
self.set_audio_device()
|
181 |
+
|
182 |
+
def set_audio_device(self):
|
183 |
+
|
184 |
+
output_device = self.audio_out_devices_cb.currentText()
|
185 |
+
if output_device == "None":
|
186 |
+
output_device = None
|
187 |
+
|
188 |
+
# If None, sounddevice queries portaudio
|
189 |
+
sd.default.device = (self.audio_in_device, output_device)
|
190 |
+
|
191 |
+
def play(self, wav, sample_rate):
|
192 |
+
try:
|
193 |
+
sd.stop()
|
194 |
+
sd.play(wav, sample_rate)
|
195 |
+
except Exception as e:
|
196 |
+
print(e)
|
197 |
+
self.log("Error in audio playback. Try selecting a different audio output device.")
|
198 |
+
self.log("Your device must be connected before you start the toolbox.")
|
199 |
+
|
200 |
+
def stop(self):
|
201 |
+
sd.stop()
|
202 |
+
|
203 |
+
def record_one(self, sample_rate, duration):
|
204 |
+
self.record_button.setText("Recording...")
|
205 |
+
self.record_button.setDisabled(True)
|
206 |
+
|
207 |
+
self.log("Recording %d seconds of audio" % duration)
|
208 |
+
sd.stop()
|
209 |
+
try:
|
210 |
+
wav = sd.rec(duration * sample_rate, sample_rate, 1)
|
211 |
+
except Exception as e:
|
212 |
+
print(e)
|
213 |
+
self.log("Could not record anything. Is your recording device enabled?")
|
214 |
+
self.log("Your device must be connected before you start the toolbox.")
|
215 |
+
return None
|
216 |
+
|
217 |
+
for i in np.arange(0, duration, 0.1):
|
218 |
+
self.set_loading(i, duration)
|
219 |
+
sleep(0.1)
|
220 |
+
self.set_loading(duration, duration)
|
221 |
+
sd.wait()
|
222 |
+
|
223 |
+
self.log("Done recording.")
|
224 |
+
self.record_button.setText("Record")
|
225 |
+
self.record_button.setDisabled(False)
|
226 |
+
|
227 |
+
return wav.squeeze()
|
228 |
+
|
229 |
+
@property
|
230 |
+
def current_dataset_name(self):
|
231 |
+
return self.dataset_box.currentText()
|
232 |
+
|
233 |
+
@property
|
234 |
+
def current_speaker_name(self):
|
235 |
+
return self.speaker_box.currentText()
|
236 |
+
|
237 |
+
@property
|
238 |
+
def current_utterance_name(self):
|
239 |
+
return self.utterance_box.currentText()
|
240 |
+
|
241 |
+
def browse_file(self):
|
242 |
+
fpath = QFileDialog().getOpenFileName(
|
243 |
+
parent=self,
|
244 |
+
caption="Select an audio file",
|
245 |
+
filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
|
246 |
+
)
|
247 |
+
return Path(fpath[0]) if fpath[0] != "" else ""
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def repopulate_box(box, items, random=False):
|
251 |
+
"""
|
252 |
+
Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
|
253 |
+
data to the items
|
254 |
+
"""
|
255 |
+
box.blockSignals(True)
|
256 |
+
box.clear()
|
257 |
+
for item in items:
|
258 |
+
item = list(item) if isinstance(item, tuple) else [item]
|
259 |
+
box.addItem(str(item[0]), *item[1:])
|
260 |
+
if len(items) > 0:
|
261 |
+
box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
|
262 |
+
box.setDisabled(len(items) == 0)
|
263 |
+
box.blockSignals(False)
|
264 |
+
|
265 |
+
def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
|
266 |
+
random=True):
|
267 |
+
# Select a random dataset
|
268 |
+
if level <= 0:
|
269 |
+
if datasets_root is not None:
|
270 |
+
datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
|
271 |
+
datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
|
272 |
+
self.browser_load_button.setDisabled(len(datasets) == 0)
|
273 |
+
if datasets_root is None or len(datasets) == 0:
|
274 |
+
msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
|
275 |
+
if datasets_root is None else "o not have any of the recognized datasets" \
|
276 |
+
" in %s \n" \
|
277 |
+
"Please note use 'E:\datasets' as root path " \
|
278 |
+
"instead of 'E:\datasets\aidatatang_200zh\corpus\test' as an example " % datasets_root)
|
279 |
+
self.log(msg)
|
280 |
+
msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
|
281 |
+
"can still use the toolbox by recording samples yourself." % \
|
282 |
+
("\n\t".join(recognized_datasets))
|
283 |
+
print(msg, file=sys.stderr)
|
284 |
+
|
285 |
+
self.random_utterance_button.setDisabled(True)
|
286 |
+
self.random_speaker_button.setDisabled(True)
|
287 |
+
self.random_dataset_button.setDisabled(True)
|
288 |
+
self.utterance_box.setDisabled(True)
|
289 |
+
self.speaker_box.setDisabled(True)
|
290 |
+
self.dataset_box.setDisabled(True)
|
291 |
+
self.browser_load_button.setDisabled(True)
|
292 |
+
self.auto_next_checkbox.setDisabled(True)
|
293 |
+
return
|
294 |
+
self.repopulate_box(self.dataset_box, datasets, random)
|
295 |
+
|
296 |
+
# Select a random speaker
|
297 |
+
if level <= 1:
|
298 |
+
speakers_root = datasets_root.joinpath(self.current_dataset_name)
|
299 |
+
speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
|
300 |
+
self.repopulate_box(self.speaker_box, speaker_names, random)
|
301 |
+
|
302 |
+
# Select a random utterance
|
303 |
+
if level <= 2:
|
304 |
+
utterances_root = datasets_root.joinpath(
|
305 |
+
self.current_dataset_name,
|
306 |
+
self.current_speaker_name
|
307 |
+
)
|
308 |
+
utterances = []
|
309 |
+
for extension in ['mp3', 'flac', 'wav', 'm4a']:
|
310 |
+
utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
|
311 |
+
utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
|
312 |
+
self.repopulate_box(self.utterance_box, utterances, random)
|
313 |
+
|
314 |
+
def browser_select_next(self):
|
315 |
+
index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
|
316 |
+
self.utterance_box.setCurrentIndex(index)
|
317 |
+
|
318 |
+
@property
|
319 |
+
def current_encoder_fpath(self):
|
320 |
+
return self.encoder_box.itemData(self.encoder_box.currentIndex())
|
321 |
+
|
322 |
+
@property
|
323 |
+
def current_synthesizer_fpath(self):
|
324 |
+
return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
|
325 |
+
|
326 |
+
@property
|
327 |
+
def current_vocoder_fpath(self):
|
328 |
+
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
329 |
+
|
330 |
+
@property
|
331 |
+
def current_extractor_fpath(self):
|
332 |
+
return self.extractor_box.itemData(self.extractor_box.currentIndex())
|
333 |
+
|
334 |
+
@property
|
335 |
+
def current_convertor_fpath(self):
|
336 |
+
return self.convertor_box.itemData(self.convertor_box.currentIndex())
|
337 |
+
|
338 |
+
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
|
339 |
+
vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool):
|
340 |
+
# Encoder
|
341 |
+
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
|
342 |
+
if len(encoder_fpaths) == 0:
|
343 |
+
raise Exception("No encoder models found in %s" % encoder_models_dir)
|
344 |
+
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
|
345 |
+
|
346 |
+
if vc_mode:
|
347 |
+
# Extractor
|
348 |
+
extractor_fpaths = list(extractor_models_dir.glob("*.pt"))
|
349 |
+
if len(extractor_fpaths) == 0:
|
350 |
+
self.log("No extractor models found in %s" % extractor_fpaths)
|
351 |
+
self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths])
|
352 |
+
|
353 |
+
# Convertor
|
354 |
+
convertor_fpaths = list(convertor_models_dir.glob("*.pth"))
|
355 |
+
if len(convertor_fpaths) == 0:
|
356 |
+
self.log("No convertor models found in %s" % convertor_fpaths)
|
357 |
+
self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths])
|
358 |
+
else:
|
359 |
+
# Synthesizer
|
360 |
+
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
|
361 |
+
if len(synthesizer_fpaths) == 0:
|
362 |
+
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
|
363 |
+
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
|
364 |
+
|
365 |
+
# Vocoder
|
366 |
+
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
|
367 |
+
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
|
368 |
+
self.repopulate_box(self.vocoder_box, vocoder_items)
|
369 |
+
|
370 |
+
@property
|
371 |
+
def selected_utterance(self):
|
372 |
+
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
373 |
+
|
374 |
+
def register_utterance(self, utterance: Utterance, vc_mode):
|
375 |
+
self.utterance_history.blockSignals(True)
|
376 |
+
self.utterance_history.insertItem(0, utterance.name, utterance)
|
377 |
+
self.utterance_history.setCurrentIndex(0)
|
378 |
+
self.utterance_history.blockSignals(False)
|
379 |
+
|
380 |
+
if len(self.utterance_history) > self.max_saved_utterances:
|
381 |
+
self.utterance_history.removeItem(self.max_saved_utterances)
|
382 |
+
|
383 |
+
self.play_button.setDisabled(False)
|
384 |
+
if vc_mode:
|
385 |
+
self.convert_button.setDisabled(False)
|
386 |
+
else:
|
387 |
+
self.generate_button.setDisabled(False)
|
388 |
+
self.synthesize_button.setDisabled(False)
|
389 |
+
|
390 |
+
def log(self, line, mode="newline"):
|
391 |
+
if mode == "newline":
|
392 |
+
self.logs.append(line)
|
393 |
+
if len(self.logs) > self.max_log_lines:
|
394 |
+
del self.logs[0]
|
395 |
+
elif mode == "append":
|
396 |
+
self.logs[-1] += line
|
397 |
+
elif mode == "overwrite":
|
398 |
+
self.logs[-1] = line
|
399 |
+
log_text = '\n'.join(self.logs)
|
400 |
+
|
401 |
+
self.log_window.setText(log_text)
|
402 |
+
self.app.processEvents()
|
403 |
+
|
404 |
+
def set_loading(self, value, maximum=1):
|
405 |
+
self.loading_bar.setValue(int(value * 100))
|
406 |
+
self.loading_bar.setMaximum(int(maximum * 100))
|
407 |
+
self.loading_bar.setTextVisible(value != 0)
|
408 |
+
self.app.processEvents()
|
409 |
+
|
410 |
+
def populate_gen_options(self, seed, trim_silences):
|
411 |
+
if seed is not None:
|
412 |
+
self.random_seed_checkbox.setChecked(True)
|
413 |
+
self.seed_textbox.setText(str(seed))
|
414 |
+
self.seed_textbox.setEnabled(True)
|
415 |
+
else:
|
416 |
+
self.random_seed_checkbox.setChecked(False)
|
417 |
+
self.seed_textbox.setText(str(0))
|
418 |
+
self.seed_textbox.setEnabled(False)
|
419 |
+
|
420 |
+
if not trim_silences:
|
421 |
+
self.trim_silences_checkbox.setChecked(False)
|
422 |
+
self.trim_silences_checkbox.setDisabled(True)
|
423 |
+
|
424 |
+
def update_seed_textbox(self):
|
425 |
+
if self.random_seed_checkbox.isChecked():
|
426 |
+
self.seed_textbox.setEnabled(True)
|
427 |
+
else:
|
428 |
+
self.seed_textbox.setEnabled(False)
|
429 |
+
|
430 |
+
def reset_interface(self, vc_mode):
|
431 |
+
self.draw_embed(None, None, "current")
|
432 |
+
self.draw_embed(None, None, "generated")
|
433 |
+
self.draw_spec(None, "current")
|
434 |
+
self.draw_spec(None, "generated")
|
435 |
+
self.draw_umap_projections(set())
|
436 |
+
self.set_loading(0)
|
437 |
+
self.play_button.setDisabled(True)
|
438 |
+
if vc_mode:
|
439 |
+
self.convert_button.setDisabled(True)
|
440 |
+
else:
|
441 |
+
self.generate_button.setDisabled(True)
|
442 |
+
self.synthesize_button.setDisabled(True)
|
443 |
+
self.vocode_button.setDisabled(True)
|
444 |
+
self.replay_wav_button.setDisabled(True)
|
445 |
+
self.export_wav_button.setDisabled(True)
|
446 |
+
[self.log("") for _ in range(self.max_log_lines)]
|
447 |
+
|
448 |
+
def __init__(self, vc_mode):
|
449 |
+
## Initialize the application
|
450 |
+
self.app = QApplication(sys.argv)
|
451 |
+
super().__init__(None)
|
452 |
+
self.setWindowTitle("MockingBird GUI")
|
453 |
+
self.setWindowIcon(QtGui.QIcon('toolbox\\assets\\mb.png'))
|
454 |
+
self.setWindowFlag(Qt.WindowMinimizeButtonHint, True)
|
455 |
+
self.setWindowFlag(Qt.WindowMaximizeButtonHint, True)
|
456 |
+
|
457 |
+
|
458 |
+
## Main layouts
|
459 |
+
# Root
|
460 |
+
root_layout = QGridLayout()
|
461 |
+
self.setLayout(root_layout)
|
462 |
+
|
463 |
+
# Browser
|
464 |
+
browser_layout = QGridLayout()
|
465 |
+
root_layout.addLayout(browser_layout, 0, 0, 1, 8)
|
466 |
+
|
467 |
+
# Generation
|
468 |
+
gen_layout = QVBoxLayout()
|
469 |
+
root_layout.addLayout(gen_layout, 0, 8)
|
470 |
+
|
471 |
+
# Visualizations
|
472 |
+
vis_layout = QVBoxLayout()
|
473 |
+
root_layout.addLayout(vis_layout, 1, 0, 2, 8)
|
474 |
+
|
475 |
+
# Output
|
476 |
+
output_layout = QGridLayout()
|
477 |
+
vis_layout.addLayout(output_layout, 0)
|
478 |
+
|
479 |
+
# Projections
|
480 |
+
self.projections_layout = QVBoxLayout()
|
481 |
+
root_layout.addLayout(self.projections_layout, 1, 8, 2, 2)
|
482 |
+
|
483 |
+
## Projections
|
484 |
+
# UMap
|
485 |
+
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
|
486 |
+
fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
|
487 |
+
self.projections_layout.addWidget(FigureCanvas(fig))
|
488 |
+
self.umap_hot = False
|
489 |
+
self.clear_button = QPushButton("Clear")
|
490 |
+
self.projections_layout.addWidget(self.clear_button)
|
491 |
+
|
492 |
+
|
493 |
+
## Browser
|
494 |
+
# Dataset, speaker and utterance selection
|
495 |
+
i = 0
|
496 |
+
|
497 |
+
source_groupbox = QGroupBox('Source(ๆบ้ณ้ข)')
|
498 |
+
source_layout = QGridLayout()
|
499 |
+
source_groupbox.setLayout(source_layout)
|
500 |
+
browser_layout.addWidget(source_groupbox, i, 0, 1, 5)
|
501 |
+
|
502 |
+
self.dataset_box = QComboBox()
|
503 |
+
source_layout.addWidget(QLabel("Dataset(ๆฐๆฎ้):"), i, 0)
|
504 |
+
source_layout.addWidget(self.dataset_box, i, 1)
|
505 |
+
self.random_dataset_button = QPushButton("Random")
|
506 |
+
source_layout.addWidget(self.random_dataset_button, i, 2)
|
507 |
+
i += 1
|
508 |
+
self.speaker_box = QComboBox()
|
509 |
+
source_layout.addWidget(QLabel("Speaker(่ฏด่ฏ่
)"), i, 0)
|
510 |
+
source_layout.addWidget(self.speaker_box, i, 1)
|
511 |
+
self.random_speaker_button = QPushButton("Random")
|
512 |
+
source_layout.addWidget(self.random_speaker_button, i, 2)
|
513 |
+
i += 1
|
514 |
+
self.utterance_box = QComboBox()
|
515 |
+
source_layout.addWidget(QLabel("Utterance(้ณ้ข):"), i, 0)
|
516 |
+
source_layout.addWidget(self.utterance_box, i, 1)
|
517 |
+
self.random_utterance_button = QPushButton("Random")
|
518 |
+
source_layout.addWidget(self.random_utterance_button, i, 2)
|
519 |
+
|
520 |
+
i += 1
|
521 |
+
source_layout.addWidget(QLabel("<b>Use(ไฝฟ็จ):</b>"), i, 0)
|
522 |
+
self.browser_load_button = QPushButton("Load Above(ๅ ่ฝฝไธ้ข)")
|
523 |
+
source_layout.addWidget(self.browser_load_button, i, 1, 1, 2)
|
524 |
+
self.auto_next_checkbox = QCheckBox("Auto select next")
|
525 |
+
self.auto_next_checkbox.setChecked(True)
|
526 |
+
source_layout.addWidget(self.auto_next_checkbox, i+1, 1)
|
527 |
+
self.browser_browse_button = QPushButton("Browse(ๆๅผๆฌๅฐ)")
|
528 |
+
source_layout.addWidget(self.browser_browse_button, i, 3)
|
529 |
+
self.record_button = QPushButton("Record(ๅฝ้ณ)")
|
530 |
+
source_layout.addWidget(self.record_button, i+1, 3)
|
531 |
+
|
532 |
+
i += 2
|
533 |
+
# Utterance box
|
534 |
+
browser_layout.addWidget(QLabel("<b>Current(ๅฝๅ):</b>"), i, 0)
|
535 |
+
self.utterance_history = QComboBox()
|
536 |
+
browser_layout.addWidget(self.utterance_history, i, 1)
|
537 |
+
self.play_button = QPushButton("Play(ๆญๆพ)")
|
538 |
+
browser_layout.addWidget(self.play_button, i, 2)
|
539 |
+
self.stop_button = QPushButton("Stop(ๆๅ)")
|
540 |
+
browser_layout.addWidget(self.stop_button, i, 3)
|
541 |
+
if vc_mode:
|
542 |
+
self.load_soruce_button = QPushButton("Select(้ๆฉไธบ่ขซ่ฝฌๆข็่ฏญ้ณ่พๅ
ฅ)")
|
543 |
+
browser_layout.addWidget(self.load_soruce_button, i, 4)
|
544 |
+
|
545 |
+
i += 1
|
546 |
+
model_groupbox = QGroupBox('Models(ๆจกๅ้ๆฉ)')
|
547 |
+
model_layout = QHBoxLayout()
|
548 |
+
model_groupbox.setLayout(model_layout)
|
549 |
+
browser_layout.addWidget(model_groupbox, i, 0, 2, 5)
|
550 |
+
|
551 |
+
# Model and audio output selection
|
552 |
+
self.encoder_box = QComboBox()
|
553 |
+
model_layout.addWidget(QLabel("Encoder:"))
|
554 |
+
model_layout.addWidget(self.encoder_box)
|
555 |
+
self.synthesizer_box = QComboBox()
|
556 |
+
if vc_mode:
|
557 |
+
self.extractor_box = QComboBox()
|
558 |
+
model_layout.addWidget(QLabel("Extractor:"))
|
559 |
+
model_layout.addWidget(self.extractor_box)
|
560 |
+
self.convertor_box = QComboBox()
|
561 |
+
model_layout.addWidget(QLabel("Convertor:"))
|
562 |
+
model_layout.addWidget(self.convertor_box)
|
563 |
+
else:
|
564 |
+
model_layout.addWidget(QLabel("Synthesizer:"))
|
565 |
+
model_layout.addWidget(self.synthesizer_box)
|
566 |
+
self.vocoder_box = QComboBox()
|
567 |
+
model_layout.addWidget(QLabel("Vocoder:"))
|
568 |
+
model_layout.addWidget(self.vocoder_box)
|
569 |
+
|
570 |
+
#Replay & Save Audio
|
571 |
+
i = 0
|
572 |
+
output_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
573 |
+
self.waves_cb = QComboBox()
|
574 |
+
self.waves_cb_model = QStringListModel()
|
575 |
+
self.waves_cb.setModel(self.waves_cb_model)
|
576 |
+
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
|
577 |
+
output_layout.addWidget(self.waves_cb, i, 1)
|
578 |
+
self.replay_wav_button = QPushButton("Replay")
|
579 |
+
self.replay_wav_button.setToolTip("Replay last generated vocoder")
|
580 |
+
output_layout.addWidget(self.replay_wav_button, i, 2)
|
581 |
+
self.export_wav_button = QPushButton("Export")
|
582 |
+
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
|
583 |
+
output_layout.addWidget(self.export_wav_button, i, 3)
|
584 |
+
self.audio_out_devices_cb=QComboBox()
|
585 |
+
i += 1
|
586 |
+
output_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 0)
|
587 |
+
output_layout.addWidget(self.audio_out_devices_cb, i, 1)
|
588 |
+
|
589 |
+
## Embed & spectrograms
|
590 |
+
vis_layout.addStretch()
|
591 |
+
# TODO: add spectrograms for source
|
592 |
+
gridspec_kw = {"width_ratios": [1, 4]}
|
593 |
+
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
594 |
+
gridspec_kw=gridspec_kw)
|
595 |
+
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
596 |
+
vis_layout.addWidget(FigureCanvas(fig))
|
597 |
+
|
598 |
+
fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
599 |
+
gridspec_kw=gridspec_kw)
|
600 |
+
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
601 |
+
vis_layout.addWidget(FigureCanvas(fig))
|
602 |
+
|
603 |
+
for ax in self.current_ax.tolist() + self.gen_ax.tolist():
|
604 |
+
ax.set_facecolor("#F0F0F0")
|
605 |
+
for side in ["top", "right", "bottom", "left"]:
|
606 |
+
ax.spines[side].set_visible(False)
|
607 |
+
|
608 |
+
## Generation
|
609 |
+
self.text_prompt = QPlainTextEdit(default_text)
|
610 |
+
gen_layout.addWidget(self.text_prompt, stretch=1)
|
611 |
+
|
612 |
+
if vc_mode:
|
613 |
+
layout = QHBoxLayout()
|
614 |
+
self.convert_button = QPushButton("Extract and Convert")
|
615 |
+
layout.addWidget(self.convert_button)
|
616 |
+
gen_layout.addLayout(layout)
|
617 |
+
else:
|
618 |
+
self.generate_button = QPushButton("Synthesize and vocode")
|
619 |
+
gen_layout.addWidget(self.generate_button)
|
620 |
+
layout = QHBoxLayout()
|
621 |
+
self.synthesize_button = QPushButton("Synthesize only")
|
622 |
+
layout.addWidget(self.synthesize_button)
|
623 |
+
|
624 |
+
self.vocode_button = QPushButton("Vocode only")
|
625 |
+
layout.addWidget(self.vocode_button)
|
626 |
+
gen_layout.addLayout(layout)
|
627 |
+
|
628 |
+
|
629 |
+
layout_seed = QGridLayout()
|
630 |
+
self.random_seed_checkbox = QCheckBox("Random seed:")
|
631 |
+
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
632 |
+
layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
|
633 |
+
self.seed_textbox = QLineEdit()
|
634 |
+
self.seed_textbox.setMaximumWidth(80)
|
635 |
+
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
636 |
+
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
637 |
+
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
638 |
+
" This feature requires `webrtcvad` to be installed.")
|
639 |
+
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
640 |
+
self.style_slider = QSlider(Qt.Horizontal)
|
641 |
+
self.style_slider.setTickInterval(1)
|
642 |
+
self.style_slider.setFocusPolicy(Qt.NoFocus)
|
643 |
+
self.style_slider.setSingleStep(1)
|
644 |
+
self.style_slider.setRange(-1, 9)
|
645 |
+
self.style_value_label = QLabel("-1")
|
646 |
+
self.style_slider.setValue(-1)
|
647 |
+
layout_seed.addWidget(QLabel("Style:"), 1, 0)
|
648 |
+
|
649 |
+
self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s))
|
650 |
+
layout_seed.addWidget(self.style_value_label, 1, 1)
|
651 |
+
layout_seed.addWidget(self.style_slider, 1, 3)
|
652 |
+
|
653 |
+
self.token_slider = QSlider(Qt.Horizontal)
|
654 |
+
self.token_slider.setTickInterval(1)
|
655 |
+
self.token_slider.setFocusPolicy(Qt.NoFocus)
|
656 |
+
self.token_slider.setSingleStep(1)
|
657 |
+
self.token_slider.setRange(3, 9)
|
658 |
+
self.token_value_label = QLabel("5")
|
659 |
+
self.token_slider.setValue(4)
|
660 |
+
layout_seed.addWidget(QLabel("Accuracy(็ฒพๅบฆ):"), 2, 0)
|
661 |
+
|
662 |
+
self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s))
|
663 |
+
layout_seed.addWidget(self.token_value_label, 2, 1)
|
664 |
+
layout_seed.addWidget(self.token_slider, 2, 3)
|
665 |
+
|
666 |
+
self.length_slider = QSlider(Qt.Horizontal)
|
667 |
+
self.length_slider.setTickInterval(1)
|
668 |
+
self.length_slider.setFocusPolicy(Qt.NoFocus)
|
669 |
+
self.length_slider.setSingleStep(1)
|
670 |
+
self.length_slider.setRange(1, 10)
|
671 |
+
self.length_value_label = QLabel("2")
|
672 |
+
self.length_slider.setValue(2)
|
673 |
+
layout_seed.addWidget(QLabel("MaxLength(ๆๅคงๅฅ้ฟ):"), 3, 0)
|
674 |
+
|
675 |
+
self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s))
|
676 |
+
layout_seed.addWidget(self.length_value_label, 3, 1)
|
677 |
+
layout_seed.addWidget(self.length_slider, 3, 3)
|
678 |
+
|
679 |
+
gen_layout.addLayout(layout_seed)
|
680 |
+
|
681 |
+
self.loading_bar = QProgressBar()
|
682 |
+
gen_layout.addWidget(self.loading_bar)
|
683 |
+
|
684 |
+
self.log_window = QLabel()
|
685 |
+
self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
|
686 |
+
gen_layout.addWidget(self.log_window)
|
687 |
+
self.logs = []
|
688 |
+
gen_layout.addStretch()
|
689 |
+
|
690 |
+
|
691 |
+
## Set the size of the window and of the elements
|
692 |
+
max_size = QDesktopWidget().availableGeometry(self).size() * 0.5
|
693 |
+
self.resize(max_size)
|
694 |
+
|
695 |
+
## Finalize the display
|
696 |
+
self.reset_interface(vc_mode)
|
697 |
+
self.show()
|
698 |
+
|
699 |
+
def start(self):
|
700 |
+
self.app.exec_()
|
control/toolbox/utterance.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
|
4 |
+
Utterance.__eq__ = lambda x, y: x.name == y.name
|
5 |
+
Utterance.__hash__ = lambda x: hash(x.name)
|
data/ckpt/encoder/pretrained.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57715adc6f36047166ab06e37b904240aee2f4d10fc88f78ed91510cf4b38666
|
3 |
+
size 17095158
|
data/ckpt/vocoder/pretrained/config_16k.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 16,
|
5 |
+
"learning_rate": 0.0002,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.999,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [5,5,4,2],
|
12 |
+
"upsample_kernel_sizes": [10,10,8,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"segment_size": 6400,
|
18 |
+
"num_mels": 80,
|
19 |
+
"num_freq": 1025,
|
20 |
+
"n_fft": 1024,
|
21 |
+
"hop_size": 200,
|
22 |
+
"win_size": 800,
|
23 |
+
|
24 |
+
"sampling_rate": 16000,
|
25 |
+
|
26 |
+
"fmin": 0,
|
27 |
+
"fmax": 7600,
|
28 |
+
"fmax_for_loss": null,
|
29 |
+
|
30 |
+
"num_workers": 4
|
31 |
+
}
|
data/ckpt/vocoder/pretrained/g_hifigan.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c5b29830f9b42c481c108cb0b89d56f380928d4d46e1d30d65c92340ddc694e
|
3 |
+
size 51985448
|
data/ckpt/vocoder/pretrained/pretrained.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169
|
3 |
+
size 53845290
|
data/samples/T0055G0013S0005.wav
ADDED
Binary file (121 kB). View file
|
|