Spaces:
Running
Running
Update to V1.4
Browse files- app.py +7 -7
- fish_speech/configs/firefly_gan_vq.yaml +2 -3
- fish_speech/configs/text2semantic_finetune.yaml +1 -1
- fish_speech/i18n/README.md +27 -0
- fish_speech/i18n/__init__.py +3 -0
- fish_speech/i18n/core.py +40 -0
- fish_speech/i18n/locale/en_US.json +122 -0
- fish_speech/i18n/locale/es_ES.json +122 -0
- fish_speech/i18n/locale/ja_JP.json +123 -0
- fish_speech/i18n/locale/pt_BR.json +133 -0
- fish_speech/i18n/locale/zh_CN.json +122 -0
- fish_speech/i18n/scan.py +122 -0
- fish_speech/models/text2semantic/llama.py +27 -0
- fish_speech/models/vqgan/__init__.py +0 -3
- fish_speech/models/vqgan/modules/firefly.py +167 -196
- fish_speech/models/vqgan/modules/fsq.py +4 -27
- fish_speech/scheduler.py +19 -1
- fish_speech/text/clean.py +1 -1
- fish_speech/train.py +5 -1
- fish_speech/utils/__init__.py +2 -0
- fish_speech/utils/context.py +13 -0
- fish_speech/utils/file.py +0 -103
- fish_speech/webui/css/style.css +161 -0
- fish_speech/webui/html/footer.html +11 -0
- fish_speech/webui/js/animate.js +69 -0
- fish_speech/webui/launch_utils.py +120 -0
- fish_speech/webui/manage.py +1237 -0
- requirements.txt +2 -1
- tools/api.py +93 -135
- tools/commons.py +35 -0
- tools/download_models.py +55 -0
- tools/extract_model.py +21 -0
- tools/file.py +125 -0
- tools/llama/build_dataset.py +1 -1
- tools/llama/generate.py +28 -10
- tools/llama/merge_lora.py +1 -1
- tools/llama/quantize.py +2 -2
- tools/msgpack_api.py +34 -0
- tools/post_api.py +205 -0
- tools/sensevoice/README.md +59 -0
- tools/sensevoice/__init__.py +0 -0
- tools/sensevoice/auto_model.py +573 -0
- tools/sensevoice/fun_asr.py +332 -0
- tools/sensevoice/vad_utils.py +61 -0
- tools/smart_pad.py +47 -0
- tools/vqgan/create_train_split.py +1 -1
- tools/vqgan/extract_vq.py +3 -3
- tools/vqgan/inference.py +5 -3
- tools/webui.py +619 -0
- tools/whisper_asr.py +176 -0
app.py
CHANGED
@@ -10,7 +10,7 @@ import gc
|
|
10 |
|
11 |
# Download if not exists
|
12 |
os.makedirs("checkpoints", exist_ok=True)
|
13 |
-
snapshot_download(repo_id="fishaudio/fish-speech-1.
|
14 |
|
15 |
print("All checkpoints downloaded")
|
16 |
|
@@ -46,8 +46,8 @@ os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
|
46 |
|
47 |
HEADER_MD = """# Fish Speech
|
48 |
|
49 |
-
## The demo in this space is version 1.
|
50 |
-
## 该 Demo 为 Fish Speech 1.
|
51 |
|
52 |
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
53 |
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
@@ -61,8 +61,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
|
|
61 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
62 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
63 |
|
64 |
-
The model running in this WebUI is Fish Speech V1.
|
65 |
-
在此 WebUI 中运行的模型是 Fish Speech V1.
|
66 |
"""
|
67 |
|
68 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
@@ -560,12 +560,12 @@ def parse_args():
|
|
560 |
parser.add_argument(
|
561 |
"--llama-checkpoint-path",
|
562 |
type=Path,
|
563 |
-
default="checkpoints/fish-speech-1.
|
564 |
)
|
565 |
parser.add_argument(
|
566 |
"--decoder-checkpoint-path",
|
567 |
type=Path,
|
568 |
-
default="checkpoints/fish-speech-1.
|
569 |
)
|
570 |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
571 |
parser.add_argument("--device", type=str, default="cuda")
|
|
|
10 |
|
11 |
# Download if not exists
|
12 |
os.makedirs("checkpoints", exist_ok=True)
|
13 |
+
snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
|
14 |
|
15 |
print("All checkpoints downloaded")
|
16 |
|
|
|
46 |
|
47 |
HEADER_MD = """# Fish Speech
|
48 |
|
49 |
+
## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model.
|
50 |
+
## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
|
51 |
|
52 |
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
53 |
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
|
|
61 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
62 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
63 |
|
64 |
+
The model running in this WebUI is Fish Speech V1.4 Medium.
|
65 |
+
在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium.
|
66 |
"""
|
67 |
|
68 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
|
560 |
parser.add_argument(
|
561 |
"--llama-checkpoint-path",
|
562 |
type=Path,
|
563 |
+
default="checkpoints/fish-speech-1.4",
|
564 |
)
|
565 |
parser.add_argument(
|
566 |
"--decoder-checkpoint-path",
|
567 |
type=Path,
|
568 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
569 |
)
|
570 |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
571 |
parser.add_argument("--device", type=str, default="cuda")
|
fish_speech/configs/firefly_gan_vq.yaml
CHANGED
@@ -22,13 +22,12 @@ head:
|
|
22 |
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
23 |
num_mels: 512
|
24 |
upsample_initial_channel: 512
|
25 |
-
use_template: false
|
26 |
pre_conv_kernel_size: 13
|
27 |
post_conv_kernel_size: 13
|
28 |
quantizer:
|
29 |
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
30 |
input_dim: 512
|
31 |
-
n_groups:
|
32 |
n_codebooks: 1
|
33 |
levels: [8, 5, 5, 5]
|
34 |
-
downsample_factor: [2]
|
|
|
22 |
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
23 |
num_mels: 512
|
24 |
upsample_initial_channel: 512
|
|
|
25 |
pre_conv_kernel_size: 13
|
26 |
post_conv_kernel_size: 13
|
27 |
quantizer:
|
28 |
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
29 |
input_dim: 512
|
30 |
+
n_groups: 8
|
31 |
n_codebooks: 1
|
32 |
levels: [8, 5, 5, 5]
|
33 |
+
downsample_factor: [2, 2]
|
fish_speech/configs/text2semantic_finetune.yaml
CHANGED
@@ -4,7 +4,7 @@ defaults:
|
|
4 |
|
5 |
project: text2semantic_finetune_dual_ar
|
6 |
max_length: 4096
|
7 |
-
pretrained_ckpt_path: checkpoints/fish-speech-1.
|
8 |
|
9 |
# Lightning Trainer
|
10 |
trainer:
|
|
|
4 |
|
5 |
project: text2semantic_finetune_dual_ar
|
6 |
max_length: 4096
|
7 |
+
pretrained_ckpt_path: checkpoints/fish-speech-1.4
|
8 |
|
9 |
# Lightning Trainer
|
10 |
trainer:
|
fish_speech/i18n/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## i18n Folder Attribution
|
2 |
+
|
3 |
+
The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
|
4 |
+
|
5 |
+
### fish_speech/i18n/core.py
|
6 |
+
|
7 |
+
**Related code from RVC:**
|
8 |
+
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
|
9 |
+
|
10 |
+
**Initial commit:**
|
11 |
+
add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
|
12 |
+
|
13 |
+
**Initial author:**
|
14 |
+
[@L4Ph](https://github.com/L4Ph)
|
15 |
+
|
16 |
+
### fish_speech/i18n/scan.py
|
17 |
+
|
18 |
+
**Related code from RVC:**
|
19 |
+
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
|
20 |
+
|
21 |
+
**Initial commit:**
|
22 |
+
File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
|
23 |
+
|
24 |
+
**Initial author:**
|
25 |
+
[@towzeur](https://github.com/towzeur)
|
26 |
+
|
27 |
+
We appreciate the contributions of the RVC project and its authors.
|
fish_speech/i18n/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .core import i18n
|
2 |
+
|
3 |
+
__all__ = ["i18n"]
|
fish_speech/i18n/core.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import locale
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
I18N_FILE_PATH = Path(__file__).parent / "locale"
|
6 |
+
DEFAULT_LANGUAGE = "en_US"
|
7 |
+
|
8 |
+
|
9 |
+
def load_language_list(language):
|
10 |
+
with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
|
11 |
+
language_list = json.load(f)
|
12 |
+
|
13 |
+
return language_list
|
14 |
+
|
15 |
+
|
16 |
+
class I18nAuto:
|
17 |
+
def __init__(self):
|
18 |
+
i18n_file = Path(".locale")
|
19 |
+
|
20 |
+
if i18n_file.exists():
|
21 |
+
with open(i18n_file, "r", encoding="utf-8") as f:
|
22 |
+
language = f.read().strip()
|
23 |
+
else:
|
24 |
+
# getlocale can't identify the system's language ((None, None))
|
25 |
+
language = locale.getdefaultlocale()[0]
|
26 |
+
|
27 |
+
if (I18N_FILE_PATH / f"{language}.json").exists() is False:
|
28 |
+
language = DEFAULT_LANGUAGE
|
29 |
+
|
30 |
+
self.language = language
|
31 |
+
self.language_map = load_language_list(language)
|
32 |
+
|
33 |
+
def __call__(self, key):
|
34 |
+
return self.language_map.get(key, key)
|
35 |
+
|
36 |
+
def __repr__(self):
|
37 |
+
return "Use Language: " + self.language
|
38 |
+
|
39 |
+
|
40 |
+
i18n = I18nAuto()
|
fish_speech/i18n/locale/en_US.json
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
|
5 |
+
"Accumulate Gradient Batches": "Accumulate Gradient Batches",
|
6 |
+
"Add to Processing Area": "Add to Processing Area",
|
7 |
+
"Added path successfully!": "Added path successfully!",
|
8 |
+
"Advanced Config": "Advanced Config",
|
9 |
+
"Base LLAMA Model": "Base LLAMA Model",
|
10 |
+
"Batch Inference": "Batch Inference",
|
11 |
+
"Batch Size": "Batch Size",
|
12 |
+
"Changing with the Model Path": "Changing with the Model Path",
|
13 |
+
"Chinese": "Chinese",
|
14 |
+
"Compile Model": "Compile Model",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
|
16 |
+
"Copy": "Copy",
|
17 |
+
"Data Preprocessing": "Data Preprocessing",
|
18 |
+
"Data Preprocessing Path": "Data Preprocessing Path",
|
19 |
+
"Data Source": "Data Source",
|
20 |
+
"Decoder Model Config": "Decoder Model Config",
|
21 |
+
"Decoder Model Path": "Decoder Model Path",
|
22 |
+
"Disabled": "Disabled",
|
23 |
+
"Enable Reference Audio": "Enable Reference Audio",
|
24 |
+
"English": "English",
|
25 |
+
"Error Message": "Error Message",
|
26 |
+
"File Preprocessing": "File Preprocessing",
|
27 |
+
"Generate": "Generate",
|
28 |
+
"Generated Audio": "Generated Audio",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
|
30 |
+
"Infer interface is closed": "Infer interface is closed",
|
31 |
+
"Inference Configuration": "Inference Configuration",
|
32 |
+
"Inference Server Configuration": "Inference Server Configuration",
|
33 |
+
"Inference Server Error": "Inference Server Error",
|
34 |
+
"Inferring interface is launched at {}": "Inferring interface is launched at {}",
|
35 |
+
"Initial Learning Rate": "Initial Learning Rate",
|
36 |
+
"Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
|
37 |
+
"Input Text": "Input Text",
|
38 |
+
"Invalid path: {}": "Invalid path: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
|
41 |
+
"Japanese": "Japanese",
|
42 |
+
"LLAMA Configuration": "LLAMA Configuration",
|
43 |
+
"LLAMA Model Config": "LLAMA Model Config",
|
44 |
+
"LLAMA Model Path": "LLAMA Model Path",
|
45 |
+
"Labeling Device": "Labeling Device",
|
46 |
+
"LoRA Model to be merged": "LoRA Model to be merged",
|
47 |
+
"Maximum Audio Duration": "Maximum Audio Duration",
|
48 |
+
"Maximum Length per Sample": "Maximum Length per Sample",
|
49 |
+
"Maximum Training Steps": "Maximum Training Steps",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
|
51 |
+
"Merge": "Merge",
|
52 |
+
"Merge LoRA": "Merge LoRA",
|
53 |
+
"Merge successfully": "Merge successfully",
|
54 |
+
"Minimum Audio Duration": "Minimum Audio Duration",
|
55 |
+
"Model Output Path": "Model Output Path",
|
56 |
+
"Model Size": "Model Size",
|
57 |
+
"Move": "Move",
|
58 |
+
"Move files successfully": "Move files successfully",
|
59 |
+
"No audio generated, please check the input text.": "No audio generated, please check the input text.",
|
60 |
+
"No selected options": "No selected options",
|
61 |
+
"Number of Workers": "Number of Workers",
|
62 |
+
"Open Inference Server": "Open Inference Server",
|
63 |
+
"Open Labeler WebUI": "Open Labeler WebUI",
|
64 |
+
"Open Tensorboard": "Open Tensorboard",
|
65 |
+
"Opened labeler in browser": "Opened labeler in browser",
|
66 |
+
"Optional Label Language": "Optional Label Language",
|
67 |
+
"Optional online ver": "Optional online ver",
|
68 |
+
"Output Path": "Output Path",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
|
70 |
+
"Precision": "Precision",
|
71 |
+
"Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
|
72 |
+
"Put your text here.": "Put your text here.",
|
73 |
+
"Reference Audio": "Reference Audio",
|
74 |
+
"Reference Text": "Reference Text",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
|
76 |
+
"Remove Selected Data": "Remove Selected Data",
|
77 |
+
"Removed path successfully!": "Removed path successfully!",
|
78 |
+
"Repetition Penalty": "Repetition Penalty",
|
79 |
+
"Save model every n steps": "Save model every n steps",
|
80 |
+
"Select LLAMA ckpt": "Select LLAMA ckpt",
|
81 |
+
"Select VITS ckpt": "Select VITS ckpt",
|
82 |
+
"Select VQGAN ckpt": "Select VQGAN ckpt",
|
83 |
+
"Select source file processing method": "Select source file processing method",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
|
85 |
+
"Selected: {}": "Selected: {}",
|
86 |
+
"Speaker": "Speaker",
|
87 |
+
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
|
88 |
+
"Start Training": "Start Training",
|
89 |
+
"Streaming Audio": "Streaming Audio",
|
90 |
+
"Streaming Generate": "Streaming Generate",
|
91 |
+
"Tensorboard Host": "Tensorboard Host",
|
92 |
+
"Tensorboard Log Path": "Tensorboard Log Path",
|
93 |
+
"Tensorboard Port": "Tensorboard Port",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard interface is closed",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
|
96 |
+
"Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
|
98 |
+
"Training Configuration": "Training Configuration",
|
99 |
+
"Training Error": "Training Error",
|
100 |
+
"Training stopped": "Training stopped",
|
101 |
+
"Type name of the speaker": "Type name of the speaker",
|
102 |
+
"Type the path or select from the dropdown": "Type the path or select from the dropdown",
|
103 |
+
"Use LoRA": "Use LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
|
105 |
+
"Use filelist": "Use filelist",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
|
107 |
+
"VITS Configuration": "VITS Configuration",
|
108 |
+
"VQGAN Configuration": "VQGAN Configuration",
|
109 |
+
"Validation Batch Size": "Validation Batch Size",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
|
112 |
+
"WebUI Host": "WebUI Host",
|
113 |
+
"WebUI Port": "WebUI Port",
|
114 |
+
"Whisper Model": "Whisper Model",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
|
117 |
+
"latest": "latest",
|
118 |
+
"new": "new",
|
119 |
+
"Realtime Transform Text": "Realtime Transform Text",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
|
121 |
+
"Text Normalization": "Text Normalization"
|
122 |
+
}
|
fish_speech/i18n/locale/es_ES.json
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
|
5 |
+
"Accumulate Gradient Batches": "Acumular lotes de gradientes",
|
6 |
+
"Add to Processing Area": "Agregar al Área de Procesamiento",
|
7 |
+
"Added path successfully!": "¡Ruta agregada exitosamente!",
|
8 |
+
"Advanced Config": "Configuración Avanzada",
|
9 |
+
"Base LLAMA Model": "Modelo Base LLAMA",
|
10 |
+
"Batch Inference": "Inferencia por Lote",
|
11 |
+
"Batch Size": "Tamaño del Lote",
|
12 |
+
"Changing with the Model Path": "Cambiando con la Ruta del Modelo",
|
13 |
+
"Chinese": "Chino",
|
14 |
+
"Compile Model": "Compilar Modelo",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
|
16 |
+
"Copy": "Copiar",
|
17 |
+
"Data Preprocessing": "Preprocesamiento de Datos",
|
18 |
+
"Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
|
19 |
+
"Data Source": "Fuente de Datos",
|
20 |
+
"Decoder Model Config": "Configuración del modelo decodificador",
|
21 |
+
"Decoder Model Path": "Ruta del modelo decodificador",
|
22 |
+
"Disabled": "Desactivado",
|
23 |
+
"Enable Reference Audio": "Habilitar Audio de Referencia",
|
24 |
+
"English": "Inglés",
|
25 |
+
"Error Message": "Mensaje de Error",
|
26 |
+
"File Preprocessing": "Preprocesamiento de Archivos",
|
27 |
+
"Generate": "Generar",
|
28 |
+
"Generated Audio": "Audio Generado",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
|
30 |
+
"Infer interface is closed": "La interfaz de inferencia está cerrada",
|
31 |
+
"Inference Configuration": "Configuración de Inferencia",
|
32 |
+
"Inference Server Configuration": "Configuración del Servidor de Inferencia",
|
33 |
+
"Inference Server Error": "Error del Servidor de Inferencia",
|
34 |
+
"Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
|
35 |
+
"Initial Learning Rate": "Tasa de Aprendizaje Inicial",
|
36 |
+
"Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
|
37 |
+
"Input Text": "Texto de Entrada",
|
38 |
+
"Invalid path: {}": "Ruta inválida: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
|
41 |
+
"Japanese": "Japonés",
|
42 |
+
"LLAMA Configuration": "Configuración de LLAMA",
|
43 |
+
"LLAMA Model Config": "Configuración del Modelo LLAMA",
|
44 |
+
"LLAMA Model Path": "Ruta del Modelo LLAMA",
|
45 |
+
"Labeling Device": "Dispositivo de Etiquetado",
|
46 |
+
"LoRA Model to be merged": "Modelo LoRA a fusionar",
|
47 |
+
"Maximum Audio Duration": "Duración máxima de audio",
|
48 |
+
"Maximum Length per Sample": "Longitud Máxima por Muestra",
|
49 |
+
"Maximum Training Steps": "Pasos Máximos de Entrenamiento",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
|
51 |
+
"Merge": "Fusionar",
|
52 |
+
"Merge LoRA": "Fusionar LoRA",
|
53 |
+
"Merge successfully": "Fusionado exitosamente",
|
54 |
+
"Minimum Audio Duration": "Duración mínima de audio",
|
55 |
+
"Model Output Path": "Ruta de Salida del Modelo",
|
56 |
+
"Model Size": "Tamaño del Modelo",
|
57 |
+
"Move": "Mover",
|
58 |
+
"Move files successfully": "Archivos movidos exitosamente",
|
59 |
+
"No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
|
60 |
+
"No selected options": "No hay opciones seleccionadas",
|
61 |
+
"Number of Workers": "Número de Trabajadores",
|
62 |
+
"Open Inference Server": "Abrir Servidor de Inferencia",
|
63 |
+
"Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
|
64 |
+
"Open Tensorboard": "Abrir Tensorboard",
|
65 |
+
"Opened labeler in browser": "Se abrió el etiquetador en el navegador",
|
66 |
+
"Optional Label Language": "Idioma de Etiquetado Opcional",
|
67 |
+
"Optional online ver": "Ver en línea opcional",
|
68 |
+
"Output Path": "Ruta de Salida",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
|
70 |
+
"Precision": "Precisión",
|
71 |
+
"Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
|
72 |
+
"Put your text here.": "Ponga su texto aquí.",
|
73 |
+
"Reference Audio": "Audio de Referencia",
|
74 |
+
"Reference Text": "Texto de Referencia",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
|
76 |
+
"Remove Selected Data": "Eliminar Datos Seleccionados",
|
77 |
+
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
|
78 |
+
"Repetition Penalty": "Penalización por Repetición",
|
79 |
+
"Save model every n steps": "Guardar modelo cada n pasos",
|
80 |
+
"Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
|
81 |
+
"Select VITS ckpt": "Seleccionar punto de control VITS",
|
82 |
+
"Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
|
83 |
+
"Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
|
85 |
+
"Selected: {}": "Seleccionado: {}",
|
86 |
+
"Speaker": "Hablante",
|
87 |
+
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
|
88 |
+
"Start Training": "Iniciar Entrenamiento",
|
89 |
+
"Streaming Audio": "transmisión de audio",
|
90 |
+
"Streaming Generate": "síntesis en flujo",
|
91 |
+
"Tensorboard Host": "Host de Tensorboard",
|
92 |
+
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
|
93 |
+
"Tensorboard Port": "Puerto de Tensorboard",
|
94 |
+
"Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
|
95 |
+
"Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
|
96 |
+
"Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
|
98 |
+
"Training Configuration": "Configuración de Entrenamiento",
|
99 |
+
"Training Error": "Error de Entrenamiento",
|
100 |
+
"Training stopped": "Entrenamiento detenido",
|
101 |
+
"Type name of the speaker": "Escriba el nombre del hablante",
|
102 |
+
"Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
|
103 |
+
"Use LoRA": "Usar LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
|
105 |
+
"Use filelist": "Usar lista de archivos",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
|
107 |
+
"VITS Configuration": "Configuración de VITS",
|
108 |
+
"VQGAN Configuration": "Configuración de VQGAN",
|
109 |
+
"Validation Batch Size": "Tamaño del Lote de Validación",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
|
112 |
+
"WebUI Host": "Host de WebUI",
|
113 |
+
"WebUI Port": "Puerto de WebUI",
|
114 |
+
"Whisper Model": "Modelo Whisper",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
|
117 |
+
"latest": "más reciente",
|
118 |
+
"new": "nuevo",
|
119 |
+
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
|
121 |
+
"Text Normalization": "Normalización de Texto"
|
122 |
+
}
|
fish_speech/i18n/locale/ja_JP.json
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
|
5 |
+
"Accumulate Gradient Batches": "勾配バッチの累積",
|
6 |
+
"Add to Processing Area": "処理エリアに追加",
|
7 |
+
"Added path successfully!": "パスの追加に成功しました!",
|
8 |
+
"Advanced Config": "詳細設定",
|
9 |
+
"Base LLAMA Model": "基本LLAMAモデル",
|
10 |
+
"Batch Inference": "バッチ推論",
|
11 |
+
"Batch Size": "バッチサイズ",
|
12 |
+
"Changing with the Model Path": "モデルのパスに伴って変化する",
|
13 |
+
"Chinese": "中国語",
|
14 |
+
"Compile Model": "モデルのコンパイル",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
|
16 |
+
"Copy": "コピー",
|
17 |
+
"Data Preprocessing": "データ前処理",
|
18 |
+
"Data Preprocessing Path": "データ前処理パス",
|
19 |
+
"Data Source": "データソース",
|
20 |
+
"Decoder Model Config": "デコーダーモデルの構成",
|
21 |
+
"Decoder Model Path": "デコーダーモデルのパス",
|
22 |
+
"Disabled": "無効",
|
23 |
+
"Enable Reference Audio": "リファレンスオーディオを有効にする",
|
24 |
+
"English": "英語",
|
25 |
+
"Error Message": "エラーメッセージ",
|
26 |
+
"File Preprocessing": "文書前处理",
|
27 |
+
"Generate": "生成",
|
28 |
+
"Generated Audio": "生成されたオーディオ",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
|
30 |
+
"Infer interface is closed": "推論インターフェースが閉じられています",
|
31 |
+
"Inference Configuration": "推論設定",
|
32 |
+
"Inference Server Configuration": "推論サーバー設定",
|
33 |
+
"Inference Server Error": "推論サーバーエラー",
|
34 |
+
"Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
|
35 |
+
"Initial Learning Rate": "初期学習率",
|
36 |
+
"Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
|
37 |
+
"Input Text": "入力テキスト",
|
38 |
+
"Invalid path: {}": "無効なパス: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
|
40 |
+
"Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
|
41 |
+
"Japanese": "日本語",
|
42 |
+
"LLAMA Configuration": "LLAMA設定",
|
43 |
+
"LLAMA Model Config": "LLAMAモデル設定",
|
44 |
+
"LLAMA Model Path": "LLAMAモデルパス",
|
45 |
+
"Labeling Device": "ラベリングデバイス",
|
46 |
+
"LoRA Model to be merged": "マージするLoRAモデル",
|
47 |
+
"Maximum Audio Duration": "最大オーディオの長さ",
|
48 |
+
"Maximum Length per Sample": "サンプルあたりの最大長",
|
49 |
+
"Maximum Training Steps": "最大トレーニングステップ数",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
|
51 |
+
"Merge": "マージ",
|
52 |
+
"Merge LoRA": "LoRAのマージ",
|
53 |
+
"Merge successfully": "マージに成功しました",
|
54 |
+
"Minimum Audio Duration": "最小オーディオの長さ",
|
55 |
+
"Model Output Path": "モデル出力パス",
|
56 |
+
"Model Size": "モデルサイズ",
|
57 |
+
"Move": "移動",
|
58 |
+
"Move files successfully": "ファイルの移動に成功しました",
|
59 |
+
"No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
|
60 |
+
"No selected options": "選択されたオプションはありません",
|
61 |
+
"Number of Workers": "ワーカー数",
|
62 |
+
"Open Inference Server": "推論サーバーを開く",
|
63 |
+
"Open Labeler WebUI": "ラベラーWebUIを開く",
|
64 |
+
"Open Tensorboard": "Tensorboardを開く",
|
65 |
+
"Opened labeler in browser": "ブラウザでラベラーを開きました",
|
66 |
+
"Optional Label Language": "オプションのラベル言語",
|
67 |
+
"Optional online ver": "オプションのオンラインバージョン",
|
68 |
+
"Output Path": "出力パス",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
|
70 |
+
"Precision": "精度",
|
71 |
+
"Probability of applying Speaker Condition": "話者条件を適用する確率",
|
72 |
+
"Put your text here.": "ここにテキストを入力してください。",
|
73 |
+
"Reference Audio": "リファレンスオーディオ",
|
74 |
+
"Reference Text": "リファレンステキスト",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
|
76 |
+
"Remove Selected Data": "選択したデータを削除",
|
77 |
+
"Removed path successfully!": "パスの削除に成功しました!",
|
78 |
+
"Repetition Penalty": "反復ペナルティ",
|
79 |
+
"Save model every n steps": "nステップごとにモデルを保存",
|
80 |
+
"Select LLAMA ckpt": " LLAMA チェックポイントを選択",
|
81 |
+
"Select VITS ckpt": "VITS チェックポイントを選択",
|
82 |
+
"Select VQGAN ckpt": "VQGAN チェックポイントを選択",
|
83 |
+
"Select source file processing method": "ソースファイルの処理方法を選択",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
|
85 |
+
"Selected: {}": "選択済み: {}",
|
86 |
+
"Speaker": "話者",
|
87 |
+
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
|
88 |
+
"Start Training": "トレーニング開始",
|
89 |
+
"Streaming Audio": "ストリーミングオーディオ",
|
90 |
+
"Streaming Generate": "ストリーミング合成",
|
91 |
+
"Tensorboard Host": "Tensorboardホスト",
|
92 |
+
"Tensorboard Log Path": "Tensorboardログパス",
|
93 |
+
"Tensorboard Port": "Tensorboardポート",
|
94 |
+
"Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
|
96 |
+
"Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
|
98 |
+
"Training Configuration": "トレーニング設定",
|
99 |
+
"Training Error": "トレーニングエラー",
|
100 |
+
"Training stopped": "トレーニングが停止しました",
|
101 |
+
"Type name of the speaker": "話者の名前を入力",
|
102 |
+
"Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
|
103 |
+
"Use LoRA": "LoRAを使用",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
|
105 |
+
"Use filelist": "ファイルリストを使用",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
|
107 |
+
"VITS Configuration": "VITS の構成",
|
108 |
+
"VQGAN Configuration": "VQGAN の構成",
|
109 |
+
"Validation Batch Size": "検証バッチサイズ",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
|
112 |
+
"WebUI Host": "WebUIホスト",
|
113 |
+
"WebUI Port": "WebUIポート",
|
114 |
+
"Whisper Model": "Whisperモデル",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
|
117 |
+
"latest": "最新",
|
118 |
+
"new": "新規",
|
119 |
+
"Realtime Transform Text": "リアルタイム変換テキスト",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
|
121 |
+
"Text Normalization": "テキスト正規化"
|
122 |
+
|
123 |
+
}
|
fish_speech/i18n/locale/pt_BR.json
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
|
3 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
|
4 |
+
"Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
|
5 |
+
"Add to Processing Area": "Adicionar à Área de Processamento",
|
6 |
+
"Added path successfully!": "Caminho adicionado com sucesso!",
|
7 |
+
"Advanced Config": "Configuração Avançada",
|
8 |
+
"Base LLAMA Model": "Modelo LLAMA Base",
|
9 |
+
"Batch Inference": "Inferência em Lote",
|
10 |
+
"Batch Size": "Tamanho do Lote",
|
11 |
+
"Changing with the Model Path": "Alterando com o Caminho do Modelo",
|
12 |
+
|
13 |
+
"Compile Model": "Compilar Modelo",
|
14 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
|
15 |
+
"Copy": "Copiar",
|
16 |
+
"Data Preprocessing": "Pré-processamento de Dados",
|
17 |
+
"Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
|
18 |
+
"Data Source": "Fonte de Dados",
|
19 |
+
"Decoder Model Config": "Configuração do Modelo Decodificador",
|
20 |
+
"Decoder Model Path": "Caminho do Modelo Decodificador",
|
21 |
+
"Disabled": "Desativado",
|
22 |
+
"Enable Initial Prompt": "Habilitar Prompt Inicial",
|
23 |
+
"Enable Reference Audio": "Habilitar Áudio de Referência",
|
24 |
+
"English": "Inglês",
|
25 |
+
"Japanese": "Japonês",
|
26 |
+
"Chinese": "Chinês",
|
27 |
+
"Portuguese": "Português",
|
28 |
+
"Spanish": "Espanhol",
|
29 |
+
"Error Message": "Mensagem de Erro",
|
30 |
+
"Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
|
31 |
+
"File Preprocessing": "Pré-processamento de Arquivos",
|
32 |
+
"Generate": "Gerar",
|
33 |
+
"Generated Audio": "Áudio Gerado",
|
34 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
|
35 |
+
"Infer interface is closed": "A interface de inferência foi fechada",
|
36 |
+
"Inference Configuration": "Configuração de Inferência",
|
37 |
+
"Inference Server Configuration": "Configuração do Servidor de Inferência",
|
38 |
+
"Inference Server Error": "Erro do Servidor de Inferência",
|
39 |
+
"Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
|
40 |
+
"Initial Learning Rate": "Taxa de Aprendizagem Inicial",
|
41 |
+
"Initial Prompt": "Prompt Inicial",
|
42 |
+
"Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
|
43 |
+
"Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
|
44 |
+
"Input Text": "Texto de Entrada",
|
45 |
+
"Invalid path: {}": "Caminho inválido: {}",
|
46 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
|
47 |
+
"Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
|
48 |
+
"LLAMA Configuration": "Configuração do LLAMA",
|
49 |
+
"LLAMA Model Config": "Configuração do Modelo LLAMA",
|
50 |
+
"LLAMA Model Path": "Caminho do Modelo LLAMA",
|
51 |
+
"Labeling Device": "Dispositivo de Rotulagem",
|
52 |
+
"LoRA Model to be merged": "Modelo LoRA para mesclagem",
|
53 |
+
"Maximum Length per Sample": "Comprimento Máximo por Amostra",
|
54 |
+
"Maximum Training Steps": "Etapas Máximas de Treinamento",
|
55 |
+
"Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
|
56 |
+
"Merge": "Mesclar",
|
57 |
+
"Merge LoRA": "Mesclar LoRA",
|
58 |
+
"Merge successfully": "Mesclado com sucesso",
|
59 |
+
"Model Output Path": "Caminho de Saída do Modelo",
|
60 |
+
"Model Quantization": "Quantização do Modelo",
|
61 |
+
"Model Size": "Tamanho do Modelo",
|
62 |
+
"Move": "Mover",
|
63 |
+
"Move files successfully": "Arquivos movidos com sucesso",
|
64 |
+
"No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
|
65 |
+
"No selected options": "Nenhuma opção selecionada",
|
66 |
+
"Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
|
67 |
+
"Number of Workers": "Número de Processos",
|
68 |
+
"Open Inference Server": "Abrir Servidor de Inferência",
|
69 |
+
"Open Labeler WebUI": "Abrir WebUI de Rotulagem",
|
70 |
+
"Open Tensorboard": "Abrir Tensorboard",
|
71 |
+
"Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
|
72 |
+
"Optional Label Language": "Idioma do Rótulo (Opcional)",
|
73 |
+
"Optional online ver": "Versão online (opcional)",
|
74 |
+
"Output Path": "Caminho de Saída",
|
75 |
+
"Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
|
76 |
+
"Post-quantification Precision": "Precisão Pós-quantização",
|
77 |
+
"Precision": "Precisão",
|
78 |
+
"Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
|
79 |
+
"Put your text here.": "Insira seu texto aqui.",
|
80 |
+
"Quantify": "Quantizar",
|
81 |
+
"Quantify successfully": "Quantizado com sucesso",
|
82 |
+
"Realtime Transform Text": "Transformar Texto em Tempo Real",
|
83 |
+
"Reference Audio": "Áudio de Referência",
|
84 |
+
"Reference Text": "Texto de Referência",
|
85 |
+
"warning": "Aviso",
|
86 |
+
"Pre-processing begins...": "O pré-processamento começou!",
|
87 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
|
88 |
+
"Remove Selected Data": "Remover Dados Selecionados",
|
89 |
+
"Removed path successfully!": "Caminho removido com sucesso!",
|
90 |
+
"Repetition Penalty": "Penalidade de Repetição",
|
91 |
+
"Save model every n steps": "Salvar modelo a cada n etapas",
|
92 |
+
"Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
|
93 |
+
"Select source file processing method": "Escolha como processar o arquivo de origem",
|
94 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
|
95 |
+
"Selected: {}": "Selecionado: {}",
|
96 |
+
"Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
|
97 |
+
"Start Training": "Iniciar Treinamento",
|
98 |
+
"Streaming Audio": "Áudio em Streaming",
|
99 |
+
"Streaming Generate": "Geração em Streaming",
|
100 |
+
"Tensorboard Host": "Host do Tensorboard",
|
101 |
+
"Tensorboard Log Path": "Caminho de Log do Tensorboard",
|
102 |
+
"Tensorboard Port": "Porta do Tensorboard",
|
103 |
+
"Tensorboard interface is closed": "A interface do Tensorboard está fechada",
|
104 |
+
"Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
|
105 |
+
"Text Normalization": "Normalização de Texto",
|
106 |
+
"Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
|
107 |
+
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
|
108 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
|
109 |
+
"Training Configuration": "Configuração de Treinamento",
|
110 |
+
"Training Error": "Erro de Treinamento",
|
111 |
+
"Training stopped": "Treinamento interrompido!",
|
112 |
+
"Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
|
113 |
+
"Use LoRA": "Usar LoRA",
|
114 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
|
115 |
+
"Use filelist": "Usar lista de arquivos",
|
116 |
+
"VQGAN Configuration": "Configuração do VQGAN",
|
117 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
|
118 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
|
119 |
+
"WebUI Host": "Host da WebUI",
|
120 |
+
"WebUI Port": "Porta da WebUI",
|
121 |
+
"Whisper Model": "Modelo Whisper",
|
122 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
|
123 |
+
"auto": "automático",
|
124 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
|
125 |
+
"latest": "mais recente",
|
126 |
+
"new": "novo",
|
127 |
+
"This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
|
128 |
+
"You don't need to train this model!": "Não é necessário treinar este modelo!",
|
129 |
+
"Yes": "Sim",
|
130 |
+
"No": "Não",
|
131 |
+
"version:": "versão:",
|
132 |
+
"author:": "autor:"
|
133 |
+
}
|
fish_speech/i18n/locale/zh_CN.json
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
|
5 |
+
"Accumulate Gradient Batches": "梯度累积批次",
|
6 |
+
"Add to Processing Area": "加入处理区",
|
7 |
+
"Added path successfully!": "添加路径成功!",
|
8 |
+
"Advanced Config": "高级参数",
|
9 |
+
"Base LLAMA Model": "基础 LLAMA 模型",
|
10 |
+
"Batch Inference": "批量推理",
|
11 |
+
"Batch Size": "批次大小",
|
12 |
+
"Changing with the Model Path": "随模型路径变化",
|
13 |
+
"Chinese": "中文",
|
14 |
+
"Compile Model": "编译模型",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
|
16 |
+
"Copy": "复制",
|
17 |
+
"Data Preprocessing": "数据预处理",
|
18 |
+
"Data Preprocessing Path": "数据预处理路径",
|
19 |
+
"Data Source": "数据源",
|
20 |
+
"Decoder Model Config": "解码器模型配置",
|
21 |
+
"Decoder Model Path": "解码器模型路径",
|
22 |
+
"Disabled": "禁用",
|
23 |
+
"Enable Reference Audio": "启用参考音频",
|
24 |
+
"English": "英文",
|
25 |
+
"Error Message": "错误信息",
|
26 |
+
"File Preprocessing": "文件预处理",
|
27 |
+
"Generate": "生成",
|
28 |
+
"Generated Audio": "音频",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
|
30 |
+
"Infer interface is closed": "推理界面已关闭",
|
31 |
+
"Inference Configuration": "推理配置",
|
32 |
+
"Inference Server Configuration": "推理服务器配置",
|
33 |
+
"Inference Server Error": "推理服务器错误",
|
34 |
+
"Inferring interface is launched at {}": "推理界面已在 {} 上启动",
|
35 |
+
"Initial Learning Rate": "初始学习率",
|
36 |
+
"Input Audio & Source Path for Transcription": "输入音频和转录源路径",
|
37 |
+
"Input Text": "输入文本",
|
38 |
+
"Invalid path: {}": "无效路径: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
|
41 |
+
"Japanese": "日文",
|
42 |
+
"LLAMA Configuration": "LLAMA 配置",
|
43 |
+
"LLAMA Model Config": "LLAMA 模型配置",
|
44 |
+
"LLAMA Model Path": "LLAMA 模型路径",
|
45 |
+
"Labeling Device": "标注加速设备",
|
46 |
+
"LoRA Model to be merged": "要合并的 LoRA 模型",
|
47 |
+
"Maximum Audio Duration": "最大音频时长",
|
48 |
+
"Maximum Length per Sample": "每个样本的最大长度",
|
49 |
+
"Maximum Training Steps": "最大训练步数",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
|
51 |
+
"Merge": "合并",
|
52 |
+
"Merge LoRA": "合并 LoRA",
|
53 |
+
"Merge successfully": "合并成功",
|
54 |
+
"Minimum Audio Duration": "最小音频时长",
|
55 |
+
"Model Output Path": "模型输出路径",
|
56 |
+
"Model Size": "模型规模",
|
57 |
+
"Move": "移动",
|
58 |
+
"Move files successfully": "移动文件成功",
|
59 |
+
"No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
|
60 |
+
"No selected options": "没有选择的选项",
|
61 |
+
"Number of Workers": "数据加载进程数",
|
62 |
+
"Open Inference Server": "打开推理服务器",
|
63 |
+
"Open Labeler WebUI": "打开标注工具",
|
64 |
+
"Open Tensorboard": "打开 Tensorboard",
|
65 |
+
"Opened labeler in browser": "在浏览器中打开标注工具",
|
66 |
+
"Optional Label Language": "[可选] 标注语言",
|
67 |
+
"Optional online ver": "[可选] 使用在线版",
|
68 |
+
"Output Path": "输出路径",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
|
70 |
+
"Precision": "精度",
|
71 |
+
"Probability of applying Speaker Condition": "应用说话人条件的概率",
|
72 |
+
"Put your text here.": "在此处输入文本.",
|
73 |
+
"Reference Audio": "参考音频",
|
74 |
+
"Reference Text": "参考文本",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
|
76 |
+
"Remove Selected Data": "移除选中数据",
|
77 |
+
"Removed path successfully!": "移除路径成功!",
|
78 |
+
"Repetition Penalty": "重复惩罚",
|
79 |
+
"Save model every n steps": "每 n 步保存模型",
|
80 |
+
"Select LLAMA ckpt": "选择 LLAMA 检查点",
|
81 |
+
"Select VITS ckpt": "选择 VITS 检查点",
|
82 |
+
"Select VQGAN ckpt": "选择 VQGAN 检查点",
|
83 |
+
"Select source file processing method": "选择源文件处理方法",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
|
85 |
+
"Selected: {}": "已选择: {}",
|
86 |
+
"Speaker": "说话人",
|
87 |
+
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
|
88 |
+
"Start Training": "开始训练",
|
89 |
+
"Streaming Audio": "流式音频",
|
90 |
+
"Streaming Generate": "流式合成",
|
91 |
+
"Tensorboard Host": "Tensorboard 监听地址",
|
92 |
+
"Tensorboard Log Path": "Tensorboard 日志路径",
|
93 |
+
"Tensorboard Port": "Tensorboard 端口",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard 界面已关闭",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
|
96 |
+
"Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
|
98 |
+
"Training Configuration": "训练配置",
|
99 |
+
"Training Error": "训练错误",
|
100 |
+
"Training stopped": "训练已停止",
|
101 |
+
"Type name of the speaker": "输入说话人的名称",
|
102 |
+
"Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
|
103 |
+
"Use LoRA": "使用 LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
|
105 |
+
"Use filelist": "使用文件列表",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
|
107 |
+
"VITS Configuration": "VITS 配置",
|
108 |
+
"VQGAN Configuration": "VQGAN 配置",
|
109 |
+
"Validation Batch Size": "验证批次大小",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
|
112 |
+
"WebUI Host": "WebUI 监听地址",
|
113 |
+
"WebUI Port": "WebUI 端口",
|
114 |
+
"Whisper Model": "Whisper 模型",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
|
117 |
+
"latest": "最近的检查点",
|
118 |
+
"new": "创建新的检查点",
|
119 |
+
"Realtime Transform Text": "实时规范化文本",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
|
121 |
+
"Text Normalization": "文本规范化"
|
122 |
+
}
|
fish_speech/i18n/scan.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
from collections import OrderedDict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
|
10 |
+
|
11 |
+
|
12 |
+
def extract_i18n_strings(node):
|
13 |
+
i18n_strings = []
|
14 |
+
|
15 |
+
if (
|
16 |
+
isinstance(node, ast.Call)
|
17 |
+
and isinstance(node.func, ast.Name)
|
18 |
+
and node.func.id == "i18n"
|
19 |
+
):
|
20 |
+
for arg in node.args:
|
21 |
+
if isinstance(arg, ast.Str):
|
22 |
+
i18n_strings.append(arg.s)
|
23 |
+
|
24 |
+
for child_node in ast.iter_child_nodes(node):
|
25 |
+
i18n_strings.extend(extract_i18n_strings(child_node))
|
26 |
+
|
27 |
+
return i18n_strings
|
28 |
+
|
29 |
+
|
30 |
+
# scan the directory for all .py files (recursively)
|
31 |
+
# for each file, parse the code into an AST
|
32 |
+
# for each AST, extract the i18n strings
|
33 |
+
|
34 |
+
strings = []
|
35 |
+
folders = ["fish_speech", "tools"]
|
36 |
+
# for filename in glob.iglob("**/*.py", recursive=True):
|
37 |
+
for folder in folders:
|
38 |
+
for f in Path(folder).rglob("*.py"):
|
39 |
+
code = f.read_text(encoding="utf-8")
|
40 |
+
if "i18n(" in code:
|
41 |
+
tree = ast.parse(code)
|
42 |
+
i18n_strings = extract_i18n_strings(tree)
|
43 |
+
logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
|
44 |
+
strings.extend(i18n_strings)
|
45 |
+
|
46 |
+
code_keys = set(strings)
|
47 |
+
logger.info(f"Total unique: {len(code_keys)}")
|
48 |
+
|
49 |
+
|
50 |
+
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
51 |
+
with open(standard_file, "r", encoding="utf-8") as f:
|
52 |
+
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
53 |
+
standard_keys = set(standard_data.keys())
|
54 |
+
|
55 |
+
# Define the standard file name
|
56 |
+
unused_keys = standard_keys - code_keys
|
57 |
+
logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
|
58 |
+
for unused_key in unused_keys:
|
59 |
+
logger.info(f"\t{unused_key}")
|
60 |
+
|
61 |
+
missing_keys = code_keys - standard_keys
|
62 |
+
logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
|
63 |
+
for missing_key in missing_keys:
|
64 |
+
logger.info(f"\t{missing_key}")
|
65 |
+
|
66 |
+
code_keys_dict = OrderedDict()
|
67 |
+
for s in strings:
|
68 |
+
code_keys_dict[s] = s
|
69 |
+
|
70 |
+
# write back
|
71 |
+
with open(standard_file, "w", encoding="utf-8") as f:
|
72 |
+
json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
|
73 |
+
f.write("\n")
|
74 |
+
|
75 |
+
logger.info(f"Updated {standard_file}")
|
76 |
+
|
77 |
+
|
78 |
+
# Define the standard file name
|
79 |
+
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
80 |
+
|
81 |
+
# Find all JSON files in the directory
|
82 |
+
dir_path = I18N_FILE_PATH
|
83 |
+
languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
|
84 |
+
|
85 |
+
# Load the standard file
|
86 |
+
with open(standard_file, "r", encoding="utf-8") as f:
|
87 |
+
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
88 |
+
|
89 |
+
# Loop through each language file
|
90 |
+
for lang_file in languages:
|
91 |
+
# Load the language file
|
92 |
+
with open(lang_file, "r", encoding="utf-8") as f:
|
93 |
+
lang_data = json.load(f, object_pairs_hook=OrderedDict)
|
94 |
+
|
95 |
+
# Find the difference between the language file and the standard file
|
96 |
+
diff = set(standard_data.keys()) - set(lang_data.keys())
|
97 |
+
|
98 |
+
miss = set(lang_data.keys()) - set(standard_data.keys())
|
99 |
+
|
100 |
+
# Add any missing keys to the language file
|
101 |
+
for key in diff:
|
102 |
+
lang_data[key] = "#!" + key
|
103 |
+
logger.info(f"Added missing key: {key} to {lang_file}")
|
104 |
+
|
105 |
+
# Del any extra keys to the language file
|
106 |
+
for key in miss:
|
107 |
+
del lang_data[key]
|
108 |
+
logger.info(f"Del extra key: {key} from {lang_file}")
|
109 |
+
|
110 |
+
# Sort the keys of the language file to match the order of the standard file
|
111 |
+
lang_data = OrderedDict(
|
112 |
+
sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
|
113 |
+
)
|
114 |
+
|
115 |
+
# Save the updated language file
|
116 |
+
with open(lang_file, "w", encoding="utf-8") as f:
|
117 |
+
json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
|
118 |
+
f.write("\n")
|
119 |
+
|
120 |
+
logger.info(f"Updated {lang_file}")
|
121 |
+
|
122 |
+
logger.info("Done")
|
fish_speech/models/text2semantic/llama.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import json
|
2 |
import math
|
|
|
3 |
from dataclasses import dataclass
|
4 |
from pathlib import Path
|
5 |
from typing import Optional
|
@@ -370,6 +371,32 @@ class BaseTransformer(nn.Module):
|
|
370 |
weights = torch.load(
|
371 |
Path(path) / "model.pth", map_location="cpu", mmap=True
|
372 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
err = model.load_state_dict(weights, strict=False, assign=True)
|
374 |
log.info(f"Loaded weights with error: {err}")
|
375 |
|
|
|
1 |
import json
|
2 |
import math
|
3 |
+
from collections import OrderedDict
|
4 |
from dataclasses import dataclass
|
5 |
from pathlib import Path
|
6 |
from typing import Optional
|
|
|
371 |
weights = torch.load(
|
372 |
Path(path) / "model.pth", map_location="cpu", mmap=True
|
373 |
)
|
374 |
+
|
375 |
+
if "state_dict" in weights:
|
376 |
+
logger.warning(
|
377 |
+
"Using a TextToSemantic LightningModule checkpoint, "
|
378 |
+
"please make sure it is a full model, not a LoRA model."
|
379 |
+
)
|
380 |
+
weights = weights["state_dict"]
|
381 |
+
|
382 |
+
if next(iter(weights.keys())).startswith("model."):
|
383 |
+
logger.info(
|
384 |
+
f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
|
385 |
+
)
|
386 |
+
new_weights = OrderedDict()
|
387 |
+
for k, v in weights.items():
|
388 |
+
new_weights[k.replace("model.", "")] = v
|
389 |
+
weights = new_weights
|
390 |
+
|
391 |
+
# Verify the name and shape of parameters since strict=False in load_state_dict.
|
392 |
+
for k, v in model.named_parameters():
|
393 |
+
if k not in weights:
|
394 |
+
logger.warning(f"No weight for {k}")
|
395 |
+
elif v.shape != weights[k].shape:
|
396 |
+
logger.warning(
|
397 |
+
f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
|
398 |
+
)
|
399 |
+
|
400 |
err = model.load_state_dict(weights, strict=False, assign=True)
|
401 |
log.info(f"Loaded weights with error: {err}")
|
402 |
|
fish_speech/models/vqgan/__init__.py
CHANGED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .lit_module import VQGAN
|
2 |
-
|
3 |
-
__all__ = ["VQGAN"]
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/modules/firefly.py
CHANGED
@@ -1,25 +1,26 @@
|
|
1 |
-
# A inference only version of the FireflyGAN model
|
2 |
-
|
3 |
import math
|
4 |
from functools import partial
|
5 |
from math import prod
|
6 |
from typing import Callable
|
7 |
|
8 |
-
import numpy as np
|
9 |
import torch
|
10 |
import torch.nn.functional as F
|
11 |
from torch import nn
|
12 |
-
from torch.nn import Conv1d
|
13 |
from torch.nn.utils.parametrizations import weight_norm
|
14 |
from torch.nn.utils.parametrize import remove_parametrizations
|
15 |
from torch.utils.checkpoint import checkpoint
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
def init_weights(m, mean=0.0, std=0.01):
|
21 |
classname = m.__class__.__name__
|
22 |
-
if classname.find("
|
23 |
m.weight.data.normal_(mean, std)
|
24 |
|
25 |
|
@@ -27,78 +28,141 @@ def get_padding(kernel_size, dilation=1):
|
|
27 |
return (kernel_size * dilation - dilation) // 2
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class ResBlock1(torch.nn.Module):
|
31 |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
32 |
super().__init__()
|
33 |
|
34 |
self.convs1 = nn.ModuleList(
|
35 |
[
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
),
|
46 |
-
weight_norm(
|
47 |
-
Conv1d(
|
48 |
-
channels,
|
49 |
-
channels,
|
50 |
-
kernel_size,
|
51 |
-
1,
|
52 |
-
dilation=dilation[1],
|
53 |
-
padding=get_padding(kernel_size, dilation[1]),
|
54 |
-
)
|
55 |
-
),
|
56 |
-
weight_norm(
|
57 |
-
Conv1d(
|
58 |
-
channels,
|
59 |
-
channels,
|
60 |
-
kernel_size,
|
61 |
-
1,
|
62 |
-
dilation=dilation[2],
|
63 |
-
padding=get_padding(kernel_size, dilation[2]),
|
64 |
-
)
|
65 |
-
),
|
66 |
]
|
67 |
)
|
68 |
self.convs1.apply(init_weights)
|
69 |
|
70 |
self.convs2 = nn.ModuleList(
|
71 |
[
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
),
|
82 |
-
weight_norm(
|
83 |
-
Conv1d(
|
84 |
-
channels,
|
85 |
-
channels,
|
86 |
-
kernel_size,
|
87 |
-
1,
|
88 |
-
dilation=1,
|
89 |
-
padding=get_padding(kernel_size, 1),
|
90 |
-
)
|
91 |
-
),
|
92 |
-
weight_norm(
|
93 |
-
Conv1d(
|
94 |
-
channels,
|
95 |
-
channels,
|
96 |
-
kernel_size,
|
97 |
-
1,
|
98 |
-
dilation=1,
|
99 |
-
padding=get_padding(kernel_size, 1),
|
100 |
-
)
|
101 |
-
),
|
102 |
]
|
103 |
)
|
104 |
self.convs2.apply(init_weights)
|
@@ -119,7 +183,7 @@ class ResBlock1(torch.nn.Module):
|
|
119 |
remove_parametrizations(conv, tensor_name="weight")
|
120 |
|
121 |
|
122 |
-
class
|
123 |
def __init__(
|
124 |
self,
|
125 |
channels: int,
|
@@ -153,7 +217,6 @@ class HiFiGANGenerator(nn.Module):
|
|
153 |
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
154 |
num_mels: int = 128,
|
155 |
upsample_initial_channel: int = 512,
|
156 |
-
use_template: bool = True,
|
157 |
pre_conv_kernel_size: int = 7,
|
158 |
post_conv_kernel_size: int = 7,
|
159 |
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
@@ -164,84 +227,50 @@ class HiFiGANGenerator(nn.Module):
|
|
164 |
prod(upsample_rates) == hop_length
|
165 |
), f"hop_length must be {prod(upsample_rates)}"
|
166 |
|
167 |
-
self.conv_pre =
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
padding=get_padding(pre_conv_kernel_size),
|
174 |
-
)
|
175 |
-
)
|
176 |
|
177 |
self.num_upsamples = len(upsample_rates)
|
178 |
self.num_kernels = len(resblock_kernel_sizes)
|
179 |
|
180 |
self.noise_convs = nn.ModuleList()
|
181 |
-
self.use_template = use_template
|
182 |
self.ups = nn.ModuleList()
|
183 |
|
184 |
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
185 |
-
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
186 |
self.ups.append(
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
padding=(k - u) // 2,
|
194 |
-
)
|
195 |
-
)
|
196 |
)
|
197 |
|
198 |
-
if not use_template:
|
199 |
-
continue
|
200 |
-
|
201 |
-
if i + 1 < len(upsample_rates):
|
202 |
-
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
203 |
-
self.noise_convs.append(
|
204 |
-
Conv1d(
|
205 |
-
1,
|
206 |
-
c_cur,
|
207 |
-
kernel_size=stride_f0 * 2,
|
208 |
-
stride=stride_f0,
|
209 |
-
padding=stride_f0 // 2,
|
210 |
-
)
|
211 |
-
)
|
212 |
-
else:
|
213 |
-
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
214 |
-
|
215 |
self.resblocks = nn.ModuleList()
|
216 |
for i in range(len(self.ups)):
|
217 |
ch = upsample_initial_channel // (2 ** (i + 1))
|
218 |
self.resblocks.append(
|
219 |
-
|
220 |
)
|
221 |
|
222 |
self.activation_post = post_activation()
|
223 |
-
self.conv_post =
|
224 |
-
|
225 |
-
|
226 |
-
1,
|
227 |
-
post_conv_kernel_size,
|
228 |
-
1,
|
229 |
-
padding=get_padding(post_conv_kernel_size),
|
230 |
-
)
|
231 |
-
)
|
232 |
self.ups.apply(init_weights)
|
233 |
self.conv_post.apply(init_weights)
|
234 |
|
235 |
-
def forward(self, x
|
236 |
x = self.conv_pre(x)
|
237 |
|
238 |
for i in range(self.num_upsamples):
|
239 |
x = F.silu(x, inplace=True)
|
240 |
x = self.ups[i](x)
|
241 |
|
242 |
-
if self.use_template:
|
243 |
-
x = x + self.noise_convs[i](template)
|
244 |
-
|
245 |
if self.training and self.checkpointing:
|
246 |
x = checkpoint(
|
247 |
self.resblocks[i],
|
@@ -364,11 +393,11 @@ class ConvNeXtBlock(nn.Module):
|
|
364 |
):
|
365 |
super().__init__()
|
366 |
|
367 |
-
self.dwconv =
|
368 |
dim,
|
369 |
dim,
|
370 |
kernel_size=kernel_size,
|
371 |
-
padding=int(dilation * (kernel_size - 1) / 2),
|
372 |
groups=dim,
|
373 |
) # depthwise conv
|
374 |
self.norm = LayerNorm(dim, eps=1e-6)
|
@@ -421,12 +450,13 @@ class ConvNeXtEncoder(nn.Module):
|
|
421 |
|
422 |
self.downsample_layers = nn.ModuleList()
|
423 |
stem = nn.Sequential(
|
424 |
-
|
425 |
input_channels,
|
426 |
dims[0],
|
427 |
-
kernel_size=
|
428 |
-
padding=
|
429 |
-
padding_mode="
|
|
|
430 |
),
|
431 |
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
432 |
)
|
@@ -491,6 +521,7 @@ class FireflyArchitecture(nn.Module):
|
|
491 |
self.head = head
|
492 |
self.quantizer = quantizer
|
493 |
self.spec_transform = spec_transform
|
|
|
494 |
|
495 |
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
|
496 |
if self.spec_transform is not None:
|
@@ -528,25 +559,30 @@ class FireflyArchitecture(nn.Module):
|
|
528 |
|
529 |
# Encode
|
530 |
encoded_features = self.backbone(mels) * mel_masks_float_conv
|
531 |
-
feature_lengths = mel_lengths //
|
532 |
|
533 |
return self.quantizer.encode(encoded_features), feature_lengths
|
534 |
|
535 |
def decode(self, indices, feature_lengths) -> torch.Tensor:
|
536 |
-
|
537 |
-
|
|
|
|
|
538 |
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
|
|
|
|
539 |
|
540 |
audio_masks = sequence_mask(
|
541 |
-
|
542 |
-
indices.shape[2] *
|
543 |
)
|
544 |
audio_masks_float_conv = audio_masks[:, None, :].float()
|
545 |
|
546 |
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
547 |
x = self.head(z) * audio_masks_float_conv
|
548 |
|
549 |
-
return x
|
550 |
|
551 |
def remove_parametrizations(self):
|
552 |
if hasattr(self.backbone, "remove_parametrizations"):
|
@@ -558,68 +594,3 @@ class FireflyArchitecture(nn.Module):
|
|
558 |
@property
|
559 |
def device(self):
|
560 |
return next(self.parameters()).device
|
561 |
-
|
562 |
-
|
563 |
-
class FireflyBase(nn.Module):
|
564 |
-
def __init__(self, ckpt_path: str = None, pretrained: bool = True):
|
565 |
-
super().__init__()
|
566 |
-
|
567 |
-
self.backbone = ConvNeXtEncoder(
|
568 |
-
input_channels=128,
|
569 |
-
depths=[3, 3, 9, 3],
|
570 |
-
dims=[128, 256, 384, 512],
|
571 |
-
drop_path_rate=0.2,
|
572 |
-
kernel_size=7,
|
573 |
-
)
|
574 |
-
|
575 |
-
self.head = HiFiGANGenerator(
|
576 |
-
hop_length=512,
|
577 |
-
upsample_rates=[8, 8, 2, 2, 2],
|
578 |
-
upsample_kernel_sizes=[16, 16, 4, 4, 4],
|
579 |
-
resblock_kernel_sizes=[3, 7, 11],
|
580 |
-
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
581 |
-
num_mels=512,
|
582 |
-
upsample_initial_channel=512,
|
583 |
-
use_template=False,
|
584 |
-
pre_conv_kernel_size=13,
|
585 |
-
post_conv_kernel_size=13,
|
586 |
-
)
|
587 |
-
|
588 |
-
if ckpt_path is not None:
|
589 |
-
state_dict = torch.load(ckpt_path, map_location="cpu")
|
590 |
-
elif pretrained:
|
591 |
-
state_dict = torch.hub.load_state_dict_from_url(
|
592 |
-
"https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
|
593 |
-
map_location="cpu",
|
594 |
-
model_dir="checkpoints",
|
595 |
-
)
|
596 |
-
|
597 |
-
if "state_dict" in state_dict:
|
598 |
-
state_dict = state_dict["state_dict"]
|
599 |
-
|
600 |
-
if any("generator." in k for k in state_dict):
|
601 |
-
state_dict = {
|
602 |
-
k.replace("generator.", ""): v
|
603 |
-
for k, v in state_dict.items()
|
604 |
-
if "generator." in k
|
605 |
-
}
|
606 |
-
|
607 |
-
self.load_state_dict(state_dict, strict=True)
|
608 |
-
self.head.remove_parametrizations()
|
609 |
-
|
610 |
-
@torch.no_grad()
|
611 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
612 |
-
x = self.backbone(x)
|
613 |
-
x = self.head(x)
|
614 |
-
if x.ndim == 2:
|
615 |
-
x = x[:, None, :]
|
616 |
-
return x
|
617 |
-
|
618 |
-
|
619 |
-
if __name__ == "__main__":
|
620 |
-
model = FireflyBase()
|
621 |
-
model.eval()
|
622 |
-
x = torch.randn(1, 128, 128)
|
623 |
-
with torch.no_grad():
|
624 |
-
y = model(x)
|
625 |
-
print(y.shape)
|
|
|
|
|
|
|
1 |
import math
|
2 |
from functools import partial
|
3 |
from math import prod
|
4 |
from typing import Callable
|
5 |
|
|
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
from torch import nn
|
|
|
9 |
from torch.nn.utils.parametrizations import weight_norm
|
10 |
from torch.nn.utils.parametrize import remove_parametrizations
|
11 |
from torch.utils.checkpoint import checkpoint
|
12 |
|
13 |
+
|
14 |
+
def sequence_mask(length, max_length=None):
|
15 |
+
if max_length is None:
|
16 |
+
max_length = length.max()
|
17 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
18 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
19 |
|
20 |
|
21 |
def init_weights(m, mean=0.0, std=0.01):
|
22 |
classname = m.__class__.__name__
|
23 |
+
if classname.find("Conv1D") != -1:
|
24 |
m.weight.data.normal_(mean, std)
|
25 |
|
26 |
|
|
|
28 |
return (kernel_size * dilation - dilation) // 2
|
29 |
|
30 |
|
31 |
+
def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
|
32 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
33 |
+
padding_left, padding_right = paddings
|
34 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
35 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
36 |
+
end = x.shape[-1] - padding_right
|
37 |
+
return x[..., padding_left:end]
|
38 |
+
|
39 |
+
|
40 |
+
def get_extra_padding_for_conv1d(
|
41 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
42 |
+
) -> int:
|
43 |
+
"""See `pad_for_conv1d`."""
|
44 |
+
length = x.shape[-1]
|
45 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
46 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
47 |
+
return ideal_length - length
|
48 |
+
|
49 |
+
|
50 |
+
def pad1d(
|
51 |
+
x: torch.Tensor,
|
52 |
+
paddings: tuple[int, int],
|
53 |
+
mode: str = "zeros",
|
54 |
+
value: float = 0.0,
|
55 |
+
):
|
56 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
57 |
+
If this is the case, we insert extra 0 padding to the right
|
58 |
+
before the reflection happen.
|
59 |
+
"""
|
60 |
+
length = x.shape[-1]
|
61 |
+
padding_left, padding_right = paddings
|
62 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
63 |
+
if mode == "reflect":
|
64 |
+
max_pad = max(padding_left, padding_right)
|
65 |
+
extra_pad = 0
|
66 |
+
if length <= max_pad:
|
67 |
+
extra_pad = max_pad - length + 1
|
68 |
+
x = F.pad(x, (0, extra_pad))
|
69 |
+
padded = F.pad(x, paddings, mode, value)
|
70 |
+
end = padded.shape[-1] - extra_pad
|
71 |
+
return padded[..., :end]
|
72 |
+
else:
|
73 |
+
return F.pad(x, paddings, mode, value)
|
74 |
+
|
75 |
+
|
76 |
+
class FishConvNet(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
|
79 |
+
):
|
80 |
+
super(FishConvNet, self).__init__()
|
81 |
+
self.conv = nn.Conv1d(
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size,
|
85 |
+
stride=stride,
|
86 |
+
dilation=dilation,
|
87 |
+
groups=groups,
|
88 |
+
)
|
89 |
+
self.stride = stride
|
90 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
91 |
+
self.dilation = dilation
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
pad = self.kernel_size - self.stride
|
95 |
+
extra_padding = get_extra_padding_for_conv1d(
|
96 |
+
x, self.kernel_size, self.stride, pad
|
97 |
+
)
|
98 |
+
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
99 |
+
return self.conv(x).contiguous()
|
100 |
+
|
101 |
+
def weight_norm(self, name="weight", dim=0):
|
102 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
103 |
+
return self
|
104 |
+
|
105 |
+
def remove_weight_norm(self):
|
106 |
+
self.conv = remove_parametrizations(self.conv)
|
107 |
+
return self
|
108 |
+
|
109 |
+
|
110 |
+
class FishTransConvNet(nn.Module):
|
111 |
+
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
|
112 |
+
super(FishTransConvNet, self).__init__()
|
113 |
+
self.conv = nn.ConvTranspose1d(
|
114 |
+
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
115 |
+
)
|
116 |
+
self.stride = stride
|
117 |
+
self.kernel_size = kernel_size
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.conv(x)
|
121 |
+
pad = self.kernel_size - self.stride
|
122 |
+
padding_right = math.ceil(pad)
|
123 |
+
padding_left = pad - padding_right
|
124 |
+
x = unpad1d(x, (padding_left, padding_right))
|
125 |
+
return x.contiguous()
|
126 |
+
|
127 |
+
def weight_norm(self, name="weight", dim=0):
|
128 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
129 |
+
return self
|
130 |
+
|
131 |
+
def remove_weight_norm(self):
|
132 |
+
self.conv = remove_parametrizations(self.conv)
|
133 |
+
return self
|
134 |
+
|
135 |
+
|
136 |
class ResBlock1(torch.nn.Module):
|
137 |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
138 |
super().__init__()
|
139 |
|
140 |
self.convs1 = nn.ModuleList(
|
141 |
[
|
142 |
+
FishConvNet(
|
143 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
144 |
+
).weight_norm(),
|
145 |
+
FishConvNet(
|
146 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
147 |
+
).weight_norm(),
|
148 |
+
FishConvNet(
|
149 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
150 |
+
).weight_norm(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
]
|
152 |
)
|
153 |
self.convs1.apply(init_weights)
|
154 |
|
155 |
self.convs2 = nn.ModuleList(
|
156 |
[
|
157 |
+
FishConvNet(
|
158 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
159 |
+
).weight_norm(),
|
160 |
+
FishConvNet(
|
161 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
162 |
+
).weight_norm(),
|
163 |
+
FishConvNet(
|
164 |
+
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
165 |
+
).weight_norm(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
]
|
167 |
)
|
168 |
self.convs2.apply(init_weights)
|
|
|
183 |
remove_parametrizations(conv, tensor_name="weight")
|
184 |
|
185 |
|
186 |
+
class ParallelBlock(nn.Module):
|
187 |
def __init__(
|
188 |
self,
|
189 |
channels: int,
|
|
|
217 |
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
218 |
num_mels: int = 128,
|
219 |
upsample_initial_channel: int = 512,
|
|
|
220 |
pre_conv_kernel_size: int = 7,
|
221 |
post_conv_kernel_size: int = 7,
|
222 |
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
|
|
227 |
prod(upsample_rates) == hop_length
|
228 |
), f"hop_length must be {prod(upsample_rates)}"
|
229 |
|
230 |
+
self.conv_pre = FishConvNet(
|
231 |
+
num_mels,
|
232 |
+
upsample_initial_channel,
|
233 |
+
pre_conv_kernel_size,
|
234 |
+
stride=1,
|
235 |
+
).weight_norm()
|
|
|
|
|
|
|
236 |
|
237 |
self.num_upsamples = len(upsample_rates)
|
238 |
self.num_kernels = len(resblock_kernel_sizes)
|
239 |
|
240 |
self.noise_convs = nn.ModuleList()
|
|
|
241 |
self.ups = nn.ModuleList()
|
242 |
|
243 |
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
|
244 |
self.ups.append(
|
245 |
+
FishTransConvNet(
|
246 |
+
upsample_initial_channel // (2**i),
|
247 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
248 |
+
k,
|
249 |
+
stride=u,
|
250 |
+
).weight_norm()
|
|
|
|
|
|
|
251 |
)
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
self.resblocks = nn.ModuleList()
|
254 |
for i in range(len(self.ups)):
|
255 |
ch = upsample_initial_channel // (2 ** (i + 1))
|
256 |
self.resblocks.append(
|
257 |
+
ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
|
258 |
)
|
259 |
|
260 |
self.activation_post = post_activation()
|
261 |
+
self.conv_post = FishConvNet(
|
262 |
+
ch, 1, post_conv_kernel_size, stride=1
|
263 |
+
).weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
self.ups.apply(init_weights)
|
265 |
self.conv_post.apply(init_weights)
|
266 |
|
267 |
+
def forward(self, x):
|
268 |
x = self.conv_pre(x)
|
269 |
|
270 |
for i in range(self.num_upsamples):
|
271 |
x = F.silu(x, inplace=True)
|
272 |
x = self.ups[i](x)
|
273 |
|
|
|
|
|
|
|
274 |
if self.training and self.checkpointing:
|
275 |
x = checkpoint(
|
276 |
self.resblocks[i],
|
|
|
393 |
):
|
394 |
super().__init__()
|
395 |
|
396 |
+
self.dwconv = FishConvNet(
|
397 |
dim,
|
398 |
dim,
|
399 |
kernel_size=kernel_size,
|
400 |
+
# padding=int(dilation * (kernel_size - 1) / 2),
|
401 |
groups=dim,
|
402 |
) # depthwise conv
|
403 |
self.norm = LayerNorm(dim, eps=1e-6)
|
|
|
450 |
|
451 |
self.downsample_layers = nn.ModuleList()
|
452 |
stem = nn.Sequential(
|
453 |
+
FishConvNet(
|
454 |
input_channels,
|
455 |
dims[0],
|
456 |
+
kernel_size=7,
|
457 |
+
# padding=3,
|
458 |
+
# padding_mode="replicate",
|
459 |
+
# padding_mode="zeros",
|
460 |
),
|
461 |
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
462 |
)
|
|
|
521 |
self.head = head
|
522 |
self.quantizer = quantizer
|
523 |
self.spec_transform = spec_transform
|
524 |
+
self.downsample_factor = math.prod(self.quantizer.downsample_factor)
|
525 |
|
526 |
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
|
527 |
if self.spec_transform is not None:
|
|
|
559 |
|
560 |
# Encode
|
561 |
encoded_features = self.backbone(mels) * mel_masks_float_conv
|
562 |
+
feature_lengths = mel_lengths // self.downsample_factor
|
563 |
|
564 |
return self.quantizer.encode(encoded_features), feature_lengths
|
565 |
|
566 |
def decode(self, indices, feature_lengths) -> torch.Tensor:
|
567 |
+
mel_masks = sequence_mask(
|
568 |
+
feature_lengths * self.downsample_factor,
|
569 |
+
indices.shape[2] * self.downsample_factor,
|
570 |
+
)
|
571 |
mel_masks_float_conv = mel_masks[:, None, :].float()
|
572 |
+
audio_lengths = (
|
573 |
+
feature_lengths * self.downsample_factor * self.spec_transform.hop_length
|
574 |
+
)
|
575 |
|
576 |
audio_masks = sequence_mask(
|
577 |
+
audio_lengths,
|
578 |
+
indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
|
579 |
)
|
580 |
audio_masks_float_conv = audio_masks[:, None, :].float()
|
581 |
|
582 |
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
583 |
x = self.head(z) * audio_masks_float_conv
|
584 |
|
585 |
+
return x, audio_lengths
|
586 |
|
587 |
def remove_parametrizations(self):
|
588 |
if hasattr(self.backbone, "remove_parametrizations"):
|
|
|
594 |
@property
|
595 |
def device(self):
|
596 |
return next(self.parameters()).device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/vqgan/modules/fsq.py
CHANGED
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|
6 |
from einops import rearrange
|
7 |
from vector_quantize_pytorch import GroupedResidualFSQ
|
8 |
|
9 |
-
from .firefly import ConvNeXtBlock
|
10 |
|
11 |
|
12 |
@dataclass
|
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
20 |
def __init__(
|
21 |
self,
|
22 |
input_dim: int = 512,
|
23 |
-
n_codebooks: int =
|
24 |
n_groups: int = 1,
|
25 |
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
26 |
downsample_factor: tuple[int] = (2, 2),
|
@@ -46,7 +46,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
46 |
self.downsample = nn.Sequential(
|
47 |
*[
|
48 |
nn.Sequential(
|
49 |
-
|
50 |
all_dims[idx],
|
51 |
all_dims[idx + 1],
|
52 |
kernel_size=factor,
|
@@ -61,7 +61,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
61 |
self.upsample = nn.Sequential(
|
62 |
*[
|
63 |
nn.Sequential(
|
64 |
-
|
65 |
all_dims[idx + 1],
|
66 |
all_dims[idx],
|
67 |
kernel_size=factor,
|
@@ -114,26 +114,3 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
114 |
z_q = self.residual_fsq.get_output_from_indices(indices)
|
115 |
z_q = self.upsample(z_q.mT)
|
116 |
return z_q
|
117 |
-
|
118 |
-
# def from_latents(self, latents: torch.Tensor):
|
119 |
-
# z_q, z_p, codes = super().from_latents(latents)
|
120 |
-
# z_q = self.upsample(z_q)
|
121 |
-
# return z_q, z_p, codes
|
122 |
-
|
123 |
-
|
124 |
-
if __name__ == "__main__":
|
125 |
-
rvq = DownsampleFiniteScalarQuantize(
|
126 |
-
n_codebooks=1,
|
127 |
-
downsample_factor=(2, 2),
|
128 |
-
)
|
129 |
-
x = torch.randn(16, 512, 80)
|
130 |
-
|
131 |
-
result = rvq(x)
|
132 |
-
print(rvq)
|
133 |
-
print(result.latents.shape, result.codes.shape, result.z.shape)
|
134 |
-
|
135 |
-
# y = rvq.from_codes(result.codes)
|
136 |
-
# print(y[0].shape)
|
137 |
-
|
138 |
-
# y = rvq.from_latents(result.latents)
|
139 |
-
# print(y[0].shape)
|
|
|
6 |
from einops import rearrange
|
7 |
from vector_quantize_pytorch import GroupedResidualFSQ
|
8 |
|
9 |
+
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
|
10 |
|
11 |
|
12 |
@dataclass
|
|
|
20 |
def __init__(
|
21 |
self,
|
22 |
input_dim: int = 512,
|
23 |
+
n_codebooks: int = 9,
|
24 |
n_groups: int = 1,
|
25 |
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
26 |
downsample_factor: tuple[int] = (2, 2),
|
|
|
46 |
self.downsample = nn.Sequential(
|
47 |
*[
|
48 |
nn.Sequential(
|
49 |
+
FishConvNet(
|
50 |
all_dims[idx],
|
51 |
all_dims[idx + 1],
|
52 |
kernel_size=factor,
|
|
|
61 |
self.upsample = nn.Sequential(
|
62 |
*[
|
63 |
nn.Sequential(
|
64 |
+
FishTransConvNet(
|
65 |
all_dims[idx + 1],
|
66 |
all_dims[idx],
|
67 |
kernel_size=factor,
|
|
|
114 |
z_q = self.residual_fsq.get_output_from_indices(indices)
|
115 |
z_q = self.upsample(z_q.mT)
|
116 |
return z_q
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/scheduler.py
CHANGED
@@ -4,11 +4,14 @@ import math
|
|
4 |
def get_cosine_schedule_with_warmup_lr_lambda(
|
5 |
current_step: int,
|
6 |
*,
|
7 |
-
num_warmup_steps: int,
|
8 |
num_training_steps: int,
|
9 |
num_cycles: float = 0.5,
|
10 |
final_lr_ratio: float = 0.0,
|
11 |
):
|
|
|
|
|
|
|
12 |
if current_step < num_warmup_steps:
|
13 |
return float(current_step) / float(max(1, num_warmup_steps))
|
14 |
|
@@ -20,3 +23,18 @@ def get_cosine_schedule_with_warmup_lr_lambda(
|
|
20 |
final_lr_ratio,
|
21 |
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
22 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def get_cosine_schedule_with_warmup_lr_lambda(
|
5 |
current_step: int,
|
6 |
*,
|
7 |
+
num_warmup_steps: int | float,
|
8 |
num_training_steps: int,
|
9 |
num_cycles: float = 0.5,
|
10 |
final_lr_ratio: float = 0.0,
|
11 |
):
|
12 |
+
if 0 < num_warmup_steps < 1: # float mode
|
13 |
+
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
14 |
+
|
15 |
if current_step < num_warmup_steps:
|
16 |
return float(current_step) / float(max(1, num_warmup_steps))
|
17 |
|
|
|
23 |
final_lr_ratio,
|
24 |
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
25 |
)
|
26 |
+
|
27 |
+
|
28 |
+
def get_constant_schedule_with_warmup_lr_lambda(
|
29 |
+
current_step: int,
|
30 |
+
*,
|
31 |
+
num_warmup_steps: int | float,
|
32 |
+
num_training_steps: int | None = None,
|
33 |
+
):
|
34 |
+
if 0 < num_warmup_steps < 1: # float mode
|
35 |
+
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
36 |
+
|
37 |
+
if current_step < num_warmup_steps:
|
38 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
39 |
+
|
40 |
+
return 1.0
|
fish_speech/text/clean.py
CHANGED
@@ -64,6 +64,6 @@ def clean_text(text):
|
|
64 |
|
65 |
# Replace all chinese symbols with their english counterparts
|
66 |
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
67 |
-
text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
|
68 |
|
69 |
return text
|
|
|
64 |
|
65 |
# Replace all chinese symbols with their english counterparts
|
66 |
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
67 |
+
# text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
|
68 |
|
69 |
return text
|
fish_speech/train.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from typing import Optional
|
3 |
|
4 |
import hydra
|
@@ -7,6 +8,7 @@ import pyrootutils
|
|
7 |
import torch
|
8 |
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
9 |
from lightning.pytorch.loggers import Logger
|
|
|
10 |
from omegaconf import DictConfig, OmegaConf
|
11 |
|
12 |
os.environ.pop("SLURM_NTASKS", None)
|
@@ -61,7 +63,9 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
|
|
61 |
|
62 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
63 |
trainer: Trainer = hydra.utils.instantiate(
|
64 |
-
cfg.trainer,
|
|
|
|
|
65 |
)
|
66 |
|
67 |
object_dict = {
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
from typing import Optional
|
4 |
|
5 |
import hydra
|
|
|
8 |
import torch
|
9 |
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
10 |
from lightning.pytorch.loggers import Logger
|
11 |
+
from lightning.pytorch.strategies import DDPStrategy
|
12 |
from omegaconf import DictConfig, OmegaConf
|
13 |
|
14 |
os.environ.pop("SLURM_NTASKS", None)
|
|
|
63 |
|
64 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
65 |
trainer: Trainer = hydra.utils.instantiate(
|
66 |
+
cfg.trainer,
|
67 |
+
callbacks=callbacks,
|
68 |
+
logger=logger,
|
69 |
)
|
70 |
|
71 |
object_dict = {
|
fish_speech/utils/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from .braceexpand import braceexpand
|
|
|
2 |
from .file import get_latest_checkpoint
|
3 |
from .instantiators import instantiate_callbacks, instantiate_loggers
|
4 |
from .logger import RankedLogger
|
@@ -18,4 +19,5 @@ __all__ = [
|
|
18 |
"task_wrapper",
|
19 |
"braceexpand",
|
20 |
"get_latest_checkpoint",
|
|
|
21 |
]
|
|
|
1 |
from .braceexpand import braceexpand
|
2 |
+
from .context import autocast_exclude_mps
|
3 |
from .file import get_latest_checkpoint
|
4 |
from .instantiators import instantiate_callbacks, instantiate_loggers
|
5 |
from .logger import RankedLogger
|
|
|
19 |
"task_wrapper",
|
20 |
"braceexpand",
|
21 |
"get_latest_checkpoint",
|
22 |
+
"autocast_exclude_mps",
|
23 |
]
|
fish_speech/utils/context.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import nullcontext
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def autocast_exclude_mps(
|
7 |
+
device_type: str, dtype: torch.dtype
|
8 |
+
) -> nullcontext | torch.autocast:
|
9 |
+
return (
|
10 |
+
nullcontext()
|
11 |
+
if torch.backends.mps.is_available()
|
12 |
+
else torch.autocast(device_type, dtype)
|
13 |
+
)
|
fish_speech/utils/file.py
CHANGED
@@ -1,55 +1,5 @@
|
|
1 |
import os
|
2 |
-
from glob import glob
|
3 |
from pathlib import Path
|
4 |
-
from typing import Union
|
5 |
-
|
6 |
-
from loguru import logger
|
7 |
-
from natsort import natsorted
|
8 |
-
|
9 |
-
AUDIO_EXTENSIONS = {
|
10 |
-
".mp3",
|
11 |
-
".wav",
|
12 |
-
".flac",
|
13 |
-
".ogg",
|
14 |
-
".m4a",
|
15 |
-
".wma",
|
16 |
-
".aac",
|
17 |
-
".aiff",
|
18 |
-
".aif",
|
19 |
-
".aifc",
|
20 |
-
}
|
21 |
-
|
22 |
-
|
23 |
-
def list_files(
|
24 |
-
path: Union[Path, str],
|
25 |
-
extensions: set[str] = None,
|
26 |
-
recursive: bool = False,
|
27 |
-
sort: bool = True,
|
28 |
-
) -> list[Path]:
|
29 |
-
"""List files in a directory.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
path (Path): Path to the directory.
|
33 |
-
extensions (set, optional): Extensions to filter. Defaults to None.
|
34 |
-
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
35 |
-
sort (bool, optional): Whether to sort the files. Defaults to True.
|
36 |
-
|
37 |
-
Returns:
|
38 |
-
list: List of files.
|
39 |
-
"""
|
40 |
-
|
41 |
-
if isinstance(path, str):
|
42 |
-
path = Path(path)
|
43 |
-
|
44 |
-
if not path.exists():
|
45 |
-
raise FileNotFoundError(f"Directory {path} does not exist.")
|
46 |
-
|
47 |
-
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
48 |
-
|
49 |
-
if sort:
|
50 |
-
files = natsorted(files)
|
51 |
-
|
52 |
-
return files
|
53 |
|
54 |
|
55 |
def get_latest_checkpoint(path: Path | str) -> Path | None:
|
@@ -64,56 +14,3 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
|
|
64 |
return None
|
65 |
|
66 |
return ckpts[-1]
|
67 |
-
|
68 |
-
|
69 |
-
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
70 |
-
"""
|
71 |
-
Load a Bert-VITS2 style filelist.
|
72 |
-
"""
|
73 |
-
|
74 |
-
files = set()
|
75 |
-
results = []
|
76 |
-
count_duplicated, count_not_found = 0, 0
|
77 |
-
|
78 |
-
LANGUAGE_TO_LANGUAGES = {
|
79 |
-
"zh": ["zh", "en"],
|
80 |
-
"jp": ["jp", "en"],
|
81 |
-
"en": ["en"],
|
82 |
-
}
|
83 |
-
|
84 |
-
with open(path, "r", encoding="utf-8") as f:
|
85 |
-
for line in f.readlines():
|
86 |
-
splits = line.strip().split("|", maxsplit=3)
|
87 |
-
if len(splits) != 4:
|
88 |
-
logger.warning(f"Invalid line: {line}")
|
89 |
-
continue
|
90 |
-
|
91 |
-
filename, speaker, language, text = splits
|
92 |
-
file = Path(filename)
|
93 |
-
language = language.strip().lower()
|
94 |
-
|
95 |
-
if language == "ja":
|
96 |
-
language = "jp"
|
97 |
-
|
98 |
-
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
99 |
-
languages = LANGUAGE_TO_LANGUAGES[language]
|
100 |
-
|
101 |
-
if file in files:
|
102 |
-
logger.warning(f"Duplicated file: {file}")
|
103 |
-
count_duplicated += 1
|
104 |
-
continue
|
105 |
-
|
106 |
-
if not file.exists():
|
107 |
-
logger.warning(f"File not found: {file}")
|
108 |
-
count_not_found += 1
|
109 |
-
continue
|
110 |
-
|
111 |
-
results.append((file, speaker, languages, text))
|
112 |
-
|
113 |
-
if count_duplicated > 0:
|
114 |
-
logger.warning(f"Total duplicated files: {count_duplicated}")
|
115 |
-
|
116 |
-
if count_not_found > 0:
|
117 |
-
logger.warning(f"Total files not found: {count_not_found}")
|
118 |
-
|
119 |
-
return results
|
|
|
1 |
import os
|
|
|
2 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
def get_latest_checkpoint(path: Path | str) -> Path | None:
|
|
|
14 |
return None
|
15 |
|
16 |
return ckpts[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/webui/css/style.css
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
:root {
|
2 |
+
--my-200: #80eeee;
|
3 |
+
--my-50: #ecfdf5;
|
4 |
+
--water-width: 300px;
|
5 |
+
--water-heigh: 300px;
|
6 |
+
}
|
7 |
+
|
8 |
+
|
9 |
+
/* general styled components */
|
10 |
+
.tools {
|
11 |
+
align-items: center;
|
12 |
+
justify-content: center;
|
13 |
+
}
|
14 |
+
|
15 |
+
.gradio-button {
|
16 |
+
max-width: 2.2em;
|
17 |
+
min-width: 2.2em !important;
|
18 |
+
height: 2.4em;
|
19 |
+
align-self: end;
|
20 |
+
line-height: 1em;
|
21 |
+
border-radius: 0.5em;
|
22 |
+
|
23 |
+
}
|
24 |
+
|
25 |
+
.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
|
26 |
+
box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
|
27 |
+
}
|
28 |
+
|
29 |
+
/* replace original footer with ours */
|
30 |
+
a{
|
31 |
+
font-weight: bold;
|
32 |
+
cursor: pointer;
|
33 |
+
color: #030C14 !important;
|
34 |
+
}
|
35 |
+
|
36 |
+
footer {
|
37 |
+
display: none !important;
|
38 |
+
}
|
39 |
+
|
40 |
+
#footer{
|
41 |
+
text-align: center;
|
42 |
+
}
|
43 |
+
|
44 |
+
#footer div{
|
45 |
+
display: inline-block;
|
46 |
+
}
|
47 |
+
|
48 |
+
#footer .versions{
|
49 |
+
font-size: 85%;
|
50 |
+
opacity: 0.85;
|
51 |
+
}
|
52 |
+
|
53 |
+
/*@keyframes moveBackground {*/
|
54 |
+
/* 0% {*/
|
55 |
+
/* background-position: 0 0;*/
|
56 |
+
/* }*/
|
57 |
+
/* 100% {*/
|
58 |
+
/* background-position: -100px 100px;*/
|
59 |
+
/* }*/
|
60 |
+
/*}*/
|
61 |
+
@keyframes moveJellyBackground {
|
62 |
+
0% {
|
63 |
+
background-position: 0% 50%;
|
64 |
+
}
|
65 |
+
50% {
|
66 |
+
background-position: 100% 50%;
|
67 |
+
}
|
68 |
+
100% {
|
69 |
+
background-position: 0% 50%;
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
.gradio-container {
|
74 |
+
position: absolute;
|
75 |
+
z-index: 10;
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
.quan {
|
80 |
+
position: absolute;
|
81 |
+
bottom: 0;
|
82 |
+
width: var(--water-width);
|
83 |
+
height: var(--water-heigh);
|
84 |
+
border-radius: 0;
|
85 |
+
/*border: 3px solid rgb(246, 247, 248);*/
|
86 |
+
/*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
|
87 |
+
z-index: 0;
|
88 |
+
|
89 |
+
}
|
90 |
+
|
91 |
+
.quan:last-child {
|
92 |
+
margin-right: 0;
|
93 |
+
}
|
94 |
+
|
95 |
+
.shui {
|
96 |
+
position: absolute;
|
97 |
+
top: 0;
|
98 |
+
left: 0;
|
99 |
+
width: 100%;
|
100 |
+
height: 100%;
|
101 |
+
background-color: rgb(23, 106, 201);
|
102 |
+
border-radius: 0;
|
103 |
+
overflow: hidden;
|
104 |
+
z-index: 0;
|
105 |
+
}
|
106 |
+
|
107 |
+
.shui::after {
|
108 |
+
|
109 |
+
content: '';
|
110 |
+
position: absolute;
|
111 |
+
top: 20%;
|
112 |
+
left: 50%;
|
113 |
+
width: 150%;
|
114 |
+
height: 150%;
|
115 |
+
border-radius: 40%;
|
116 |
+
background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
|
117 |
+
animation: shi 5s linear infinite;
|
118 |
+
}
|
119 |
+
|
120 |
+
@keyframes shi {
|
121 |
+
0% {
|
122 |
+
transform: translate(-50%, -65%) rotate(0deg);
|
123 |
+
}
|
124 |
+
100% {
|
125 |
+
transform: translate(-50%, -65%) rotate(360deg);
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
.shui::before {
|
130 |
+
content: '';
|
131 |
+
position: absolute;
|
132 |
+
top: 20%;
|
133 |
+
left: 50%;
|
134 |
+
width: 150%;
|
135 |
+
height: 150%;
|
136 |
+
border-radius: 42%;
|
137 |
+
background-color: rgb(240, 228, 228, 0.2);
|
138 |
+
animation: xu 7s linear infinite;
|
139 |
+
}
|
140 |
+
|
141 |
+
@keyframes xu {
|
142 |
+
0% {
|
143 |
+
transform: translate(-50%, -60%) rotate(0deg);
|
144 |
+
}
|
145 |
+
100% {
|
146 |
+
transform: translate(-50%, -60%) rotate(360deg);
|
147 |
+
}
|
148 |
+
}
|
149 |
+
|
150 |
+
fieldset.data_src div.wrap label {
|
151 |
+
background: #f8bffee0 !important;
|
152 |
+
}
|
153 |
+
|
154 |
+
.scrollable-component {
|
155 |
+
max-height: 100px;
|
156 |
+
overflow-y: auto;
|
157 |
+
}
|
158 |
+
|
159 |
+
#file_accordion {
|
160 |
+
max-height: 220px !important;
|
161 |
+
}
|
fish_speech/webui/html/footer.html
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="color: rgba(25,255,205,0.7) !important;">
|
2 |
+
<a href="{api_docs}">API</a>
|
3 |
+
•
|
4 |
+
<a href="https://github.com/fishaudio/fish-speech">Github</a>
|
5 |
+
•
|
6 |
+
<a href="https://gradio.app">Gradio</a>
|
7 |
+
</div>
|
8 |
+
<br />
|
9 |
+
<div class="versions" style="color: rgba(25,255,205,0.7) !important;">
|
10 |
+
{versions}
|
11 |
+
</div>
|
fish_speech/webui/js/animate.js
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
function createGradioAnimation() {
|
3 |
+
const params = new URLSearchParams(window.location.search);
|
4 |
+
if (!params.has('__theme')) {
|
5 |
+
params.set('__theme', 'light');
|
6 |
+
window.location.search = params.toString();
|
7 |
+
}
|
8 |
+
|
9 |
+
var gradioApp = document.querySelector('gradio-app');
|
10 |
+
if (gradioApp) {
|
11 |
+
|
12 |
+
document.documentElement.style.setProperty('--my-200', '#80eeee');
|
13 |
+
document.documentElement.style.setProperty('--my-50', '#ecfdf5');
|
14 |
+
|
15 |
+
// gradioApp.style.position = 'relative';
|
16 |
+
// gradioApp.style.backgroundSize = '200% 200%';
|
17 |
+
// gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
|
18 |
+
// gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
|
19 |
+
// gradioApp.style.display = 'flex';
|
20 |
+
// gradioApp.style.justifyContent = 'flex-start';
|
21 |
+
// gradioApp.style.flexWrap = 'nowrap';
|
22 |
+
// gradioApp.style.overflowX = 'auto';
|
23 |
+
|
24 |
+
// for (let i = 0; i < 6; i++) {
|
25 |
+
// var quan = document.createElement('div');
|
26 |
+
// quan.className = 'quan';
|
27 |
+
// gradioApp.insertBefore(quan, gradioApp.firstChild);
|
28 |
+
// quan.id = 'quan' + i.toString();
|
29 |
+
// quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
|
30 |
+
// var quanContainer = document.querySelector('.quan');
|
31 |
+
// if (quanContainer) {
|
32 |
+
// var shui = document.createElement('div');
|
33 |
+
// shui.className = 'shui';
|
34 |
+
// quanContainer.insertBefore(shui, quanContainer.firstChild)
|
35 |
+
// }
|
36 |
+
// }
|
37 |
+
}
|
38 |
+
|
39 |
+
var container = document.createElement('div');
|
40 |
+
container.id = 'gradio-animation';
|
41 |
+
container.style.fontSize = '2em';
|
42 |
+
container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
|
43 |
+
container.style.fontWeight = 'bold';
|
44 |
+
container.style.textAlign = 'center';
|
45 |
+
container.style.marginBottom = '20px';
|
46 |
+
|
47 |
+
var text = 'Welcome to Fish-Speech!';
|
48 |
+
for (var i = 0; i < text.length; i++) {
|
49 |
+
(function(i){
|
50 |
+
setTimeout(function(){
|
51 |
+
var letter = document.createElement('span');
|
52 |
+
letter.style.opacity = '0';
|
53 |
+
letter.style.transition = 'opacity 0.5s';
|
54 |
+
letter.innerText = text[i];
|
55 |
+
|
56 |
+
container.appendChild(letter);
|
57 |
+
|
58 |
+
setTimeout(function() {
|
59 |
+
letter.style.opacity = '1';
|
60 |
+
}, 50);
|
61 |
+
}, i * 200);
|
62 |
+
})(i);
|
63 |
+
}
|
64 |
+
|
65 |
+
var gradioContainer = document.querySelector('.gradio-container');
|
66 |
+
gradioContainer.insertBefore(container, gradioContainer.firstChild);
|
67 |
+
|
68 |
+
return 'Animation created';
|
69 |
+
}
|
fish_speech/webui/launch_utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
from functools import lru_cache
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Iterable
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from gradio.themes.base import Base
|
11 |
+
from gradio.themes.utils import colors, fonts, sizes
|
12 |
+
|
13 |
+
GIT = (
|
14 |
+
(Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
|
15 |
+
if sys.platform == "win32"
|
16 |
+
else "git"
|
17 |
+
)
|
18 |
+
GIT = str(GIT)
|
19 |
+
|
20 |
+
|
21 |
+
def is_module_installed(module_name: str) -> bool:
|
22 |
+
spec = importlib.util.find_spec(module_name)
|
23 |
+
return spec is not None
|
24 |
+
|
25 |
+
|
26 |
+
@lru_cache()
|
27 |
+
def commit_hash():
|
28 |
+
try:
|
29 |
+
return subprocess.check_output(
|
30 |
+
[GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
|
31 |
+
).strip()
|
32 |
+
except Exception:
|
33 |
+
return "<none>"
|
34 |
+
|
35 |
+
|
36 |
+
def versions_html():
|
37 |
+
import torch
|
38 |
+
|
39 |
+
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
40 |
+
commit = commit_hash()
|
41 |
+
hash = commit.strip("'").split(" ")[0]
|
42 |
+
|
43 |
+
return f"""
|
44 |
+
version: <a href="https://github.com/fishaudio/fish-speech/commit/{hash}">{hash}</a>
|
45 |
+
 • 
|
46 |
+
python: <span title="{sys.version}">{python_version}</span>
|
47 |
+
 • 
|
48 |
+
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
49 |
+
 • 
|
50 |
+
gradio: {gr.__version__}
|
51 |
+
 • 
|
52 |
+
author: <a href="https://github.com/fishaudio">fishaudio</a>
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
def version_check(commit):
|
57 |
+
try:
|
58 |
+
import requests
|
59 |
+
|
60 |
+
commits = requests.get(
|
61 |
+
"https://api.github.com/repos/fishaudio/fish-speech/branches/main"
|
62 |
+
).json()
|
63 |
+
if commit != "<none>" and commits["commit"]["sha"] != commit:
|
64 |
+
print("--------------------------------------------------------")
|
65 |
+
print("| You are not up to date with the most recent release. |")
|
66 |
+
print("| Consider running `git pull` to update. |")
|
67 |
+
print("--------------------------------------------------------")
|
68 |
+
elif commits["commit"]["sha"] == commit:
|
69 |
+
print("You are up to date with the most recent release.")
|
70 |
+
else:
|
71 |
+
print("Not a git clone, can't perform version check.")
|
72 |
+
except Exception as e:
|
73 |
+
print("version check failed", e)
|
74 |
+
|
75 |
+
|
76 |
+
class Seafoam(Base):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
*,
|
80 |
+
primary_hue: colors.Color | str = colors.emerald,
|
81 |
+
secondary_hue: colors.Color | str = colors.blue,
|
82 |
+
neutral_hue: colors.Color | str = colors.blue,
|
83 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
84 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
85 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
86 |
+
font: fonts.Font | str | Iterable[fonts.Font | str] = (
|
87 |
+
fonts.GoogleFont("Quicksand"),
|
88 |
+
"ui-sans-serif",
|
89 |
+
"sans-serif",
|
90 |
+
),
|
91 |
+
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
|
92 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
93 |
+
"ui-monospace",
|
94 |
+
"monospace",
|
95 |
+
),
|
96 |
+
):
|
97 |
+
super().__init__(
|
98 |
+
primary_hue=primary_hue,
|
99 |
+
secondary_hue=secondary_hue,
|
100 |
+
neutral_hue=neutral_hue,
|
101 |
+
spacing_size=spacing_size,
|
102 |
+
radius_size=radius_size,
|
103 |
+
text_size=text_size,
|
104 |
+
font=font,
|
105 |
+
font_mono=font_mono,
|
106 |
+
)
|
107 |
+
super().set(
|
108 |
+
button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
|
109 |
+
button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
|
110 |
+
button_primary_text_color="white",
|
111 |
+
button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
|
112 |
+
slider_color="*secondary_300",
|
113 |
+
slider_color_dark="*secondary_600",
|
114 |
+
block_title_text_weight="600",
|
115 |
+
block_border_width="3px",
|
116 |
+
block_shadow="*shadow_drop_lg",
|
117 |
+
button_shadow="*shadow_drop_lg",
|
118 |
+
button_small_padding="0px",
|
119 |
+
button_large_padding="3px",
|
120 |
+
)
|
fish_speech/webui/manage.py
ADDED
@@ -0,0 +1,1237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import html
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import platform
|
8 |
+
import shutil
|
9 |
+
import signal
|
10 |
+
import subprocess
|
11 |
+
import sys
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import psutil
|
16 |
+
import yaml
|
17 |
+
from loguru import logger
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
|
21 |
+
sys.path.insert(0, "")
|
22 |
+
print(sys.path)
|
23 |
+
cur_work_dir = Path(os.getcwd()).resolve()
|
24 |
+
print("You are in ", str(cur_work_dir))
|
25 |
+
|
26 |
+
from fish_speech.i18n import i18n
|
27 |
+
from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
|
28 |
+
|
29 |
+
config_path = cur_work_dir / "fish_speech" / "configs"
|
30 |
+
vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
|
31 |
+
llama_yml_path = config_path / "text2semantic_finetune.yaml"
|
32 |
+
|
33 |
+
env = os.environ.copy()
|
34 |
+
env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
|
35 |
+
|
36 |
+
seafoam = Seafoam()
|
37 |
+
|
38 |
+
|
39 |
+
def build_html_error_message(error):
|
40 |
+
return f"""
|
41 |
+
<div style="color: red; font-weight: bold;">
|
42 |
+
{html.escape(error)}
|
43 |
+
</div>
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
def build_html_ok_message(msg):
|
48 |
+
return f"""
|
49 |
+
<div style="color: green; font-weight: bold;">
|
50 |
+
{html.escape(msg)}
|
51 |
+
</div>
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def build_html_href(link, desc, msg):
|
56 |
+
return f"""
|
57 |
+
<span style="color: green; font-weight: bold; display: inline-block">
|
58 |
+
{html.escape(msg)}
|
59 |
+
<a href="{link}">{desc}</a>
|
60 |
+
</span>
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
def load_data_in_raw(path):
|
65 |
+
with open(path, "r", encoding="utf-8") as file:
|
66 |
+
data = file.read()
|
67 |
+
return str(data)
|
68 |
+
|
69 |
+
|
70 |
+
def kill_proc_tree(pid, including_parent=True):
|
71 |
+
try:
|
72 |
+
parent = psutil.Process(pid)
|
73 |
+
except psutil.NoSuchProcess:
|
74 |
+
# Process already terminated
|
75 |
+
return
|
76 |
+
|
77 |
+
children = parent.children(recursive=True)
|
78 |
+
for child in children:
|
79 |
+
try:
|
80 |
+
os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
|
81 |
+
except OSError:
|
82 |
+
pass
|
83 |
+
if including_parent:
|
84 |
+
try:
|
85 |
+
os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
|
86 |
+
except OSError:
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
system = platform.system()
|
91 |
+
p_label = None
|
92 |
+
p_infer = None
|
93 |
+
p_tensorboard = None
|
94 |
+
|
95 |
+
|
96 |
+
def kill_process(pid):
|
97 |
+
if system == "Windows":
|
98 |
+
cmd = "taskkill /t /f /pid %s" % pid
|
99 |
+
# os.system(cmd)
|
100 |
+
subprocess.run(cmd)
|
101 |
+
else:
|
102 |
+
kill_proc_tree(pid)
|
103 |
+
|
104 |
+
|
105 |
+
def change_label(if_label):
|
106 |
+
global p_label
|
107 |
+
if if_label == True and p_label is None:
|
108 |
+
url = "http://localhost:3000"
|
109 |
+
remote_url = "https://text-labeler.pages.dev/"
|
110 |
+
try:
|
111 |
+
p_label = subprocess.Popen(
|
112 |
+
[
|
113 |
+
(
|
114 |
+
"asr-label-linux-x64"
|
115 |
+
if sys.platform == "linux"
|
116 |
+
else "asr-label-win-x64.exe"
|
117 |
+
)
|
118 |
+
]
|
119 |
+
)
|
120 |
+
except FileNotFoundError:
|
121 |
+
logger.warning("asr-label execution not found!")
|
122 |
+
|
123 |
+
yield build_html_href(
|
124 |
+
link=remote_url,
|
125 |
+
desc=i18n("Optional online ver"),
|
126 |
+
msg=i18n("Opened labeler in browser"),
|
127 |
+
)
|
128 |
+
|
129 |
+
elif if_label == False and p_label is not None:
|
130 |
+
kill_process(p_label.pid)
|
131 |
+
p_label = None
|
132 |
+
yield build_html_ok_message("Nothing")
|
133 |
+
|
134 |
+
|
135 |
+
def clean_infer_cache():
|
136 |
+
import tempfile
|
137 |
+
|
138 |
+
temp_dir = Path(tempfile.gettempdir())
|
139 |
+
gradio_dir = str(temp_dir / "gradio")
|
140 |
+
try:
|
141 |
+
shutil.rmtree(gradio_dir)
|
142 |
+
logger.info(f"Deleted cached audios: {gradio_dir}")
|
143 |
+
except PermissionError:
|
144 |
+
logger.info(f"Permission denied: Unable to delete {gradio_dir}")
|
145 |
+
except FileNotFoundError:
|
146 |
+
logger.info(f"{gradio_dir} was not found")
|
147 |
+
except Exception as e:
|
148 |
+
logger.info(f"An error occurred: {e}")
|
149 |
+
|
150 |
+
|
151 |
+
def change_infer(
|
152 |
+
if_infer,
|
153 |
+
host,
|
154 |
+
port,
|
155 |
+
infer_decoder_model,
|
156 |
+
infer_decoder_config,
|
157 |
+
infer_llama_model,
|
158 |
+
infer_compile,
|
159 |
+
):
|
160 |
+
global p_infer
|
161 |
+
if if_infer == True and p_infer == None:
|
162 |
+
env = os.environ.copy()
|
163 |
+
|
164 |
+
env["GRADIO_SERVER_NAME"] = host
|
165 |
+
env["GRADIO_SERVER_PORT"] = port
|
166 |
+
# 启动第二个进程
|
167 |
+
url = f"http://{host}:{port}"
|
168 |
+
yield build_html_ok_message(
|
169 |
+
i18n("Inferring interface is launched at {}").format(url)
|
170 |
+
)
|
171 |
+
|
172 |
+
clean_infer_cache()
|
173 |
+
|
174 |
+
p_infer = subprocess.Popen(
|
175 |
+
[
|
176 |
+
PYTHON,
|
177 |
+
"tools/webui.py",
|
178 |
+
"--decoder-checkpoint-path",
|
179 |
+
infer_decoder_model,
|
180 |
+
"--decoder-config-name",
|
181 |
+
infer_decoder_config,
|
182 |
+
"--llama-checkpoint-path",
|
183 |
+
infer_llama_model,
|
184 |
+
]
|
185 |
+
+ (["--compile"] if infer_compile == "Yes" else []),
|
186 |
+
env=env,
|
187 |
+
)
|
188 |
+
|
189 |
+
elif if_infer == False and p_infer is not None:
|
190 |
+
kill_process(p_infer.pid)
|
191 |
+
p_infer = None
|
192 |
+
yield build_html_error_message(i18n("Infer interface is closed"))
|
193 |
+
|
194 |
+
|
195 |
+
js = load_data_in_raw("fish_speech/webui/js/animate.js")
|
196 |
+
css = load_data_in_raw("fish_speech/webui/css/style.css")
|
197 |
+
|
198 |
+
data_pre_output = (cur_work_dir / "data").resolve()
|
199 |
+
default_model_output = (cur_work_dir / "results").resolve()
|
200 |
+
default_filelist = data_pre_output / "detect.list"
|
201 |
+
data_pre_output.mkdir(parents=True, exist_ok=True)
|
202 |
+
|
203 |
+
items = []
|
204 |
+
dict_items = {}
|
205 |
+
|
206 |
+
|
207 |
+
def load_yaml_data_in_fact(yml_path):
|
208 |
+
with open(yml_path, "r", encoding="utf-8") as file:
|
209 |
+
yml = yaml.safe_load(file)
|
210 |
+
return yml
|
211 |
+
|
212 |
+
|
213 |
+
def write_yaml_data_in_fact(yml, yml_path):
|
214 |
+
with open(yml_path, "w", encoding="utf-8") as file:
|
215 |
+
yaml.safe_dump(yml, file, allow_unicode=True)
|
216 |
+
return yml
|
217 |
+
|
218 |
+
|
219 |
+
def generate_tree(directory, depth=0, max_depth=None, prefix=""):
|
220 |
+
if max_depth is not None and depth > max_depth:
|
221 |
+
return ""
|
222 |
+
|
223 |
+
tree_str = ""
|
224 |
+
files = []
|
225 |
+
directories = []
|
226 |
+
for item in os.listdir(directory):
|
227 |
+
if os.path.isdir(os.path.join(directory, item)):
|
228 |
+
directories.append(item)
|
229 |
+
else:
|
230 |
+
files.append(item)
|
231 |
+
|
232 |
+
entries = directories + files
|
233 |
+
for i, entry in enumerate(entries):
|
234 |
+
connector = "├── " if i < len(entries) - 1 else "└── "
|
235 |
+
tree_str += f"{prefix}{connector}{entry}<br />"
|
236 |
+
if i < len(directories):
|
237 |
+
extension = "│ " if i < len(entries) - 1 else " "
|
238 |
+
tree_str += generate_tree(
|
239 |
+
os.path.join(directory, entry),
|
240 |
+
depth + 1,
|
241 |
+
max_depth,
|
242 |
+
prefix=prefix + extension,
|
243 |
+
)
|
244 |
+
return tree_str
|
245 |
+
|
246 |
+
|
247 |
+
def new_explorer(data_path, max_depth):
|
248 |
+
return gr.Markdown(
|
249 |
+
elem_classes=["scrollable-component"],
|
250 |
+
value=generate_tree(data_path, max_depth=max_depth),
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
def add_item(
|
255 |
+
folder: str,
|
256 |
+
method: str,
|
257 |
+
label_lang: str,
|
258 |
+
if_initial_prompt: bool,
|
259 |
+
initial_prompt: str | None,
|
260 |
+
):
|
261 |
+
folder = folder.strip(" ").strip('"')
|
262 |
+
|
263 |
+
folder_path = Path(folder)
|
264 |
+
|
265 |
+
if folder and folder not in items and data_pre_output not in folder_path.parents:
|
266 |
+
if folder_path.is_dir():
|
267 |
+
items.append(folder)
|
268 |
+
dict_items[folder] = dict(
|
269 |
+
type="folder",
|
270 |
+
method=method,
|
271 |
+
label_lang=label_lang,
|
272 |
+
initial_prompt=initial_prompt if if_initial_prompt else None,
|
273 |
+
)
|
274 |
+
elif folder:
|
275 |
+
err = folder
|
276 |
+
return gr.Checkboxgroup(choices=items), build_html_error_message(
|
277 |
+
i18n("Invalid path: {}").format(err)
|
278 |
+
)
|
279 |
+
|
280 |
+
formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
|
281 |
+
logger.info("After Adding: " + formatted_data)
|
282 |
+
gr.Info(formatted_data)
|
283 |
+
return gr.Checkboxgroup(choices=items), build_html_ok_message(
|
284 |
+
i18n("Added path successfully!")
|
285 |
+
)
|
286 |
+
|
287 |
+
|
288 |
+
def remove_items(selected_items):
|
289 |
+
global items, dict_items
|
290 |
+
to_remove = [item for item in items if item in selected_items]
|
291 |
+
for item in to_remove:
|
292 |
+
del dict_items[item]
|
293 |
+
items = [item for item in items if item in dict_items.keys()]
|
294 |
+
formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
|
295 |
+
logger.info(formatted_data)
|
296 |
+
gr.Warning("After Removing: " + formatted_data)
|
297 |
+
return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
|
298 |
+
i18n("Removed path successfully!")
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
def show_selected(options):
|
303 |
+
selected_options = ", ".join(options)
|
304 |
+
|
305 |
+
if options:
|
306 |
+
return i18n("Selected: {}").format(selected_options)
|
307 |
+
else:
|
308 |
+
return i18n("No selected options")
|
309 |
+
|
310 |
+
|
311 |
+
from pydub import AudioSegment
|
312 |
+
|
313 |
+
|
314 |
+
def convert_to_mono_in_place(audio_path: Path):
|
315 |
+
audio = AudioSegment.from_file(audio_path)
|
316 |
+
if audio.channels > 1:
|
317 |
+
mono_audio = audio.set_channels(1)
|
318 |
+
mono_audio.export(audio_path, format=audio_path.suffix[1:])
|
319 |
+
logger.info(f"Convert {audio_path} successfully")
|
320 |
+
|
321 |
+
|
322 |
+
def list_copy(list_file_path, method):
|
323 |
+
wav_root = data_pre_output
|
324 |
+
lst = []
|
325 |
+
with list_file_path.open("r", encoding="utf-8") as file:
|
326 |
+
for line in tqdm(file, desc="Processing audio/transcript"):
|
327 |
+
wav_path, speaker_name, language, text = line.strip().split("|")
|
328 |
+
original_wav_path = Path(wav_path)
|
329 |
+
target_wav_path = (
|
330 |
+
wav_root / original_wav_path.parent.name / original_wav_path.name
|
331 |
+
)
|
332 |
+
lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
|
333 |
+
if target_wav_path.is_file():
|
334 |
+
continue
|
335 |
+
target_wav_path.parent.mkdir(parents=True, exist_ok=True)
|
336 |
+
if method == i18n("Copy"):
|
337 |
+
shutil.copy(original_wav_path, target_wav_path)
|
338 |
+
else:
|
339 |
+
shutil.move(original_wav_path, target_wav_path.parent)
|
340 |
+
convert_to_mono_in_place(target_wav_path)
|
341 |
+
original_lab_path = original_wav_path.with_suffix(".lab")
|
342 |
+
target_lab_path = (
|
343 |
+
wav_root
|
344 |
+
/ original_wav_path.parent.name
|
345 |
+
/ original_wav_path.with_suffix(".lab").name
|
346 |
+
)
|
347 |
+
if target_lab_path.is_file():
|
348 |
+
continue
|
349 |
+
if method == i18n("Copy"):
|
350 |
+
shutil.copy(original_lab_path, target_lab_path)
|
351 |
+
else:
|
352 |
+
shutil.move(original_lab_path, target_lab_path.parent)
|
353 |
+
|
354 |
+
if method == i18n("Move"):
|
355 |
+
with list_file_path.open("w", encoding="utf-8") as file:
|
356 |
+
file.writelines("\n".join(lst))
|
357 |
+
|
358 |
+
del lst
|
359 |
+
return build_html_ok_message(i18n("Use filelist"))
|
360 |
+
|
361 |
+
|
362 |
+
def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
|
363 |
+
global dict_items
|
364 |
+
data_path = Path(data_path)
|
365 |
+
gr.Warning("Pre-processing begins...")
|
366 |
+
for item, content in dict_items.items():
|
367 |
+
item_path = Path(item)
|
368 |
+
tar_path = data_path / item_path.name
|
369 |
+
|
370 |
+
if content["type"] == "folder" and item_path.is_dir():
|
371 |
+
if content["method"] == i18n("Copy"):
|
372 |
+
os.makedirs(tar_path, exist_ok=True)
|
373 |
+
shutil.copytree(
|
374 |
+
src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
|
375 |
+
)
|
376 |
+
elif not tar_path.is_dir():
|
377 |
+
shutil.move(src=str(item_path), dst=str(tar_path))
|
378 |
+
|
379 |
+
for suf in ["wav", "flac", "mp3"]:
|
380 |
+
for audio_path in tar_path.glob(f"**/*.{suf}"):
|
381 |
+
convert_to_mono_in_place(audio_path)
|
382 |
+
|
383 |
+
cur_lang = content["label_lang"]
|
384 |
+
initial_prompt = content["initial_prompt"]
|
385 |
+
|
386 |
+
transcribe_cmd = [
|
387 |
+
PYTHON,
|
388 |
+
"tools/whisper_asr.py",
|
389 |
+
"--model-size",
|
390 |
+
label_model,
|
391 |
+
"--device",
|
392 |
+
label_device,
|
393 |
+
"--audio-dir",
|
394 |
+
tar_path,
|
395 |
+
"--save-dir",
|
396 |
+
tar_path,
|
397 |
+
"--language",
|
398 |
+
cur_lang,
|
399 |
+
]
|
400 |
+
|
401 |
+
if initial_prompt is not None:
|
402 |
+
transcribe_cmd += ["--initial-prompt", initial_prompt]
|
403 |
+
|
404 |
+
if cur_lang != "IGNORE":
|
405 |
+
try:
|
406 |
+
gr.Warning("Begin To Transcribe")
|
407 |
+
subprocess.run(
|
408 |
+
transcribe_cmd,
|
409 |
+
env=env,
|
410 |
+
)
|
411 |
+
except Exception:
|
412 |
+
print("Transcription error occurred")
|
413 |
+
|
414 |
+
elif content["type"] == "file" and item_path.is_file():
|
415 |
+
list_copy(item_path, content["method"])
|
416 |
+
|
417 |
+
return build_html_ok_message(i18n("Move files successfully")), new_explorer(
|
418 |
+
data_path, max_depth=max_depth
|
419 |
+
)
|
420 |
+
|
421 |
+
|
422 |
+
def generate_folder_name():
|
423 |
+
now = datetime.datetime.now()
|
424 |
+
folder_name = now.strftime("%Y%m%d_%H%M%S")
|
425 |
+
return folder_name
|
426 |
+
|
427 |
+
|
428 |
+
def train_process(
|
429 |
+
data_path: str,
|
430 |
+
option: str,
|
431 |
+
# llama config
|
432 |
+
llama_ckpt,
|
433 |
+
llama_base_config,
|
434 |
+
llama_lr,
|
435 |
+
llama_maxsteps,
|
436 |
+
llama_data_num_workers,
|
437 |
+
llama_data_batch_size,
|
438 |
+
llama_data_max_length,
|
439 |
+
llama_precision,
|
440 |
+
llama_check_interval,
|
441 |
+
llama_grad_batches,
|
442 |
+
llama_use_speaker,
|
443 |
+
llama_use_lora,
|
444 |
+
):
|
445 |
+
|
446 |
+
backend = "nccl" if sys.platform == "linux" else "gloo"
|
447 |
+
|
448 |
+
new_project = generate_folder_name()
|
449 |
+
print("New Project Name: ", new_project)
|
450 |
+
|
451 |
+
if option == "VQGAN":
|
452 |
+
msg = "Skipped VQGAN Training."
|
453 |
+
gr.Warning(msg)
|
454 |
+
logger.info(msg)
|
455 |
+
|
456 |
+
if option == "LLAMA":
|
457 |
+
msg = "LLAMA Training begins..."
|
458 |
+
gr.Warning(msg)
|
459 |
+
logger.info(msg)
|
460 |
+
subprocess.run(
|
461 |
+
[
|
462 |
+
PYTHON,
|
463 |
+
"tools/vqgan/extract_vq.py",
|
464 |
+
str(data_pre_output),
|
465 |
+
"--num-workers",
|
466 |
+
"1",
|
467 |
+
"--batch-size",
|
468 |
+
"16",
|
469 |
+
"--config-name",
|
470 |
+
"firefly_gan_vq",
|
471 |
+
"--checkpoint-path",
|
472 |
+
"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
473 |
+
]
|
474 |
+
)
|
475 |
+
|
476 |
+
subprocess.run(
|
477 |
+
[
|
478 |
+
PYTHON,
|
479 |
+
"tools/llama/build_dataset.py",
|
480 |
+
"--input",
|
481 |
+
str(data_pre_output),
|
482 |
+
"--text-extension",
|
483 |
+
".lab",
|
484 |
+
"--num-workers",
|
485 |
+
"16",
|
486 |
+
]
|
487 |
+
)
|
488 |
+
ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
|
489 |
+
lora_prefix = "lora_" if llama_use_lora else ""
|
490 |
+
llama_name = lora_prefix + "text2semantic_" + new_project
|
491 |
+
latest = next(
|
492 |
+
iter(
|
493 |
+
sorted(
|
494 |
+
[
|
495 |
+
str(p.relative_to("results"))
|
496 |
+
for p in Path("results").glob(lora_prefix + "text2sem*/")
|
497 |
+
],
|
498 |
+
reverse=True,
|
499 |
+
)
|
500 |
+
),
|
501 |
+
llama_name,
|
502 |
+
)
|
503 |
+
project = (
|
504 |
+
llama_name
|
505 |
+
if llama_ckpt == i18n("new")
|
506 |
+
else (
|
507 |
+
latest
|
508 |
+
if llama_ckpt == i18n("latest")
|
509 |
+
else Path(llama_ckpt).relative_to("results")
|
510 |
+
)
|
511 |
+
)
|
512 |
+
logger.info(project)
|
513 |
+
|
514 |
+
if llama_check_interval > llama_maxsteps:
|
515 |
+
llama_check_interval = llama_maxsteps
|
516 |
+
|
517 |
+
train_cmd = [
|
518 |
+
PYTHON,
|
519 |
+
"fish_speech/train.py",
|
520 |
+
"--config-name",
|
521 |
+
"text2semantic_finetune",
|
522 |
+
f"project={project}",
|
523 |
+
f"trainer.strategy.process_group_backend={backend}",
|
524 |
+
f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
525 |
+
f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
526 |
+
f"model.optimizer.lr={llama_lr}",
|
527 |
+
f"trainer.max_steps={llama_maxsteps}",
|
528 |
+
f"data.num_workers={llama_data_num_workers}",
|
529 |
+
f"data.batch_size={llama_data_batch_size}",
|
530 |
+
f"max_length={llama_data_max_length}",
|
531 |
+
f"trainer.precision={llama_precision}",
|
532 |
+
f"trainer.val_check_interval={llama_check_interval}",
|
533 |
+
f"trainer.accumulate_grad_batches={llama_grad_batches}",
|
534 |
+
f"train_dataset.interactive_prob={llama_use_speaker}",
|
535 |
+
] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
|
536 |
+
logger.info(train_cmd)
|
537 |
+
subprocess.run(train_cmd)
|
538 |
+
|
539 |
+
return build_html_ok_message(i18n("Training stopped"))
|
540 |
+
|
541 |
+
|
542 |
+
def tensorboard_process(
|
543 |
+
if_tensorboard: bool,
|
544 |
+
tensorboard_dir: str,
|
545 |
+
host: str,
|
546 |
+
port: str,
|
547 |
+
):
|
548 |
+
global p_tensorboard
|
549 |
+
if if_tensorboard == True and p_tensorboard == None:
|
550 |
+
url = f"http://{host}:{port}"
|
551 |
+
yield build_html_ok_message(
|
552 |
+
i18n("Tensorboard interface is launched at {}").format(url)
|
553 |
+
)
|
554 |
+
prefix = ["tensorboard"]
|
555 |
+
if Path("fishenv").exists():
|
556 |
+
prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
|
557 |
+
|
558 |
+
p_tensorboard = subprocess.Popen(
|
559 |
+
prefix
|
560 |
+
+ [
|
561 |
+
"--logdir",
|
562 |
+
tensorboard_dir,
|
563 |
+
"--host",
|
564 |
+
host,
|
565 |
+
"--port",
|
566 |
+
port,
|
567 |
+
"--reload_interval",
|
568 |
+
"120",
|
569 |
+
]
|
570 |
+
)
|
571 |
+
elif if_tensorboard == False and p_tensorboard != None:
|
572 |
+
kill_process(p_tensorboard.pid)
|
573 |
+
p_tensorboard = None
|
574 |
+
yield build_html_error_message(i18n("Tensorboard interface is closed"))
|
575 |
+
|
576 |
+
|
577 |
+
def fresh_tb_dir():
|
578 |
+
return gr.Dropdown(
|
579 |
+
choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
|
580 |
+
)
|
581 |
+
|
582 |
+
|
583 |
+
def list_decoder_models():
|
584 |
+
paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
|
585 |
+
if not paths:
|
586 |
+
logger.warning("No decoder model found")
|
587 |
+
return paths
|
588 |
+
|
589 |
+
|
590 |
+
def list_llama_models():
|
591 |
+
choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
|
592 |
+
choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
|
593 |
+
choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
|
594 |
+
choices = sorted(choices, reverse=True)
|
595 |
+
if not choices:
|
596 |
+
logger.warning("No LLaMA model found")
|
597 |
+
return choices
|
598 |
+
|
599 |
+
|
600 |
+
def list_lora_llama_models():
|
601 |
+
choices = sorted(
|
602 |
+
[str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
|
603 |
+
)
|
604 |
+
if not choices:
|
605 |
+
logger.warning("No LoRA LLaMA model found")
|
606 |
+
return choices
|
607 |
+
|
608 |
+
|
609 |
+
def fresh_decoder_model():
|
610 |
+
return gr.Dropdown(choices=list_decoder_models())
|
611 |
+
|
612 |
+
|
613 |
+
def fresh_llama_ckpt(llama_use_lora):
|
614 |
+
return gr.Dropdown(
|
615 |
+
choices=[i18n("latest"), i18n("new")]
|
616 |
+
+ (
|
617 |
+
[str(p) for p in Path("results").glob("text2sem*/")]
|
618 |
+
if not llama_use_lora
|
619 |
+
else [str(p) for p in Path("results").glob("lora_*/")]
|
620 |
+
)
|
621 |
+
)
|
622 |
+
|
623 |
+
|
624 |
+
def fresh_llama_model():
|
625 |
+
return gr.Dropdown(choices=list_llama_models())
|
626 |
+
|
627 |
+
|
628 |
+
def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
|
629 |
+
if (
|
630 |
+
lora_weight is None
|
631 |
+
or not Path(lora_weight).exists()
|
632 |
+
or not Path(llama_weight).exists()
|
633 |
+
):
|
634 |
+
return build_html_error_message(
|
635 |
+
i18n(
|
636 |
+
"Path error, please check the model file exists in the corresponding path"
|
637 |
+
)
|
638 |
+
)
|
639 |
+
gr.Warning("Merging begins...")
|
640 |
+
merge_cmd = [
|
641 |
+
PYTHON,
|
642 |
+
"tools/llama/merge_lora.py",
|
643 |
+
"--lora-config",
|
644 |
+
"r_8_alpha_16",
|
645 |
+
"--lora-weight",
|
646 |
+
lora_weight,
|
647 |
+
"--output",
|
648 |
+
llama_lora_output + "_" + generate_folder_name(),
|
649 |
+
]
|
650 |
+
logger.info(merge_cmd)
|
651 |
+
subprocess.run(merge_cmd)
|
652 |
+
return build_html_ok_message(i18n("Merge successfully"))
|
653 |
+
|
654 |
+
|
655 |
+
def llama_quantify(llama_weight, quantify_mode):
|
656 |
+
if llama_weight is None or not Path(llama_weight).exists():
|
657 |
+
return build_html_error_message(
|
658 |
+
i18n(
|
659 |
+
"Path error, please check the model file exists in the corresponding path"
|
660 |
+
)
|
661 |
+
)
|
662 |
+
|
663 |
+
gr.Warning("Quantifying begins...")
|
664 |
+
|
665 |
+
now = generate_folder_name()
|
666 |
+
quantify_cmd = [
|
667 |
+
PYTHON,
|
668 |
+
"tools/llama/quantize.py",
|
669 |
+
"--checkpoint-path",
|
670 |
+
llama_weight,
|
671 |
+
"--mode",
|
672 |
+
quantify_mode,
|
673 |
+
"--timestamp",
|
674 |
+
now,
|
675 |
+
]
|
676 |
+
logger.info(quantify_cmd)
|
677 |
+
subprocess.run(quantify_cmd)
|
678 |
+
if quantify_mode == "int8":
|
679 |
+
quantize_path = str(
|
680 |
+
Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
|
681 |
+
)
|
682 |
+
else:
|
683 |
+
quantize_path = str(
|
684 |
+
Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
|
685 |
+
)
|
686 |
+
return build_html_ok_message(
|
687 |
+
i18n("Quantify successfully") + f"Path: {quantize_path}"
|
688 |
+
)
|
689 |
+
|
690 |
+
|
691 |
+
init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
|
692 |
+
init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
|
693 |
+
|
694 |
+
with gr.Blocks(
|
695 |
+
head="<style>\n" + css + "\n</style>",
|
696 |
+
js=js,
|
697 |
+
theme=seafoam,
|
698 |
+
analytics_enabled=False,
|
699 |
+
title="Fish Speech",
|
700 |
+
) as demo:
|
701 |
+
with gr.Row():
|
702 |
+
with gr.Column():
|
703 |
+
with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
|
704 |
+
with gr.Row():
|
705 |
+
textbox = gr.Textbox(
|
706 |
+
label="\U0000270F "
|
707 |
+
+ i18n("Input Audio & Source Path for Transcription"),
|
708 |
+
info=i18n("Speaker is identified by the folder name"),
|
709 |
+
interactive=True,
|
710 |
+
)
|
711 |
+
with gr.Row(equal_height=False):
|
712 |
+
with gr.Column():
|
713 |
+
output_radio = gr.Radio(
|
714 |
+
label="\U0001F4C1 "
|
715 |
+
+ i18n("Select source file processing method"),
|
716 |
+
choices=[i18n("Copy"), i18n("Move")],
|
717 |
+
value=i18n("Copy"),
|
718 |
+
interactive=True,
|
719 |
+
)
|
720 |
+
with gr.Column():
|
721 |
+
error = gr.HTML(label=i18n("Error Message"))
|
722 |
+
if_label = gr.Checkbox(
|
723 |
+
label=i18n("Open Labeler WebUI"), scale=0, show_label=True
|
724 |
+
)
|
725 |
+
|
726 |
+
with gr.Row():
|
727 |
+
label_device = gr.Dropdown(
|
728 |
+
label=i18n("Labeling Device"),
|
729 |
+
info=i18n(
|
730 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU"
|
731 |
+
),
|
732 |
+
choices=["cpu", "cuda"],
|
733 |
+
value="cuda",
|
734 |
+
interactive=True,
|
735 |
+
)
|
736 |
+
label_model = gr.Dropdown(
|
737 |
+
label=i18n("Whisper Model"),
|
738 |
+
info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
|
739 |
+
choices=["large-v3", "medium"],
|
740 |
+
value="large-v3",
|
741 |
+
interactive=True,
|
742 |
+
)
|
743 |
+
label_radio = gr.Dropdown(
|
744 |
+
label=i18n("Optional Label Language"),
|
745 |
+
info=i18n(
|
746 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
|
747 |
+
),
|
748 |
+
choices=[
|
749 |
+
(i18n("Chinese"), "zh"),
|
750 |
+
(i18n("English"), "en"),
|
751 |
+
(i18n("Japanese"), "ja"),
|
752 |
+
(i18n("Disabled"), "IGNORE"),
|
753 |
+
(i18n("auto"), "auto"),
|
754 |
+
],
|
755 |
+
value="IGNORE",
|
756 |
+
interactive=True,
|
757 |
+
)
|
758 |
+
|
759 |
+
with gr.Row():
|
760 |
+
if_initial_prompt = gr.Checkbox(
|
761 |
+
value=False,
|
762 |
+
label=i18n("Enable Initial Prompt"),
|
763 |
+
min_width=120,
|
764 |
+
scale=0,
|
765 |
+
)
|
766 |
+
initial_prompt = gr.Textbox(
|
767 |
+
label=i18n("Initial Prompt"),
|
768 |
+
info=i18n(
|
769 |
+
"Initial prompt can provide contextual or vocabulary-specific guidance to the model."
|
770 |
+
),
|
771 |
+
placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
|
772 |
+
interactive=False,
|
773 |
+
)
|
774 |
+
|
775 |
+
with gr.Row():
|
776 |
+
add_button = gr.Button(
|
777 |
+
"\U000027A1 " + i18n("Add to Processing Area"),
|
778 |
+
variant="primary",
|
779 |
+
)
|
780 |
+
remove_button = gr.Button(
|
781 |
+
"\U000026D4 " + i18n("Remove Selected Data")
|
782 |
+
)
|
783 |
+
|
784 |
+
with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
|
785 |
+
with gr.Row():
|
786 |
+
model_type_radio = gr.Radio(
|
787 |
+
label=i18n(
|
788 |
+
"Select the model to be trained (Depending on the Tab page you are on)"
|
789 |
+
),
|
790 |
+
interactive=False,
|
791 |
+
choices=["VQGAN", "LLAMA"],
|
792 |
+
value="VQGAN",
|
793 |
+
)
|
794 |
+
with gr.Row():
|
795 |
+
with gr.Tabs():
|
796 |
+
with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
|
797 |
+
gr.HTML("You don't need to train this model!")
|
798 |
+
|
799 |
+
with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
|
800 |
+
with gr.Row(equal_height=False):
|
801 |
+
llama_use_lora = gr.Checkbox(
|
802 |
+
label=i18n("Use LoRA"),
|
803 |
+
info=i18n(
|
804 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model"
|
805 |
+
),
|
806 |
+
value=True,
|
807 |
+
interactive=True,
|
808 |
+
)
|
809 |
+
llama_ckpt = gr.Dropdown(
|
810 |
+
label=i18n("Select LLAMA ckpt"),
|
811 |
+
choices=[i18n("latest"), i18n("new")]
|
812 |
+
+ [
|
813 |
+
str(p)
|
814 |
+
for p in Path("results").glob("text2sem*/")
|
815 |
+
]
|
816 |
+
+ [str(p) for p in Path("results").glob("lora*/")],
|
817 |
+
value=i18n("latest"),
|
818 |
+
interactive=True,
|
819 |
+
)
|
820 |
+
with gr.Row(equal_height=False):
|
821 |
+
llama_lr_slider = gr.Slider(
|
822 |
+
label=i18n("Initial Learning Rate"),
|
823 |
+
info=i18n(
|
824 |
+
"lr smaller -> usually train slower but more stable"
|
825 |
+
),
|
826 |
+
interactive=True,
|
827 |
+
minimum=1e-5,
|
828 |
+
maximum=1e-4,
|
829 |
+
step=1e-5,
|
830 |
+
value=5e-5,
|
831 |
+
)
|
832 |
+
llama_maxsteps_slider = gr.Slider(
|
833 |
+
label=i18n("Maximum Training Steps"),
|
834 |
+
info=i18n(
|
835 |
+
"recommend: max_steps = num_audios // batch_size * (2 to 5)"
|
836 |
+
),
|
837 |
+
interactive=True,
|
838 |
+
minimum=1,
|
839 |
+
maximum=10000,
|
840 |
+
step=1,
|
841 |
+
value=50,
|
842 |
+
)
|
843 |
+
with gr.Row(equal_height=False):
|
844 |
+
llama_base_config = gr.Dropdown(
|
845 |
+
label=i18n("Model Size"),
|
846 |
+
choices=[
|
847 |
+
"text2semantic_finetune",
|
848 |
+
],
|
849 |
+
value="text2semantic_finetune",
|
850 |
+
)
|
851 |
+
llama_data_num_workers_slider = gr.Slider(
|
852 |
+
label=i18n("Number of Workers"),
|
853 |
+
minimum=1,
|
854 |
+
maximum=32,
|
855 |
+
step=1,
|
856 |
+
value=4,
|
857 |
+
)
|
858 |
+
with gr.Row(equal_height=False):
|
859 |
+
llama_data_batch_size_slider = gr.Slider(
|
860 |
+
label=i18n("Batch Size"),
|
861 |
+
interactive=True,
|
862 |
+
minimum=1,
|
863 |
+
maximum=32,
|
864 |
+
step=1,
|
865 |
+
value=4,
|
866 |
+
)
|
867 |
+
llama_data_max_length_slider = gr.Slider(
|
868 |
+
label=i18n("Maximum Length per Sample"),
|
869 |
+
interactive=True,
|
870 |
+
minimum=1024,
|
871 |
+
maximum=4096,
|
872 |
+
step=128,
|
873 |
+
value=1024,
|
874 |
+
)
|
875 |
+
with gr.Row(equal_height=False):
|
876 |
+
llama_precision_dropdown = gr.Dropdown(
|
877 |
+
label=i18n("Precision"),
|
878 |
+
info=i18n(
|
879 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
|
880 |
+
),
|
881 |
+
interactive=True,
|
882 |
+
choices=["32", "bf16-true", "16-mixed"],
|
883 |
+
value="bf16-true",
|
884 |
+
)
|
885 |
+
llama_check_interval_slider = gr.Slider(
|
886 |
+
label=i18n("Save model every n steps"),
|
887 |
+
info=i18n(
|
888 |
+
"make sure that it's not greater than max_steps"
|
889 |
+
),
|
890 |
+
interactive=True,
|
891 |
+
minimum=1,
|
892 |
+
maximum=1000,
|
893 |
+
step=1,
|
894 |
+
value=50,
|
895 |
+
)
|
896 |
+
with gr.Row(equal_height=False):
|
897 |
+
llama_grad_batches = gr.Slider(
|
898 |
+
label=i18n("Accumulate Gradient Batches"),
|
899 |
+
interactive=True,
|
900 |
+
minimum=1,
|
901 |
+
maximum=20,
|
902 |
+
step=1,
|
903 |
+
value=init_llama_yml["trainer"][
|
904 |
+
"accumulate_grad_batches"
|
905 |
+
],
|
906 |
+
)
|
907 |
+
llama_use_speaker = gr.Slider(
|
908 |
+
label=i18n(
|
909 |
+
"Probability of applying Speaker Condition"
|
910 |
+
),
|
911 |
+
interactive=True,
|
912 |
+
minimum=0.1,
|
913 |
+
maximum=1.0,
|
914 |
+
step=0.05,
|
915 |
+
value=init_llama_yml["train_dataset"][
|
916 |
+
"interactive_prob"
|
917 |
+
],
|
918 |
+
)
|
919 |
+
|
920 |
+
with gr.Tab(label=i18n("Merge LoRA"), id=4):
|
921 |
+
with gr.Row(equal_height=False):
|
922 |
+
llama_weight = gr.Dropdown(
|
923 |
+
label=i18n("Base LLAMA Model"),
|
924 |
+
info=i18n(
|
925 |
+
"Type the path or select from the dropdown"
|
926 |
+
),
|
927 |
+
choices=[
|
928 |
+
"checkpoints/fish-speech-1.4/model.pth",
|
929 |
+
],
|
930 |
+
value="checkpoints/fish-speech-1.4/model.pth",
|
931 |
+
allow_custom_value=True,
|
932 |
+
interactive=True,
|
933 |
+
)
|
934 |
+
with gr.Row(equal_height=False):
|
935 |
+
lora_weight = gr.Dropdown(
|
936 |
+
label=i18n("LoRA Model to be merged"),
|
937 |
+
info=i18n(
|
938 |
+
"Type the path or select from the dropdown"
|
939 |
+
),
|
940 |
+
choices=[
|
941 |
+
str(p)
|
942 |
+
for p in Path("results").glob("lora*/**/*.ckpt")
|
943 |
+
],
|
944 |
+
allow_custom_value=True,
|
945 |
+
interactive=True,
|
946 |
+
)
|
947 |
+
lora_llama_config = gr.Dropdown(
|
948 |
+
label=i18n("LLAMA Model Config"),
|
949 |
+
info=i18n(
|
950 |
+
"Type the path or select from the dropdown"
|
951 |
+
),
|
952 |
+
choices=[
|
953 |
+
"text2semantic_finetune",
|
954 |
+
],
|
955 |
+
value="text2semantic_finetune",
|
956 |
+
allow_custom_value=True,
|
957 |
+
)
|
958 |
+
with gr.Row(equal_height=False):
|
959 |
+
llama_lora_output = gr.Dropdown(
|
960 |
+
label=i18n("Output Path"),
|
961 |
+
info=i18n(
|
962 |
+
"Type the path or select from the dropdown"
|
963 |
+
),
|
964 |
+
value="checkpoints/merged",
|
965 |
+
choices=["checkpoints/merged"],
|
966 |
+
allow_custom_value=True,
|
967 |
+
interactive=True,
|
968 |
+
)
|
969 |
+
with gr.Row(equal_height=False):
|
970 |
+
llama_lora_merge_btn = gr.Button(
|
971 |
+
value=i18n("Merge"), variant="primary"
|
972 |
+
)
|
973 |
+
|
974 |
+
with gr.Tab(label=i18n("Model Quantization"), id=5):
|
975 |
+
with gr.Row(equal_height=False):
|
976 |
+
llama_weight_to_quantify = gr.Dropdown(
|
977 |
+
label=i18n("Base LLAMA Model"),
|
978 |
+
info=i18n(
|
979 |
+
"Type the path or select from the dropdown"
|
980 |
+
),
|
981 |
+
choices=list_llama_models(),
|
982 |
+
value="checkpoints/fish-speech-1.4",
|
983 |
+
allow_custom_value=True,
|
984 |
+
interactive=True,
|
985 |
+
)
|
986 |
+
quantify_mode = gr.Dropdown(
|
987 |
+
label=i18n("Post-quantification Precision"),
|
988 |
+
info=i18n(
|
989 |
+
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
|
990 |
+
),
|
991 |
+
choices=["int8", "int4"],
|
992 |
+
value="int8",
|
993 |
+
allow_custom_value=False,
|
994 |
+
interactive=True,
|
995 |
+
)
|
996 |
+
with gr.Row(equal_height=False):
|
997 |
+
llama_quantify_btn = gr.Button(
|
998 |
+
value=i18n("Quantify"), variant="primary"
|
999 |
+
)
|
1000 |
+
|
1001 |
+
with gr.Tab(label="Tensorboard", id=6):
|
1002 |
+
with gr.Row(equal_height=False):
|
1003 |
+
tb_host = gr.Textbox(
|
1004 |
+
label=i18n("Tensorboard Host"), value="127.0.0.1"
|
1005 |
+
)
|
1006 |
+
tb_port = gr.Textbox(
|
1007 |
+
label=i18n("Tensorboard Port"), value="11451"
|
1008 |
+
)
|
1009 |
+
with gr.Row(equal_height=False):
|
1010 |
+
tb_dir = gr.Dropdown(
|
1011 |
+
label=i18n("Tensorboard Log Path"),
|
1012 |
+
allow_custom_value=True,
|
1013 |
+
choices=[
|
1014 |
+
str(p)
|
1015 |
+
for p in Path("results").glob("**/tensorboard/")
|
1016 |
+
],
|
1017 |
+
)
|
1018 |
+
with gr.Row(equal_height=False):
|
1019 |
+
if_tb = gr.Checkbox(
|
1020 |
+
label=i18n("Open Tensorboard"),
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
|
1024 |
+
with gr.Column():
|
1025 |
+
with gr.Row():
|
1026 |
+
with gr.Accordion(
|
1027 |
+
label="\U0001F5A5 "
|
1028 |
+
+ i18n("Inference Server Configuration"),
|
1029 |
+
open=False,
|
1030 |
+
):
|
1031 |
+
with gr.Row():
|
1032 |
+
infer_host_textbox = gr.Textbox(
|
1033 |
+
label=i18n("WebUI Host"), value="127.0.0.1"
|
1034 |
+
)
|
1035 |
+
infer_port_textbox = gr.Textbox(
|
1036 |
+
label=i18n("WebUI Port"), value="7862"
|
1037 |
+
)
|
1038 |
+
with gr.Row():
|
1039 |
+
infer_decoder_model = gr.Dropdown(
|
1040 |
+
label=i18n("Decoder Model Path"),
|
1041 |
+
info=i18n(
|
1042 |
+
"Type the path or select from the dropdown"
|
1043 |
+
),
|
1044 |
+
choices=list_decoder_models(),
|
1045 |
+
value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
1046 |
+
allow_custom_value=True,
|
1047 |
+
)
|
1048 |
+
infer_decoder_config = gr.Dropdown(
|
1049 |
+
label=i18n("Decoder Model Config"),
|
1050 |
+
info=i18n("Changing with the Model Path"),
|
1051 |
+
value="firefly_gan_vq",
|
1052 |
+
choices=[
|
1053 |
+
"firefly_gan_vq",
|
1054 |
+
],
|
1055 |
+
allow_custom_value=True,
|
1056 |
+
)
|
1057 |
+
with gr.Row():
|
1058 |
+
infer_llama_model = gr.Dropdown(
|
1059 |
+
label=i18n("LLAMA Model Path"),
|
1060 |
+
info=i18n(
|
1061 |
+
"Type the path or select from the dropdown"
|
1062 |
+
),
|
1063 |
+
value="checkpoints/fish-speech-1.4",
|
1064 |
+
choices=list_llama_models(),
|
1065 |
+
allow_custom_value=True,
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
with gr.Row():
|
1069 |
+
infer_compile = gr.Radio(
|
1070 |
+
label=i18n("Compile Model"),
|
1071 |
+
info=i18n(
|
1072 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time"
|
1073 |
+
),
|
1074 |
+
choices=["Yes", "No"],
|
1075 |
+
value=(
|
1076 |
+
"Yes" if (sys.platform == "linux") else "No"
|
1077 |
+
),
|
1078 |
+
interactive=is_module_installed("triton"),
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
with gr.Row():
|
1082 |
+
infer_checkbox = gr.Checkbox(
|
1083 |
+
label=i18n("Open Inference Server")
|
1084 |
+
)
|
1085 |
+
infer_error = gr.HTML(label=i18n("Inference Server Error"))
|
1086 |
+
|
1087 |
+
with gr.Column():
|
1088 |
+
train_error = gr.HTML(label=i18n("Training Error"))
|
1089 |
+
checkbox_group = gr.CheckboxGroup(
|
1090 |
+
label="\U0001F4CA " + i18n("Data Source"),
|
1091 |
+
info=i18n(
|
1092 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
|
1093 |
+
),
|
1094 |
+
elem_classes=["data_src"],
|
1095 |
+
)
|
1096 |
+
train_box = gr.Textbox(
|
1097 |
+
label=i18n("Data Preprocessing Path"),
|
1098 |
+
value=str(data_pre_output),
|
1099 |
+
interactive=False,
|
1100 |
+
)
|
1101 |
+
model_box = gr.Textbox(
|
1102 |
+
label="\U0001F4BE " + i18n("Model Output Path"),
|
1103 |
+
value=str(default_model_output),
|
1104 |
+
interactive=False,
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
with gr.Accordion(
|
1108 |
+
i18n(
|
1109 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)"
|
1110 |
+
),
|
1111 |
+
elem_classes=["scrollable-component"],
|
1112 |
+
elem_id="file_accordion",
|
1113 |
+
):
|
1114 |
+
tree_slider = gr.Slider(
|
1115 |
+
minimum=0,
|
1116 |
+
maximum=3,
|
1117 |
+
value=0,
|
1118 |
+
step=1,
|
1119 |
+
show_label=False,
|
1120 |
+
container=False,
|
1121 |
+
)
|
1122 |
+
file_markdown = new_explorer(str(data_pre_output), 0)
|
1123 |
+
with gr.Row(equal_height=False):
|
1124 |
+
admit_btn = gr.Button(
|
1125 |
+
"\U00002705 " + i18n("File Preprocessing"),
|
1126 |
+
variant="primary",
|
1127 |
+
)
|
1128 |
+
fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
|
1129 |
+
help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
|
1130 |
+
train_btn = gr.Button(i18n("Start Training"), variant="primary")
|
1131 |
+
|
1132 |
+
footer = load_data_in_raw("fish_speech/webui/html/footer.html")
|
1133 |
+
footer = footer.format(
|
1134 |
+
versions=versions_html(),
|
1135 |
+
api_docs="https://speech.fish.audio/inference/#http-api",
|
1136 |
+
)
|
1137 |
+
gr.HTML(footer, elem_id="footer")
|
1138 |
+
vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
|
1139 |
+
llama_page.select(lambda: "LLAMA", None, model_type_radio)
|
1140 |
+
add_button.click(
|
1141 |
+
fn=add_item,
|
1142 |
+
inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
|
1143 |
+
outputs=[checkbox_group, error],
|
1144 |
+
)
|
1145 |
+
remove_button.click(
|
1146 |
+
fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
|
1147 |
+
)
|
1148 |
+
checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
|
1149 |
+
help_button.click(
|
1150 |
+
fn=None,
|
1151 |
+
js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
|
1152 |
+
'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
|
1153 |
+
)
|
1154 |
+
if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
|
1155 |
+
if_initial_prompt.change(
|
1156 |
+
fn=lambda x: gr.Textbox(value="", interactive=x),
|
1157 |
+
inputs=[if_initial_prompt],
|
1158 |
+
outputs=[initial_prompt],
|
1159 |
+
)
|
1160 |
+
train_btn.click(
|
1161 |
+
fn=train_process,
|
1162 |
+
inputs=[
|
1163 |
+
train_box,
|
1164 |
+
model_type_radio,
|
1165 |
+
# llama config
|
1166 |
+
llama_ckpt,
|
1167 |
+
llama_base_config,
|
1168 |
+
llama_lr_slider,
|
1169 |
+
llama_maxsteps_slider,
|
1170 |
+
llama_data_num_workers_slider,
|
1171 |
+
llama_data_batch_size_slider,
|
1172 |
+
llama_data_max_length_slider,
|
1173 |
+
llama_precision_dropdown,
|
1174 |
+
llama_check_interval_slider,
|
1175 |
+
llama_grad_batches,
|
1176 |
+
llama_use_speaker,
|
1177 |
+
llama_use_lora,
|
1178 |
+
],
|
1179 |
+
outputs=[train_error],
|
1180 |
+
)
|
1181 |
+
if_tb.change(
|
1182 |
+
fn=tensorboard_process,
|
1183 |
+
inputs=[if_tb, tb_dir, tb_host, tb_port],
|
1184 |
+
outputs=[train_error],
|
1185 |
+
)
|
1186 |
+
tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
|
1187 |
+
infer_decoder_model.change(
|
1188 |
+
fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
|
1189 |
+
)
|
1190 |
+
infer_llama_model.change(
|
1191 |
+
fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
|
1192 |
+
)
|
1193 |
+
llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
|
1194 |
+
admit_btn.click(
|
1195 |
+
fn=check_files,
|
1196 |
+
inputs=[train_box, tree_slider, label_model, label_device],
|
1197 |
+
outputs=[error, file_markdown],
|
1198 |
+
)
|
1199 |
+
fresh_btn.click(
|
1200 |
+
fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
|
1201 |
+
)
|
1202 |
+
llama_use_lora.change(
|
1203 |
+
fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
|
1204 |
+
)
|
1205 |
+
llama_ckpt.change(
|
1206 |
+
fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
|
1207 |
+
)
|
1208 |
+
lora_weight.change(
|
1209 |
+
fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
|
1210 |
+
inputs=[],
|
1211 |
+
outputs=[lora_weight],
|
1212 |
+
)
|
1213 |
+
llama_lora_merge_btn.click(
|
1214 |
+
fn=llama_lora_merge,
|
1215 |
+
inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
|
1216 |
+
outputs=[train_error],
|
1217 |
+
)
|
1218 |
+
llama_quantify_btn.click(
|
1219 |
+
fn=llama_quantify,
|
1220 |
+
inputs=[llama_weight_to_quantify, quantify_mode],
|
1221 |
+
outputs=[train_error],
|
1222 |
+
)
|
1223 |
+
infer_checkbox.change(
|
1224 |
+
fn=change_infer,
|
1225 |
+
inputs=[
|
1226 |
+
infer_checkbox,
|
1227 |
+
infer_host_textbox,
|
1228 |
+
infer_port_textbox,
|
1229 |
+
infer_decoder_model,
|
1230 |
+
infer_decoder_config,
|
1231 |
+
infer_llama_model,
|
1232 |
+
infer_compile,
|
1233 |
+
],
|
1234 |
+
outputs=[infer_error],
|
1235 |
+
)
|
1236 |
+
|
1237 |
+
demo.launch(inbrowser=True)
|
requirements.txt
CHANGED
@@ -24,4 +24,5 @@ resampy>=0.4.3
|
|
24 |
spaces>=0.26.1
|
25 |
einx[torch]==0.2.0
|
26 |
opencc
|
27 |
-
faster-whisper
|
|
|
|
24 |
spaces>=0.26.1
|
25 |
einx[torch]==0.2.0
|
26 |
opencc
|
27 |
+
faster-whisper
|
28 |
+
ormsgpack
|
tools/api.py
CHANGED
@@ -3,21 +3,26 @@ import io
|
|
3 |
import json
|
4 |
import queue
|
5 |
import random
|
|
|
6 |
import traceback
|
7 |
import wave
|
8 |
from argparse import ArgumentParser
|
9 |
from http import HTTPStatus
|
10 |
from pathlib import Path
|
11 |
-
from typing import Annotated, Literal, Optional
|
12 |
|
13 |
-
import librosa
|
14 |
import numpy as np
|
|
|
15 |
import pyrootutils
|
16 |
import soundfile as sf
|
17 |
import torch
|
|
|
|
|
18 |
from kui.asgi import (
|
19 |
Body,
|
|
|
20 |
HTTPException,
|
|
|
21 |
HttpView,
|
22 |
JSONResponse,
|
23 |
Kui,
|
@@ -26,13 +31,16 @@ from kui.asgi import (
|
|
26 |
)
|
27 |
from kui.asgi.routing import MultimethodRoutes
|
28 |
from loguru import logger
|
29 |
-
from pydantic import BaseModel, Field
|
30 |
|
31 |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
32 |
|
33 |
# from fish_speech.models.vqgan.lit_module import VQGAN
|
34 |
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
35 |
-
from
|
|
|
|
|
|
|
36 |
from tools.llama.generate import (
|
37 |
GenerateRequest,
|
38 |
GenerateResponse,
|
@@ -80,13 +88,21 @@ async def other_exception_handler(exc: "Exception"):
|
|
80 |
|
81 |
def load_audio(reference_audio, sr):
|
82 |
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
83 |
-
|
84 |
-
|
85 |
-
reference_audio = io.BytesIO(audio_data)
|
86 |
-
except base64.binascii.Error:
|
87 |
-
raise ValueError("Invalid path or base64 string")
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
return audio
|
91 |
|
92 |
|
@@ -132,7 +148,7 @@ def decode_vq_tokens(
|
|
132 |
return decoder_model.decode(
|
133 |
indices=codes[None],
|
134 |
feature_lengths=feature_lengths,
|
135 |
-
).squeeze()
|
136 |
|
137 |
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
138 |
|
@@ -140,58 +156,6 @@ def decode_vq_tokens(
|
|
140 |
routes = MultimethodRoutes(base_class=HttpView)
|
141 |
|
142 |
|
143 |
-
def get_random_paths(base_path, data, speaker, emotion):
|
144 |
-
if base_path and data and speaker and emotion and (Path(base_path).exists()):
|
145 |
-
if speaker in data and emotion in data[speaker]:
|
146 |
-
files = data[speaker][emotion]
|
147 |
-
lab_files = [f for f in files if f.endswith(".lab")]
|
148 |
-
wav_files = [f for f in files if f.endswith(".wav")]
|
149 |
-
|
150 |
-
if lab_files and wav_files:
|
151 |
-
selected_lab = random.choice(lab_files)
|
152 |
-
selected_wav = random.choice(wav_files)
|
153 |
-
|
154 |
-
lab_path = Path(base_path) / speaker / emotion / selected_lab
|
155 |
-
wav_path = Path(base_path) / speaker / emotion / selected_wav
|
156 |
-
if lab_path.exists() and wav_path.exists():
|
157 |
-
return lab_path, wav_path
|
158 |
-
|
159 |
-
return None, None
|
160 |
-
|
161 |
-
|
162 |
-
def load_json(json_file):
|
163 |
-
if not json_file:
|
164 |
-
logger.info("Not using a json file")
|
165 |
-
return None
|
166 |
-
try:
|
167 |
-
with open(json_file, "r", encoding="utf-8") as file:
|
168 |
-
data = json.load(file)
|
169 |
-
except FileNotFoundError:
|
170 |
-
logger.warning(f"ref json not found: {json_file}")
|
171 |
-
data = None
|
172 |
-
except Exception as e:
|
173 |
-
logger.warning(f"Loading json failed: {e}")
|
174 |
-
data = None
|
175 |
-
return data
|
176 |
-
|
177 |
-
|
178 |
-
class InvokeRequest(BaseModel):
|
179 |
-
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
180 |
-
reference_text: Optional[str] = None
|
181 |
-
reference_audio: Optional[str] = None
|
182 |
-
max_new_tokens: int = 1024
|
183 |
-
chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
|
184 |
-
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
185 |
-
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
186 |
-
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
187 |
-
emotion: Optional[str] = None
|
188 |
-
format: Literal["wav", "mp3", "flac"] = "wav"
|
189 |
-
streaming: bool = False
|
190 |
-
ref_json: Optional[str] = "ref_data.json"
|
191 |
-
ref_base: Optional[str] = "ref_data"
|
192 |
-
speaker: Optional[str] = None
|
193 |
-
|
194 |
-
|
195 |
def get_content_type(audio_format):
|
196 |
if audio_format == "wav":
|
197 |
return "audio/wav"
|
@@ -204,35 +168,52 @@ def get_content_type(audio_format):
|
|
204 |
|
205 |
|
206 |
@torch.inference_mode()
|
207 |
-
def inference(req:
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
# LLAMA Inference
|
232 |
request = dict(
|
233 |
device=decoder_model.device,
|
234 |
max_new_tokens=req.max_new_tokens,
|
235 |
-
text=
|
|
|
|
|
|
|
|
|
236 |
top_p=req.top_p,
|
237 |
repetition_penalty=req.repetition_penalty,
|
238 |
temperature=req.temperature,
|
@@ -241,7 +222,7 @@ def inference(req: InvokeRequest):
|
|
241 |
chunk_length=req.chunk_length,
|
242 |
max_length=2048,
|
243 |
prompt_tokens=prompt_tokens,
|
244 |
-
prompt_text=
|
245 |
)
|
246 |
|
247 |
response_queue = queue.Queue()
|
@@ -266,7 +247,7 @@ def inference(req: InvokeRequest):
|
|
266 |
if result.action == "next":
|
267 |
break
|
268 |
|
269 |
-
with
|
270 |
device_type=decoder_model.device.type, dtype=args.precision
|
271 |
):
|
272 |
fake_audios = decode_vq_tokens(
|
@@ -294,40 +275,7 @@ def inference(req: InvokeRequest):
|
|
294 |
yield fake_audios
|
295 |
|
296 |
|
297 |
-
def
|
298 |
-
if not use_auto_rerank:
|
299 |
-
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
|
300 |
-
return inference(req)
|
301 |
-
|
302 |
-
zh_model, en_model = load_model()
|
303 |
-
max_attempts = 5
|
304 |
-
best_wer = float("inf")
|
305 |
-
best_audio = None
|
306 |
-
|
307 |
-
for attempt in range(max_attempts):
|
308 |
-
# 调用原始的 inference 函数
|
309 |
-
audio_generator = inference(req)
|
310 |
-
fake_audios = next(audio_generator)
|
311 |
-
|
312 |
-
asr_result = batch_asr(
|
313 |
-
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
|
314 |
-
)[0]
|
315 |
-
wer = calculate_wer(req.text, asr_result["text"])
|
316 |
-
|
317 |
-
if wer <= 0.1 and not asr_result["huge_gap"]:
|
318 |
-
return fake_audios
|
319 |
-
|
320 |
-
if wer < best_wer:
|
321 |
-
best_wer = wer
|
322 |
-
best_audio = fake_audios
|
323 |
-
|
324 |
-
if attempt == max_attempts - 1:
|
325 |
-
break
|
326 |
-
|
327 |
-
return best_audio
|
328 |
-
|
329 |
-
|
330 |
-
async def inference_async(req: InvokeRequest):
|
331 |
for chunk in inference(req):
|
332 |
yield chunk
|
333 |
|
@@ -336,9 +284,9 @@ async def buffer_to_async_generator(buffer):
|
|
336 |
yield buffer
|
337 |
|
338 |
|
339 |
-
@routes.http.post("/v1/
|
340 |
async def api_invoke_model(
|
341 |
-
req: Annotated[
|
342 |
):
|
343 |
"""
|
344 |
Invoke model and generate audio
|
@@ -397,19 +345,19 @@ def parse_args():
|
|
397 |
parser.add_argument(
|
398 |
"--llama-checkpoint-path",
|
399 |
type=str,
|
400 |
-
default="checkpoints/fish-speech-1.
|
401 |
)
|
402 |
parser.add_argument(
|
403 |
"--decoder-checkpoint-path",
|
404 |
type=str,
|
405 |
-
default="checkpoints/fish-speech-1.
|
406 |
)
|
407 |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
408 |
parser.add_argument("--device", type=str, default="cuda")
|
409 |
parser.add_argument("--half", action="store_true")
|
410 |
parser.add_argument("--compile", action="store_true")
|
411 |
parser.add_argument("--max-text-length", type=int, default=0)
|
412 |
-
parser.add_argument("--listen", type=str, default="127.0.0.1:
|
413 |
parser.add_argument("--workers", type=int, default=1)
|
414 |
parser.add_argument("--use-auto-rerank", type=bool, default=True)
|
415 |
|
@@ -423,18 +371,30 @@ openapi = OpenAPI(
|
|
423 |
},
|
424 |
).routes
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
app = Kui(
|
427 |
routes=routes + openapi[1:], # Remove the default route
|
428 |
exception_handlers={
|
429 |
HTTPException: http_execption_handler,
|
430 |
Exception: other_exception_handler,
|
431 |
},
|
|
|
432 |
cors_config={},
|
433 |
)
|
434 |
|
435 |
|
436 |
if __name__ == "__main__":
|
437 |
-
import threading
|
438 |
|
439 |
import uvicorn
|
440 |
|
@@ -461,18 +421,16 @@ if __name__ == "__main__":
|
|
461 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
462 |
list(
|
463 |
inference(
|
464 |
-
|
465 |
text="Hello world.",
|
466 |
-
|
467 |
-
|
468 |
max_new_tokens=0,
|
469 |
top_p=0.7,
|
470 |
repetition_penalty=1.2,
|
471 |
temperature=0.7,
|
472 |
emotion=None,
|
473 |
format="wav",
|
474 |
-
ref_base=None,
|
475 |
-
ref_json=None,
|
476 |
)
|
477 |
)
|
478 |
)
|
|
|
3 |
import json
|
4 |
import queue
|
5 |
import random
|
6 |
+
import sys
|
7 |
import traceback
|
8 |
import wave
|
9 |
from argparse import ArgumentParser
|
10 |
from http import HTTPStatus
|
11 |
from pathlib import Path
|
12 |
+
from typing import Annotated, Any, Literal, Optional
|
13 |
|
|
|
14 |
import numpy as np
|
15 |
+
import ormsgpack
|
16 |
import pyrootutils
|
17 |
import soundfile as sf
|
18 |
import torch
|
19 |
+
import torchaudio
|
20 |
+
from baize.datastructures import ContentType
|
21 |
from kui.asgi import (
|
22 |
Body,
|
23 |
+
FactoryClass,
|
24 |
HTTPException,
|
25 |
+
HttpRequest,
|
26 |
HttpView,
|
27 |
JSONResponse,
|
28 |
Kui,
|
|
|
31 |
)
|
32 |
from kui.asgi.routing import MultimethodRoutes
|
33 |
from loguru import logger
|
34 |
+
from pydantic import BaseModel, Field, conint
|
35 |
|
36 |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
37 |
|
38 |
# from fish_speech.models.vqgan.lit_module import VQGAN
|
39 |
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
40 |
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
41 |
+
from fish_speech.utils import autocast_exclude_mps
|
42 |
+
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
43 |
+
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
44 |
from tools.llama.generate import (
|
45 |
GenerateRequest,
|
46 |
GenerateResponse,
|
|
|
88 |
|
89 |
def load_audio(reference_audio, sr):
|
90 |
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
91 |
+
audio_data = reference_audio
|
92 |
+
reference_audio = io.BytesIO(audio_data)
|
|
|
|
|
|
|
93 |
|
94 |
+
waveform, original_sr = torchaudio.load(
|
95 |
+
reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
|
96 |
+
)
|
97 |
+
|
98 |
+
if waveform.shape[0] > 1:
|
99 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
100 |
+
|
101 |
+
if original_sr != sr:
|
102 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
103 |
+
waveform = resampler(waveform)
|
104 |
+
|
105 |
+
audio = waveform.squeeze().numpy()
|
106 |
return audio
|
107 |
|
108 |
|
|
|
148 |
return decoder_model.decode(
|
149 |
indices=codes[None],
|
150 |
feature_lengths=feature_lengths,
|
151 |
+
)[0].squeeze()
|
152 |
|
153 |
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
154 |
|
|
|
156 |
routes = MultimethodRoutes(base_class=HttpView)
|
157 |
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
def get_content_type(audio_format):
|
160 |
if audio_format == "wav":
|
161 |
return "audio/wav"
|
|
|
168 |
|
169 |
|
170 |
@torch.inference_mode()
|
171 |
+
def inference(req: ServeTTSRequest):
|
172 |
+
|
173 |
+
idstr: str | None = req.reference_id
|
174 |
+
if idstr is not None:
|
175 |
+
ref_folder = Path("references") / idstr
|
176 |
+
ref_folder.mkdir(parents=True, exist_ok=True)
|
177 |
+
ref_audios = list_files(
|
178 |
+
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
179 |
+
)
|
180 |
+
prompt_tokens = [
|
181 |
+
encode_reference(
|
182 |
+
decoder_model=decoder_model,
|
183 |
+
reference_audio=audio_to_bytes(str(ref_audio)),
|
184 |
+
enable_reference_audio=True,
|
185 |
+
)
|
186 |
+
for ref_audio in ref_audios
|
187 |
+
]
|
188 |
+
prompt_texts = [
|
189 |
+
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
190 |
+
for ref_audio in ref_audios
|
191 |
+
]
|
192 |
+
|
193 |
+
else:
|
194 |
+
# Parse reference audio aka prompt
|
195 |
+
refs = req.references
|
196 |
+
if refs is None:
|
197 |
+
refs = []
|
198 |
+
prompt_tokens = [
|
199 |
+
encode_reference(
|
200 |
+
decoder_model=decoder_model,
|
201 |
+
reference_audio=ref.audio,
|
202 |
+
enable_reference_audio=True,
|
203 |
+
)
|
204 |
+
for ref in refs
|
205 |
+
]
|
206 |
+
prompt_texts = [ref.text for ref in refs]
|
207 |
+
|
208 |
# LLAMA Inference
|
209 |
request = dict(
|
210 |
device=decoder_model.device,
|
211 |
max_new_tokens=req.max_new_tokens,
|
212 |
+
text=(
|
213 |
+
req.text
|
214 |
+
if not req.normalize
|
215 |
+
else ChnNormedText(raw_text=req.text).normalize()
|
216 |
+
),
|
217 |
top_p=req.top_p,
|
218 |
repetition_penalty=req.repetition_penalty,
|
219 |
temperature=req.temperature,
|
|
|
222 |
chunk_length=req.chunk_length,
|
223 |
max_length=2048,
|
224 |
prompt_tokens=prompt_tokens,
|
225 |
+
prompt_text=prompt_texts,
|
226 |
)
|
227 |
|
228 |
response_queue = queue.Queue()
|
|
|
247 |
if result.action == "next":
|
248 |
break
|
249 |
|
250 |
+
with autocast_exclude_mps(
|
251 |
device_type=decoder_model.device.type, dtype=args.precision
|
252 |
):
|
253 |
fake_audios = decode_vq_tokens(
|
|
|
275 |
yield fake_audios
|
276 |
|
277 |
|
278 |
+
async def inference_async(req: ServeTTSRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
for chunk in inference(req):
|
280 |
yield chunk
|
281 |
|
|
|
284 |
yield buffer
|
285 |
|
286 |
|
287 |
+
@routes.http.post("/v1/tts")
|
288 |
async def api_invoke_model(
|
289 |
+
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
290 |
):
|
291 |
"""
|
292 |
Invoke model and generate audio
|
|
|
345 |
parser.add_argument(
|
346 |
"--llama-checkpoint-path",
|
347 |
type=str,
|
348 |
+
default="checkpoints/fish-speech-1.4",
|
349 |
)
|
350 |
parser.add_argument(
|
351 |
"--decoder-checkpoint-path",
|
352 |
type=str,
|
353 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
354 |
)
|
355 |
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
356 |
parser.add_argument("--device", type=str, default="cuda")
|
357 |
parser.add_argument("--half", action="store_true")
|
358 |
parser.add_argument("--compile", action="store_true")
|
359 |
parser.add_argument("--max-text-length", type=int, default=0)
|
360 |
+
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
361 |
parser.add_argument("--workers", type=int, default=1)
|
362 |
parser.add_argument("--use-auto-rerank", type=bool, default=True)
|
363 |
|
|
|
371 |
},
|
372 |
).routes
|
373 |
|
374 |
+
|
375 |
+
class MsgPackRequest(HttpRequest):
|
376 |
+
async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
|
377 |
+
if self.content_type == "application/msgpack":
|
378 |
+
return ormsgpack.unpackb(await self.body)
|
379 |
+
|
380 |
+
raise HTTPException(
|
381 |
+
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
382 |
+
headers={"Accept": "application/msgpack"},
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
app = Kui(
|
387 |
routes=routes + openapi[1:], # Remove the default route
|
388 |
exception_handlers={
|
389 |
HTTPException: http_execption_handler,
|
390 |
Exception: other_exception_handler,
|
391 |
},
|
392 |
+
factory_class=FactoryClass(http=MsgPackRequest),
|
393 |
cors_config={},
|
394 |
)
|
395 |
|
396 |
|
397 |
if __name__ == "__main__":
|
|
|
398 |
|
399 |
import uvicorn
|
400 |
|
|
|
421 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
422 |
list(
|
423 |
inference(
|
424 |
+
ServeTTSRequest(
|
425 |
text="Hello world.",
|
426 |
+
references=[],
|
427 |
+
reference_id=None,
|
428 |
max_new_tokens=0,
|
429 |
top_p=0.7,
|
430 |
repetition_penalty=1.2,
|
431 |
temperature=0.7,
|
432 |
emotion=None,
|
433 |
format="wav",
|
|
|
|
|
434 |
)
|
435 |
)
|
436 |
)
|
tools/commons.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, Literal, Optional
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field, conint
|
4 |
+
|
5 |
+
|
6 |
+
class ServeReferenceAudio(BaseModel):
|
7 |
+
audio: bytes
|
8 |
+
text: str
|
9 |
+
|
10 |
+
|
11 |
+
class ServeTTSRequest(BaseModel):
|
12 |
+
text: str
|
13 |
+
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
14 |
+
# Audio format
|
15 |
+
format: Literal["wav", "pcm", "mp3"] = "wav"
|
16 |
+
mp3_bitrate: Literal[64, 128, 192] = 128
|
17 |
+
# References audios for in-context learning
|
18 |
+
references: list[ServeReferenceAudio] = []
|
19 |
+
# Reference id
|
20 |
+
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
21 |
+
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
22 |
+
reference_id: str | None = None
|
23 |
+
# Normalize text for en & zh, this increase stability for numbers
|
24 |
+
normalize: bool = True
|
25 |
+
mp3_bitrate: Optional[int] = 64
|
26 |
+
opus_bitrate: Optional[int] = -1000
|
27 |
+
# Balance mode will reduce latency to 300ms, but may decrease stability
|
28 |
+
latency: Literal["normal", "balanced"] = "normal"
|
29 |
+
# not usually used below
|
30 |
+
streaming: bool = False
|
31 |
+
emotion: Optional[str] = None
|
32 |
+
max_new_tokens: int = 1024
|
33 |
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
34 |
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
35 |
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
tools/download_models.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
|
5 |
+
|
6 |
+
# Download
|
7 |
+
def check_and_download_files(repo_id, file_list, local_dir):
|
8 |
+
os.makedirs(local_dir, exist_ok=True)
|
9 |
+
for file in file_list:
|
10 |
+
file_path = os.path.join(local_dir, file)
|
11 |
+
if not os.path.exists(file_path):
|
12 |
+
print(f"{file} 不存在,从 Hugging Face 仓库下载...")
|
13 |
+
hf_hub_download(
|
14 |
+
repo_id=repo_id,
|
15 |
+
filename=file,
|
16 |
+
resume_download=True,
|
17 |
+
local_dir=local_dir,
|
18 |
+
local_dir_use_symlinks=False,
|
19 |
+
)
|
20 |
+
else:
|
21 |
+
print(f"{file} 已存在,跳过下载。")
|
22 |
+
|
23 |
+
|
24 |
+
# 1st
|
25 |
+
repo_id_1 = "fishaudio/fish-speech-1.4"
|
26 |
+
local_dir_1 = "./checkpoints/fish-speech-1.4"
|
27 |
+
files_1 = [
|
28 |
+
"model.pth",
|
29 |
+
"README.md",
|
30 |
+
"special_tokens_map.json",
|
31 |
+
"tokenizer_config.json",
|
32 |
+
"tokenizer.json",
|
33 |
+
"config.json",
|
34 |
+
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
35 |
+
]
|
36 |
+
|
37 |
+
# 3rd
|
38 |
+
repo_id_3 = "fishaudio/fish-speech-1"
|
39 |
+
local_dir_3 = "./"
|
40 |
+
files_3 = [
|
41 |
+
"ffmpeg.exe",
|
42 |
+
"ffprobe.exe",
|
43 |
+
]
|
44 |
+
|
45 |
+
# 4th
|
46 |
+
repo_id_4 = "SpicyqSama007/fish-speech-packed"
|
47 |
+
local_dir_4 = "./"
|
48 |
+
files_4 = [
|
49 |
+
"asr-label-win-x64.exe",
|
50 |
+
]
|
51 |
+
|
52 |
+
check_and_download_files(repo_id_1, files_1, local_dir_1)
|
53 |
+
|
54 |
+
check_and_download_files(repo_id_3, files_3, local_dir_3)
|
55 |
+
check_and_download_files(repo_id_4, files_4, local_dir_4)
|
tools/extract_model.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
import torch
|
3 |
+
from loguru import logger
|
4 |
+
|
5 |
+
|
6 |
+
@click.command()
|
7 |
+
@click.argument("model_path")
|
8 |
+
@click.argument("output_path")
|
9 |
+
def main(model_path, output_path):
|
10 |
+
if model_path == output_path:
|
11 |
+
logger.error("Model path and output path are the same")
|
12 |
+
return
|
13 |
+
|
14 |
+
logger.info(f"Loading model from {model_path}")
|
15 |
+
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
16 |
+
torch.save(state_dict, output_path)
|
17 |
+
logger.info(f"Model saved to {output_path}")
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
tools/file.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from loguru import logger
|
6 |
+
from natsort import natsorted
|
7 |
+
|
8 |
+
AUDIO_EXTENSIONS = {
|
9 |
+
".mp3",
|
10 |
+
".wav",
|
11 |
+
".flac",
|
12 |
+
".ogg",
|
13 |
+
".m4a",
|
14 |
+
".wma",
|
15 |
+
".aac",
|
16 |
+
".aiff",
|
17 |
+
".aif",
|
18 |
+
".aifc",
|
19 |
+
}
|
20 |
+
|
21 |
+
VIDEO_EXTENSIONS = {
|
22 |
+
".mp4",
|
23 |
+
".avi",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def audio_to_bytes(file_path):
|
28 |
+
if not file_path or not Path(file_path).exists():
|
29 |
+
return None
|
30 |
+
with open(file_path, "rb") as wav_file:
|
31 |
+
wav = wav_file.read()
|
32 |
+
return wav
|
33 |
+
|
34 |
+
|
35 |
+
def read_ref_text(ref_text):
|
36 |
+
path = Path(ref_text)
|
37 |
+
if path.exists() and path.is_file():
|
38 |
+
with path.open("r", encoding="utf-8") as file:
|
39 |
+
return file.read()
|
40 |
+
return ref_text
|
41 |
+
|
42 |
+
|
43 |
+
def list_files(
|
44 |
+
path: Union[Path, str],
|
45 |
+
extensions: set[str] = None,
|
46 |
+
recursive: bool = False,
|
47 |
+
sort: bool = True,
|
48 |
+
) -> list[Path]:
|
49 |
+
"""List files in a directory.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
path (Path): Path to the directory.
|
53 |
+
extensions (set, optional): Extensions to filter. Defaults to None.
|
54 |
+
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
55 |
+
sort (bool, optional): Whether to sort the files. Defaults to True.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
list: List of files.
|
59 |
+
"""
|
60 |
+
|
61 |
+
if isinstance(path, str):
|
62 |
+
path = Path(path)
|
63 |
+
|
64 |
+
if not path.exists():
|
65 |
+
raise FileNotFoundError(f"Directory {path} does not exist.")
|
66 |
+
|
67 |
+
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
68 |
+
|
69 |
+
if sort:
|
70 |
+
files = natsorted(files)
|
71 |
+
|
72 |
+
return files
|
73 |
+
|
74 |
+
|
75 |
+
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
76 |
+
"""
|
77 |
+
Load a Bert-VITS2 style filelist.
|
78 |
+
"""
|
79 |
+
|
80 |
+
files = set()
|
81 |
+
results = []
|
82 |
+
count_duplicated, count_not_found = 0, 0
|
83 |
+
|
84 |
+
LANGUAGE_TO_LANGUAGES = {
|
85 |
+
"zh": ["zh", "en"],
|
86 |
+
"jp": ["jp", "en"],
|
87 |
+
"en": ["en"],
|
88 |
+
}
|
89 |
+
|
90 |
+
with open(path, "r", encoding="utf-8") as f:
|
91 |
+
for line in f.readlines():
|
92 |
+
splits = line.strip().split("|", maxsplit=3)
|
93 |
+
if len(splits) != 4:
|
94 |
+
logger.warning(f"Invalid line: {line}")
|
95 |
+
continue
|
96 |
+
|
97 |
+
filename, speaker, language, text = splits
|
98 |
+
file = Path(filename)
|
99 |
+
language = language.strip().lower()
|
100 |
+
|
101 |
+
if language == "ja":
|
102 |
+
language = "jp"
|
103 |
+
|
104 |
+
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
105 |
+
languages = LANGUAGE_TO_LANGUAGES[language]
|
106 |
+
|
107 |
+
if file in files:
|
108 |
+
logger.warning(f"Duplicated file: {file}")
|
109 |
+
count_duplicated += 1
|
110 |
+
continue
|
111 |
+
|
112 |
+
if not file.exists():
|
113 |
+
logger.warning(f"File not found: {file}")
|
114 |
+
count_not_found += 1
|
115 |
+
continue
|
116 |
+
|
117 |
+
results.append((file, speaker, languages, text))
|
118 |
+
|
119 |
+
if count_duplicated > 0:
|
120 |
+
logger.warning(f"Total duplicated files: {count_duplicated}")
|
121 |
+
|
122 |
+
if count_not_found > 0:
|
123 |
+
logger.warning(f"Total files not found: {count_not_found}")
|
124 |
+
|
125 |
+
return results
|
tools/llama/build_dataset.py
CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
|
|
13 |
|
14 |
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
15 |
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
16 |
-
from
|
17 |
|
18 |
# To avoid CPU overload
|
19 |
os.environ["MKL_NUM_THREADS"] = "1"
|
|
|
13 |
|
14 |
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
15 |
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
16 |
+
from tools.file import load_filelist
|
17 |
|
18 |
# To avoid CPU overload
|
19 |
os.environ["MKL_NUM_THREADS"] = "1"
|
tools/llama/generate.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import queue
|
3 |
import threading
|
4 |
import time
|
|
|
5 |
from dataclasses import dataclass
|
6 |
from pathlib import Path
|
7 |
from typing import Literal, Optional, Tuple, Union
|
@@ -93,15 +94,20 @@ def decode_one_token_ar(
|
|
93 |
**sampling_kwargs,
|
94 |
) -> torch.Tensor:
|
95 |
x = model.forward_generate(x, input_pos)
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
codebooks = [
|
97 |
sample(
|
98 |
x.logits,
|
99 |
-
previous_tokens=
|
100 |
-
|
101 |
-
), # Disable repetition penalty for the token codebook
|
102 |
-
**sampling_kwargs,
|
103 |
)[0]
|
104 |
]
|
|
|
105 |
x = x.hidden_states
|
106 |
|
107 |
# Cleanup the cache
|
@@ -136,11 +142,16 @@ def decode_one_token_naive(
|
|
136 |
) -> torch.Tensor:
|
137 |
x = model.forward_generate(x, input_pos)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
139 |
codebooks = [
|
140 |
sample(
|
141 |
-
x.
|
142 |
previous_tokens=None, # Disable repetition penalty for the token codebook
|
143 |
-
**
|
144 |
)[0]
|
145 |
]
|
146 |
|
@@ -181,8 +192,12 @@ def decode_n_tokens(
|
|
181 |
else:
|
182 |
window = previous_tokens[:, i - win_size : i]
|
183 |
|
184 |
-
with
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
): # Actually better for Inductor to codegen attention here
|
187 |
next_token = decode_one_token(
|
188 |
model=model,
|
@@ -356,7 +371,10 @@ def load_model(checkpoint_path, device, precision, compile=False):
|
|
356 |
if compile:
|
357 |
logger.info("Compiling function...")
|
358 |
decode_one_token = torch.compile(
|
359 |
-
decode_one_token,
|
|
|
|
|
|
|
360 |
)
|
361 |
|
362 |
return model.eval(), decode_one_token
|
@@ -604,7 +622,7 @@ def launch_thread_safe_queue(
|
|
604 |
@click.option(
|
605 |
"--checkpoint-path",
|
606 |
type=click.Path(path_type=Path, exists=True),
|
607 |
-
default="checkpoints/fish-speech-1.
|
608 |
)
|
609 |
@click.option("--device", type=str, default="cuda")
|
610 |
@click.option("--compile/--no-compile", default=False)
|
|
|
2 |
import queue
|
3 |
import threading
|
4 |
import time
|
5 |
+
from contextlib import nullcontext
|
6 |
from dataclasses import dataclass
|
7 |
from pathlib import Path
|
8 |
from typing import Literal, Optional, Tuple, Union
|
|
|
94 |
**sampling_kwargs,
|
95 |
) -> torch.Tensor:
|
96 |
x = model.forward_generate(x, input_pos)
|
97 |
+
|
98 |
+
sampling_kwargs_main = sampling_kwargs.copy()
|
99 |
+
sampling_kwargs_main["temperature"] = 0.1
|
100 |
+
sampling_kwargs_main["top_p"] = 0.1
|
101 |
+
sampling_kwargs_main["repetition_penalty"] = 1.0
|
102 |
+
|
103 |
codebooks = [
|
104 |
sample(
|
105 |
x.logits,
|
106 |
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
107 |
+
**sampling_kwargs_main,
|
|
|
|
|
108 |
)[0]
|
109 |
]
|
110 |
+
|
111 |
x = x.hidden_states
|
112 |
|
113 |
# Cleanup the cache
|
|
|
142 |
) -> torch.Tensor:
|
143 |
x = model.forward_generate(x, input_pos)
|
144 |
|
145 |
+
sampling_kwargs_main = sampling_kwargs.copy()
|
146 |
+
sampling_kwargs_main["temperature"] = 0.1
|
147 |
+
sampling_kwargs_main["top_p"] = 0.1
|
148 |
+
sampling_kwargs_main["repetition_penalty"] = 1.0
|
149 |
+
|
150 |
codebooks = [
|
151 |
sample(
|
152 |
+
x.logits,
|
153 |
previous_tokens=None, # Disable repetition penalty for the token codebook
|
154 |
+
**sampling_kwargs_main,
|
155 |
)[0]
|
156 |
]
|
157 |
|
|
|
192 |
else:
|
193 |
window = previous_tokens[:, i - win_size : i]
|
194 |
|
195 |
+
with (
|
196 |
+
torch.backends.cuda.sdp_kernel(
|
197 |
+
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
198 |
+
)
|
199 |
+
if torch.cuda.is_available()
|
200 |
+
else nullcontext()
|
201 |
): # Actually better for Inductor to codegen attention here
|
202 |
next_token = decode_one_token(
|
203 |
model=model,
|
|
|
371 |
if compile:
|
372 |
logger.info("Compiling function...")
|
373 |
decode_one_token = torch.compile(
|
374 |
+
decode_one_token,
|
375 |
+
fullgraph=True,
|
376 |
+
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
377 |
+
mode="reduce-overhead" if torch.cuda.is_available() else None,
|
378 |
)
|
379 |
|
380 |
return model.eval(), decode_one_token
|
|
|
622 |
@click.option(
|
623 |
"--checkpoint-path",
|
624 |
type=click.Path(path_type=Path, exists=True),
|
625 |
+
default="checkpoints/fish-speech-1.4",
|
626 |
)
|
627 |
@click.option("--device", type=str, default="cuda")
|
628 |
@click.option("--compile/--no-compile", default=False)
|
tools/llama/merge_lora.py
CHANGED
@@ -15,7 +15,7 @@ from fish_speech.models.text2semantic.lora import get_merged_state_dict
|
|
15 |
|
16 |
@click.command()
|
17 |
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
18 |
-
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.
|
19 |
@click.option("--lora-weight", type=str, required=True)
|
20 |
@click.option("--output", type=str, required=True)
|
21 |
def merge(lora_config, base_weight, lora_weight, output):
|
|
|
15 |
|
16 |
@click.command()
|
17 |
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
18 |
+
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
|
19 |
@click.option("--lora-weight", type=str, required=True)
|
20 |
@click.option("--output", type=str, required=True)
|
21 |
def merge(lora_config, base_weight, lora_weight, output):
|
tools/llama/quantize.py
CHANGED
@@ -428,7 +428,7 @@ def generate_folder_name():
|
|
428 |
@click.option(
|
429 |
"--checkpoint-path",
|
430 |
type=click.Path(path_type=Path, exists=True),
|
431 |
-
default="checkpoints/fish-speech-1.
|
432 |
)
|
433 |
@click.option(
|
434 |
"--mode", type=str, default="int8", help="type of quantization to perform"
|
@@ -451,7 +451,7 @@ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -
|
|
451 |
precision=precision,
|
452 |
compile=False,
|
453 |
)
|
454 |
-
vq_model = "firefly-gan-vq-fsq-
|
455 |
now = timestamp if timestamp != "None" else generate_folder_name()
|
456 |
|
457 |
if mode == "int8":
|
|
|
428 |
@click.option(
|
429 |
"--checkpoint-path",
|
430 |
type=click.Path(path_type=Path, exists=True),
|
431 |
+
default="checkpoints/fish-speech-1.4",
|
432 |
)
|
433 |
@click.option(
|
434 |
"--mode", type=str, default="int8", help="type of quantization to perform"
|
|
|
451 |
precision=precision,
|
452 |
compile=False,
|
453 |
)
|
454 |
+
vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
|
455 |
now = timestamp if timestamp != "None" else generate_folder_name()
|
456 |
|
457 |
if mode == "int8":
|
tools/msgpack_api.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import httpx
|
2 |
+
import ormsgpack
|
3 |
+
|
4 |
+
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
5 |
+
|
6 |
+
# priority: ref_id > references
|
7 |
+
request = ServeTTSRequest(
|
8 |
+
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
9 |
+
# reference_id="114514",
|
10 |
+
references=[
|
11 |
+
ServeReferenceAudio(
|
12 |
+
audio=open("lengyue.wav", "rb").read(),
|
13 |
+
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
14 |
+
)
|
15 |
+
],
|
16 |
+
streaming=True,
|
17 |
+
)
|
18 |
+
|
19 |
+
with (
|
20 |
+
httpx.Client() as client,
|
21 |
+
open("hello.wav", "wb") as f,
|
22 |
+
):
|
23 |
+
with client.stream(
|
24 |
+
"POST",
|
25 |
+
"http://127.0.0.1:8080/v1/tts",
|
26 |
+
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
27 |
+
headers={
|
28 |
+
"authorization": "Bearer YOUR_API_KEY",
|
29 |
+
"content-type": "application/msgpack",
|
30 |
+
},
|
31 |
+
timeout=None,
|
32 |
+
) as response:
|
33 |
+
for chunk in response.iter_bytes():
|
34 |
+
f.write(chunk)
|
tools/post_api.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import base64
|
3 |
+
import wave
|
4 |
+
|
5 |
+
import ormsgpack
|
6 |
+
import pyaudio
|
7 |
+
import requests
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from pydub.playback import play
|
10 |
+
|
11 |
+
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
12 |
+
from tools.file import audio_to_bytes, read_ref_text
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(
|
18 |
+
description="Send a WAV file and text to a server and receive synthesized audio."
|
19 |
+
)
|
20 |
+
|
21 |
+
parser.add_argument(
|
22 |
+
"--url",
|
23 |
+
"-u",
|
24 |
+
type=str,
|
25 |
+
default="http://127.0.0.1:8080/v1/tts",
|
26 |
+
help="URL of the server",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--text", "-t", type=str, required=True, help="Text to be synthesized"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--reference_id",
|
33 |
+
"-id",
|
34 |
+
type=str,
|
35 |
+
default=None,
|
36 |
+
help="ID of the reference model o be used for the speech",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--reference_audio",
|
40 |
+
"-ra",
|
41 |
+
type=str,
|
42 |
+
nargs="+",
|
43 |
+
default=None,
|
44 |
+
help="Path to the WAV file",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--reference_text",
|
48 |
+
"-rt",
|
49 |
+
type=str,
|
50 |
+
nargs="+",
|
51 |
+
default=None,
|
52 |
+
help="Reference text for voice synthesis",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--output",
|
56 |
+
"-o",
|
57 |
+
type=str,
|
58 |
+
default="generated_audio",
|
59 |
+
help="Output audio file name",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--play",
|
63 |
+
type=bool,
|
64 |
+
default=True,
|
65 |
+
help="Whether to play audio after receiving data",
|
66 |
+
)
|
67 |
+
parser.add_argument("--normalize", type=bool, default=True)
|
68 |
+
parser.add_argument(
|
69 |
+
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
70 |
+
)
|
71 |
+
parser.add_argument("--mp3_bitrate", type=int, default=64)
|
72 |
+
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
73 |
+
parser.add_argument("--latency", type=str, default="normal", help="延迟选项")
|
74 |
+
parser.add_argument(
|
75 |
+
"--max_new_tokens",
|
76 |
+
type=int,
|
77 |
+
default=1024,
|
78 |
+
help="Maximum new tokens to generate",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--chunk_length", type=int, default=100, help="Chunk length for synthesis"
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--repetition_penalty",
|
88 |
+
type=float,
|
89 |
+
default=1.2,
|
90 |
+
help="Repetition penalty for synthesis",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--temperature", type=float, default=0.7, help="Temperature for sampling"
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
|
97 |
+
)
|
98 |
+
parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
|
99 |
+
parser.add_argument(
|
100 |
+
"--streaming", type=bool, default=False, help="Enable streaming response"
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--channels", type=int, default=1, help="Number of audio channels"
|
104 |
+
)
|
105 |
+
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
|
106 |
+
|
107 |
+
return parser.parse_args()
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
|
112 |
+
args = parse_args()
|
113 |
+
|
114 |
+
idstr: str | None = args.reference_id
|
115 |
+
# priority: ref_id > [{text, audio},...]
|
116 |
+
if idstr is None:
|
117 |
+
ref_audios = args.reference_audio
|
118 |
+
ref_texts = args.reference_text
|
119 |
+
if ref_audios is None:
|
120 |
+
byte_audios = []
|
121 |
+
else:
|
122 |
+
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
|
123 |
+
if ref_texts is None:
|
124 |
+
ref_texts = []
|
125 |
+
else:
|
126 |
+
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
|
127 |
+
else:
|
128 |
+
byte_audios = []
|
129 |
+
ref_texts = []
|
130 |
+
pass # in api.py
|
131 |
+
|
132 |
+
data = {
|
133 |
+
"text": args.text,
|
134 |
+
"references": [
|
135 |
+
ServeReferenceAudio(audio=ref_audio, text=ref_text)
|
136 |
+
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
137 |
+
],
|
138 |
+
"reference_id": idstr,
|
139 |
+
"normalize": args.normalize,
|
140 |
+
"format": args.format,
|
141 |
+
"mp3_bitrate": args.mp3_bitrate,
|
142 |
+
"opus_bitrate": args.opus_bitrate,
|
143 |
+
"max_new_tokens": args.max_new_tokens,
|
144 |
+
"chunk_length": args.chunk_length,
|
145 |
+
"top_p": args.top_p,
|
146 |
+
"repetition_penalty": args.repetition_penalty,
|
147 |
+
"temperature": args.temperature,
|
148 |
+
"speaker": args.speaker,
|
149 |
+
"emotion": args.emotion,
|
150 |
+
"streaming": args.streaming,
|
151 |
+
}
|
152 |
+
|
153 |
+
pydantic_data = ServeTTSRequest(**data)
|
154 |
+
|
155 |
+
response = requests.post(
|
156 |
+
args.url,
|
157 |
+
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
158 |
+
stream=args.streaming,
|
159 |
+
headers={
|
160 |
+
"authorization": "Bearer YOUR_API_KEY",
|
161 |
+
"content-type": "application/msgpack",
|
162 |
+
},
|
163 |
+
)
|
164 |
+
|
165 |
+
if response.status_code == 200:
|
166 |
+
if args.streaming:
|
167 |
+
p = pyaudio.PyAudio()
|
168 |
+
audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
|
169 |
+
stream = p.open(
|
170 |
+
format=audio_format, channels=args.channels, rate=args.rate, output=True
|
171 |
+
)
|
172 |
+
|
173 |
+
wf = wave.open(f"{args.output}.wav", "wb")
|
174 |
+
wf.setnchannels(args.channels)
|
175 |
+
wf.setsampwidth(p.get_sample_size(audio_format))
|
176 |
+
wf.setframerate(args.rate)
|
177 |
+
|
178 |
+
stream_stopped_flag = False
|
179 |
+
|
180 |
+
try:
|
181 |
+
for chunk in response.iter_content(chunk_size=1024):
|
182 |
+
if chunk:
|
183 |
+
stream.write(chunk)
|
184 |
+
wf.writeframesraw(chunk)
|
185 |
+
else:
|
186 |
+
if not stream_stopped_flag:
|
187 |
+
stream.stop_stream()
|
188 |
+
stream_stopped_flag = True
|
189 |
+
finally:
|
190 |
+
stream.close()
|
191 |
+
p.terminate()
|
192 |
+
wf.close()
|
193 |
+
else:
|
194 |
+
audio_content = response.content
|
195 |
+
audio_path = f"{args.output}.{args.format}"
|
196 |
+
with open(audio_path, "wb") as audio_file:
|
197 |
+
audio_file.write(audio_content)
|
198 |
+
|
199 |
+
audio = AudioSegment.from_file(audio_path, format=args.format)
|
200 |
+
if args.play:
|
201 |
+
play(audio)
|
202 |
+
print(f"Audio has been saved to '{audio_path}'.")
|
203 |
+
else:
|
204 |
+
print(f"Request failed with status code {response.status_code}")
|
205 |
+
print(response.json())
|
tools/sensevoice/README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FunASR Command Line Interface
|
2 |
+
|
3 |
+
This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
|
4 |
+
|
5 |
+
## Requirements
|
6 |
+
|
7 |
+
- Python >= 3.10
|
8 |
+
- PyTorch <= 2.3.1
|
9 |
+
- ffmpeg, pydub, audio-separator[gpu].
|
10 |
+
|
11 |
+
## Installation
|
12 |
+
|
13 |
+
Install the required packages:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
pip install -e .[stable]
|
17 |
+
```
|
18 |
+
|
19 |
+
Make sure you have `ffmpeg` installed and available in your `PATH`.
|
20 |
+
|
21 |
+
## Usage
|
22 |
+
|
23 |
+
### Basic Usage
|
24 |
+
|
25 |
+
To run the tool with default settings:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
|
29 |
+
```
|
30 |
+
|
31 |
+
## Options
|
32 |
+
|
33 |
+
| Option | Description |
|
34 |
+
| :-----------------------: | :---------------------------------------------------------------------------: |
|
35 |
+
| --audio-dir | Directory containing audio or video files. |
|
36 |
+
| --save-dir | Directory to save processed audio files. |
|
37 |
+
| --device | Device to use for processing. Options: cuda (default) or cpu. |
|
38 |
+
| --language | Language of the transcription. Default is auto. |
|
39 |
+
| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
|
40 |
+
| --punc | Enable punctuation prediction. |
|
41 |
+
| --denoise | Enable noise reduction (vocal separation). |
|
42 |
+
|
43 |
+
## Example
|
44 |
+
|
45 |
+
To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
|
49 |
+
```
|
50 |
+
|
51 |
+
## Additional Notes
|
52 |
+
|
53 |
+
- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
|
54 |
+
- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
|
55 |
+
- The script will automatically create necessary directories in the `--save-dir`.
|
56 |
+
|
57 |
+
## Troubleshooting
|
58 |
+
|
59 |
+
If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
|
tools/sensevoice/__init__.py
ADDED
File without changes
|
tools/sensevoice/auto_model.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import os.path
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
import string
|
13 |
+
import time
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from funasr.download.download_model_from_hub import download_model
|
18 |
+
from funasr.download.file import download_from_url
|
19 |
+
from funasr.register import tables
|
20 |
+
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
21 |
+
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
22 |
+
from funasr.utils import export_utils, misc
|
23 |
+
from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
|
24 |
+
from funasr.utils.misc import deep_update
|
25 |
+
from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from .vad_utils import merge_vad, slice_padding_audio_samples
|
29 |
+
|
30 |
+
try:
|
31 |
+
from funasr.models.campplus.cluster_backend import ClusterBackend
|
32 |
+
from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
38 |
+
""" """
|
39 |
+
data_list = []
|
40 |
+
key_list = []
|
41 |
+
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
|
42 |
+
|
43 |
+
chars = string.ascii_letters + string.digits
|
44 |
+
if isinstance(data_in, str):
|
45 |
+
if data_in.startswith("http://") or data_in.startswith("https://"): # url
|
46 |
+
data_in = download_from_url(data_in)
|
47 |
+
|
48 |
+
if isinstance(data_in, str) and os.path.exists(
|
49 |
+
data_in
|
50 |
+
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
51 |
+
_, file_extension = os.path.splitext(data_in)
|
52 |
+
file_extension = file_extension.lower()
|
53 |
+
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
|
54 |
+
with open(data_in, encoding="utf-8") as fin:
|
55 |
+
for line in fin:
|
56 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
57 |
+
if data_in.endswith(
|
58 |
+
".jsonl"
|
59 |
+
): # file.jsonl: json.dumps({"source": data})
|
60 |
+
lines = json.loads(line.strip())
|
61 |
+
data = lines["source"]
|
62 |
+
key = data["key"] if "key" in data else key
|
63 |
+
else: # filelist, wav.scp, text.txt: id \t data or data
|
64 |
+
lines = line.strip().split(maxsplit=1)
|
65 |
+
data = lines[1] if len(lines) > 1 else lines[0]
|
66 |
+
key = lines[0] if len(lines) > 1 else key
|
67 |
+
|
68 |
+
data_list.append(data)
|
69 |
+
key_list.append(key)
|
70 |
+
else:
|
71 |
+
if key is None:
|
72 |
+
# key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
73 |
+
key = misc.extract_filename_without_extension(data_in)
|
74 |
+
data_list = [data_in]
|
75 |
+
key_list = [key]
|
76 |
+
elif isinstance(data_in, (list, tuple)):
|
77 |
+
if data_type is not None and isinstance(
|
78 |
+
data_type, (list, tuple)
|
79 |
+
): # mutiple inputs
|
80 |
+
data_list_tmp = []
|
81 |
+
for data_in_i, data_type_i in zip(data_in, data_type):
|
82 |
+
key_list, data_list_i = prepare_data_iterator(
|
83 |
+
data_in=data_in_i, data_type=data_type_i
|
84 |
+
)
|
85 |
+
data_list_tmp.append(data_list_i)
|
86 |
+
data_list = []
|
87 |
+
for item in zip(*data_list_tmp):
|
88 |
+
data_list.append(item)
|
89 |
+
else:
|
90 |
+
# [audio sample point, fbank, text]
|
91 |
+
data_list = data_in
|
92 |
+
key_list = []
|
93 |
+
for data_i in data_in:
|
94 |
+
if isinstance(data_i, str) and os.path.exists(data_i):
|
95 |
+
key = misc.extract_filename_without_extension(data_i)
|
96 |
+
else:
|
97 |
+
if key is None:
|
98 |
+
key = "rand_key_" + "".join(
|
99 |
+
random.choice(chars) for _ in range(13)
|
100 |
+
)
|
101 |
+
key_list.append(key)
|
102 |
+
|
103 |
+
else: # raw text; audio sample point, fbank; bytes
|
104 |
+
if isinstance(data_in, bytes): # audio bytes
|
105 |
+
data_in = load_bytes(data_in)
|
106 |
+
if key is None:
|
107 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
108 |
+
data_list = [data_in]
|
109 |
+
key_list = [key]
|
110 |
+
|
111 |
+
return key_list, data_list
|
112 |
+
|
113 |
+
|
114 |
+
class AutoModel:
|
115 |
+
|
116 |
+
def __init__(self, **kwargs):
|
117 |
+
|
118 |
+
try:
|
119 |
+
from funasr.utils.version_checker import check_for_update
|
120 |
+
|
121 |
+
print(
|
122 |
+
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
|
123 |
+
)
|
124 |
+
check_for_update(disable=kwargs.get("disable_update", False))
|
125 |
+
except:
|
126 |
+
pass
|
127 |
+
|
128 |
+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
129 |
+
logging.basicConfig(level=log_level)
|
130 |
+
|
131 |
+
model, kwargs = self.build_model(**kwargs)
|
132 |
+
|
133 |
+
# if vad_model is not None, build vad model else None
|
134 |
+
vad_model = kwargs.get("vad_model", None)
|
135 |
+
vad_kwargs = (
|
136 |
+
{} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
|
137 |
+
)
|
138 |
+
if vad_model is not None:
|
139 |
+
logging.info("Building VAD model.")
|
140 |
+
vad_kwargs["model"] = vad_model
|
141 |
+
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
|
142 |
+
vad_kwargs["device"] = kwargs["device"]
|
143 |
+
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
144 |
+
|
145 |
+
# if punc_model is not None, build punc model else None
|
146 |
+
punc_model = kwargs.get("punc_model", None)
|
147 |
+
punc_kwargs = (
|
148 |
+
{}
|
149 |
+
if kwargs.get("punc_kwargs", {}) is None
|
150 |
+
else kwargs.get("punc_kwargs", {})
|
151 |
+
)
|
152 |
+
if punc_model is not None:
|
153 |
+
logging.info("Building punc model.")
|
154 |
+
punc_kwargs["model"] = punc_model
|
155 |
+
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
|
156 |
+
punc_kwargs["device"] = kwargs["device"]
|
157 |
+
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
158 |
+
|
159 |
+
# if spk_model is not None, build spk model else None
|
160 |
+
spk_model = kwargs.get("spk_model", None)
|
161 |
+
spk_kwargs = (
|
162 |
+
{} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
|
163 |
+
)
|
164 |
+
if spk_model is not None:
|
165 |
+
logging.info("Building SPK model.")
|
166 |
+
spk_kwargs["model"] = spk_model
|
167 |
+
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
|
168 |
+
spk_kwargs["device"] = kwargs["device"]
|
169 |
+
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
|
170 |
+
self.cb_model = ClusterBackend().to(kwargs["device"])
|
171 |
+
spk_mode = kwargs.get("spk_mode", "punc_segment")
|
172 |
+
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
173 |
+
logging.error(
|
174 |
+
"spk_mode should be one of default, vad_segment and punc_segment."
|
175 |
+
)
|
176 |
+
self.spk_mode = spk_mode
|
177 |
+
|
178 |
+
self.kwargs = kwargs
|
179 |
+
self.model = model
|
180 |
+
self.vad_model = vad_model
|
181 |
+
self.vad_kwargs = vad_kwargs
|
182 |
+
self.punc_model = punc_model
|
183 |
+
self.punc_kwargs = punc_kwargs
|
184 |
+
self.spk_model = spk_model
|
185 |
+
self.spk_kwargs = spk_kwargs
|
186 |
+
self.model_path = kwargs.get("model_path")
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def build_model(**kwargs):
|
190 |
+
assert "model" in kwargs
|
191 |
+
if "model_conf" not in kwargs:
|
192 |
+
logging.info(
|
193 |
+
"download models from model hub: {}".format(kwargs.get("hub", "ms"))
|
194 |
+
)
|
195 |
+
kwargs = download_model(**kwargs)
|
196 |
+
|
197 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
198 |
+
|
199 |
+
device = kwargs.get("device", "cuda")
|
200 |
+
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
201 |
+
device = "cpu"
|
202 |
+
kwargs["batch_size"] = 1
|
203 |
+
kwargs["device"] = device
|
204 |
+
|
205 |
+
torch.set_num_threads(kwargs.get("ncpu", 4))
|
206 |
+
|
207 |
+
# build tokenizer
|
208 |
+
tokenizer = kwargs.get("tokenizer", None)
|
209 |
+
if tokenizer is not None:
|
210 |
+
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
211 |
+
tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
|
212 |
+
kwargs["token_list"] = (
|
213 |
+
tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
214 |
+
)
|
215 |
+
kwargs["token_list"] = (
|
216 |
+
tokenizer.get_vocab()
|
217 |
+
if hasattr(tokenizer, "get_vocab")
|
218 |
+
else kwargs["token_list"]
|
219 |
+
)
|
220 |
+
vocab_size = (
|
221 |
+
len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
|
222 |
+
)
|
223 |
+
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
|
224 |
+
vocab_size = tokenizer.get_vocab_size()
|
225 |
+
else:
|
226 |
+
vocab_size = -1
|
227 |
+
kwargs["tokenizer"] = tokenizer
|
228 |
+
|
229 |
+
# build frontend
|
230 |
+
frontend = kwargs.get("frontend", None)
|
231 |
+
kwargs["input_size"] = None
|
232 |
+
if frontend is not None:
|
233 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
234 |
+
frontend = frontend_class(**kwargs.get("frontend_conf", {}))
|
235 |
+
kwargs["input_size"] = (
|
236 |
+
frontend.output_size() if hasattr(frontend, "output_size") else None
|
237 |
+
)
|
238 |
+
kwargs["frontend"] = frontend
|
239 |
+
# build model
|
240 |
+
model_class = tables.model_classes.get(kwargs["model"])
|
241 |
+
assert model_class is not None, f'{kwargs["model"]} is not registered'
|
242 |
+
model_conf = {}
|
243 |
+
deep_update(model_conf, kwargs.get("model_conf", {}))
|
244 |
+
deep_update(model_conf, kwargs)
|
245 |
+
model = model_class(**model_conf, vocab_size=vocab_size)
|
246 |
+
|
247 |
+
# init_param
|
248 |
+
init_param = kwargs.get("init_param", None)
|
249 |
+
if init_param is not None:
|
250 |
+
if os.path.exists(init_param):
|
251 |
+
logging.info(f"Loading pretrained params from {init_param}")
|
252 |
+
load_pretrained_model(
|
253 |
+
model=model,
|
254 |
+
path=init_param,
|
255 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
256 |
+
oss_bucket=kwargs.get("oss_bucket", None),
|
257 |
+
scope_map=kwargs.get("scope_map", []),
|
258 |
+
excludes=kwargs.get("excludes", None),
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
print(f"error, init_param does not exist!: {init_param}")
|
262 |
+
|
263 |
+
# fp16
|
264 |
+
if kwargs.get("fp16", False):
|
265 |
+
model.to(torch.float16)
|
266 |
+
elif kwargs.get("bf16", False):
|
267 |
+
model.to(torch.bfloat16)
|
268 |
+
model.to(device)
|
269 |
+
|
270 |
+
if not kwargs.get("disable_log", True):
|
271 |
+
tables.print()
|
272 |
+
|
273 |
+
return model, kwargs
|
274 |
+
|
275 |
+
def __call__(self, *args, **cfg):
|
276 |
+
kwargs = self.kwargs
|
277 |
+
deep_update(kwargs, cfg)
|
278 |
+
res = self.model(*args, kwargs)
|
279 |
+
return res
|
280 |
+
|
281 |
+
def generate(self, input, input_len=None, **cfg):
|
282 |
+
if self.vad_model is None:
|
283 |
+
return self.inference(input, input_len=input_len, **cfg)
|
284 |
+
|
285 |
+
else:
|
286 |
+
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
287 |
+
|
288 |
+
def inference(
|
289 |
+
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
290 |
+
):
|
291 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
292 |
+
if "cache" in kwargs:
|
293 |
+
kwargs.pop("cache")
|
294 |
+
deep_update(kwargs, cfg)
|
295 |
+
model = self.model if model is None else model
|
296 |
+
model.eval()
|
297 |
+
|
298 |
+
batch_size = kwargs.get("batch_size", 1)
|
299 |
+
# if kwargs.get("device", "cpu") == "cpu":
|
300 |
+
# batch_size = 1
|
301 |
+
|
302 |
+
key_list, data_list = prepare_data_iterator(
|
303 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
304 |
+
)
|
305 |
+
|
306 |
+
speed_stats = {}
|
307 |
+
asr_result_list = []
|
308 |
+
num_samples = len(data_list)
|
309 |
+
disable_pbar = self.kwargs.get("disable_pbar", False)
|
310 |
+
pbar = (
|
311 |
+
tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
|
312 |
+
if not disable_pbar
|
313 |
+
else None
|
314 |
+
)
|
315 |
+
time_speech_total = 0.0
|
316 |
+
time_escape_total = 0.0
|
317 |
+
for beg_idx in range(0, num_samples, batch_size):
|
318 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
319 |
+
data_batch = data_list[beg_idx:end_idx]
|
320 |
+
key_batch = key_list[beg_idx:end_idx]
|
321 |
+
batch = {"data_in": data_batch, "key": key_batch}
|
322 |
+
|
323 |
+
if (end_idx - beg_idx) == 1 and kwargs.get(
|
324 |
+
"data_type", None
|
325 |
+
) == "fbank": # fbank
|
326 |
+
batch["data_in"] = data_batch[0]
|
327 |
+
batch["data_lengths"] = input_len
|
328 |
+
|
329 |
+
time1 = time.perf_counter()
|
330 |
+
with torch.no_grad():
|
331 |
+
res = model.inference(**batch, **kwargs)
|
332 |
+
if isinstance(res, (list, tuple)):
|
333 |
+
results = res[0] if len(res) > 0 else [{"text": ""}]
|
334 |
+
meta_data = res[1] if len(res) > 1 else {}
|
335 |
+
time2 = time.perf_counter()
|
336 |
+
|
337 |
+
asr_result_list.extend(results)
|
338 |
+
|
339 |
+
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
340 |
+
batch_data_time = meta_data.get("batch_data_time", -1)
|
341 |
+
time_escape = time2 - time1
|
342 |
+
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
343 |
+
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
344 |
+
speed_stats["forward"] = f"{time_escape:0.3f}"
|
345 |
+
speed_stats["batch_size"] = f"{len(results)}"
|
346 |
+
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
347 |
+
description = f"{speed_stats}, "
|
348 |
+
if pbar:
|
349 |
+
pbar.update(end_idx - beg_idx)
|
350 |
+
pbar.set_description(description)
|
351 |
+
time_speech_total += batch_data_time
|
352 |
+
time_escape_total += time_escape
|
353 |
+
|
354 |
+
if pbar:
|
355 |
+
# pbar.update(1)
|
356 |
+
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
357 |
+
torch.cuda.empty_cache()
|
358 |
+
return asr_result_list
|
359 |
+
|
360 |
+
def vad(self, input, input_len=None, **cfg):
|
361 |
+
kwargs = self.kwargs
|
362 |
+
# step.1: compute the vad model
|
363 |
+
deep_update(self.vad_kwargs, cfg)
|
364 |
+
beg_vad = time.time()
|
365 |
+
res = self.inference(
|
366 |
+
input,
|
367 |
+
input_len=input_len,
|
368 |
+
model=self.vad_model,
|
369 |
+
kwargs=self.vad_kwargs,
|
370 |
+
**cfg,
|
371 |
+
)
|
372 |
+
end_vad = time.time()
|
373 |
+
# FIX(gcf): concat the vad clips for sense vocie model for better aed
|
374 |
+
if cfg.get("merge_vad", False):
|
375 |
+
for i in range(len(res)):
|
376 |
+
res[i]["value"] = merge_vad(
|
377 |
+
res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
|
378 |
+
)
|
379 |
+
elapsed = end_vad - beg_vad
|
380 |
+
return elapsed, res
|
381 |
+
|
382 |
+
def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
|
383 |
+
|
384 |
+
kwargs = self.kwargs
|
385 |
+
|
386 |
+
# step.2 compute asr model
|
387 |
+
model = self.model
|
388 |
+
deep_update(kwargs, cfg)
|
389 |
+
batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
|
390 |
+
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
|
391 |
+
kwargs["batch_size"] = batch_size
|
392 |
+
|
393 |
+
key_list, data_list = prepare_data_iterator(
|
394 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None)
|
395 |
+
)
|
396 |
+
results_ret_list = []
|
397 |
+
time_speech_total_all_samples = 1e-6
|
398 |
+
|
399 |
+
beg_total = time.time()
|
400 |
+
pbar_total = (
|
401 |
+
tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
|
402 |
+
if not kwargs.get("disable_pbar", False)
|
403 |
+
else None
|
404 |
+
)
|
405 |
+
|
406 |
+
for i in range(len(vad_res)):
|
407 |
+
key = vad_res[i]["key"]
|
408 |
+
vadsegments = vad_res[i]["value"]
|
409 |
+
input_i = data_list[i]
|
410 |
+
fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
|
411 |
+
speech = load_audio_text_image_video(
|
412 |
+
input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
|
413 |
+
)
|
414 |
+
speech_lengths = len(speech)
|
415 |
+
n = len(vadsegments)
|
416 |
+
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
417 |
+
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
418 |
+
results_sorted = []
|
419 |
+
|
420 |
+
if not len(sorted_data):
|
421 |
+
results_ret_list.append({"key": key, "text": "", "timestamp": []})
|
422 |
+
logging.info("decoding, utt: {}, empty speech".format(key))
|
423 |
+
continue
|
424 |
+
|
425 |
+
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
426 |
+
batch_size = max(
|
427 |
+
batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
|
428 |
+
)
|
429 |
+
|
430 |
+
if kwargs["device"] == "cpu":
|
431 |
+
batch_size = 0
|
432 |
+
|
433 |
+
beg_idx = 0
|
434 |
+
beg_asr_total = time.time()
|
435 |
+
time_speech_total_per_sample = speech_lengths / 16000
|
436 |
+
time_speech_total_all_samples += time_speech_total_per_sample
|
437 |
+
|
438 |
+
# pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
|
439 |
+
|
440 |
+
all_segments = []
|
441 |
+
max_len_in_batch = 0
|
442 |
+
end_idx = 1
|
443 |
+
|
444 |
+
for j, _ in enumerate(range(0, n)):
|
445 |
+
# pbar_sample.update(1)
|
446 |
+
sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
|
447 |
+
potential_batch_length = max(max_len_in_batch, sample_length) * (
|
448 |
+
j + 1 - beg_idx
|
449 |
+
)
|
450 |
+
# batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
|
451 |
+
if (
|
452 |
+
j < n - 1
|
453 |
+
and sample_length < batch_size_threshold_ms
|
454 |
+
and potential_batch_length < batch_size
|
455 |
+
):
|
456 |
+
max_len_in_batch = max(max_len_in_batch, sample_length)
|
457 |
+
end_idx += 1
|
458 |
+
continue
|
459 |
+
|
460 |
+
speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
|
461 |
+
speech, speech_lengths, sorted_data[beg_idx:end_idx]
|
462 |
+
)
|
463 |
+
results = self.inference(
|
464 |
+
speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
|
465 |
+
)
|
466 |
+
|
467 |
+
for _b in range(len(speech_j)):
|
468 |
+
results[_b]["interval"] = intervals[_b]
|
469 |
+
|
470 |
+
if self.spk_model is not None:
|
471 |
+
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
472 |
+
for _b in range(len(speech_j)):
|
473 |
+
vad_segments = [
|
474 |
+
[
|
475 |
+
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
|
476 |
+
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
|
477 |
+
np.array(speech_j[_b]),
|
478 |
+
]
|
479 |
+
]
|
480 |
+
segments = sv_chunk(vad_segments)
|
481 |
+
all_segments.extend(segments)
|
482 |
+
speech_b = [i[2] for i in segments]
|
483 |
+
spk_res = self.inference(
|
484 |
+
speech_b,
|
485 |
+
input_len=None,
|
486 |
+
model=self.spk_model,
|
487 |
+
kwargs=kwargs,
|
488 |
+
**cfg,
|
489 |
+
)
|
490 |
+
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
491 |
+
|
492 |
+
beg_idx = end_idx
|
493 |
+
end_idx += 1
|
494 |
+
max_len_in_batch = sample_length
|
495 |
+
if len(results) < 1:
|
496 |
+
continue
|
497 |
+
results_sorted.extend(results)
|
498 |
+
|
499 |
+
# end_asr_total = time.time()
|
500 |
+
# time_escape_total_per_sample = end_asr_total - beg_asr_total
|
501 |
+
# pbar_sample.update(1)
|
502 |
+
# pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
503 |
+
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
504 |
+
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
505 |
+
|
506 |
+
restored_data = [0] * n
|
507 |
+
for j in range(n):
|
508 |
+
index = sorted_data[j][1]
|
509 |
+
cur = results_sorted[j]
|
510 |
+
pattern = r"<\|([^|]+)\|>"
|
511 |
+
emotion_string = re.findall(pattern, cur["text"])
|
512 |
+
cur["text"] = re.sub(pattern, "", cur["text"])
|
513 |
+
cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
|
514 |
+
if self.punc_model is not None and len(cur["text"].strip()) > 0:
|
515 |
+
deep_update(self.punc_kwargs, cfg)
|
516 |
+
punc_res = self.inference(
|
517 |
+
cur["text"],
|
518 |
+
model=self.punc_model,
|
519 |
+
kwargs=self.punc_kwargs,
|
520 |
+
**cfg,
|
521 |
+
)
|
522 |
+
cur["text"] = punc_res[0]["text"]
|
523 |
+
|
524 |
+
restored_data[index] = cur
|
525 |
+
|
526 |
+
end_asr_total = time.time()
|
527 |
+
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
528 |
+
if pbar_total:
|
529 |
+
pbar_total.update(1)
|
530 |
+
pbar_total.set_description(
|
531 |
+
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
532 |
+
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
|
533 |
+
f"time_escape: {time_escape_total_per_sample:0.3f}"
|
534 |
+
)
|
535 |
+
|
536 |
+
# end_total = time.time()
|
537 |
+
# time_escape_total_all_samples = end_total - beg_total
|
538 |
+
# print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
|
539 |
+
# f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
|
540 |
+
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
|
541 |
+
return restored_data
|
542 |
+
|
543 |
+
def export(self, input=None, **cfg):
|
544 |
+
"""
|
545 |
+
|
546 |
+
:param input:
|
547 |
+
:param type:
|
548 |
+
:param quantize:
|
549 |
+
:param fallback_num:
|
550 |
+
:param calib_num:
|
551 |
+
:param opset_version:
|
552 |
+
:param cfg:
|
553 |
+
:return:
|
554 |
+
"""
|
555 |
+
|
556 |
+
device = cfg.get("device", "cpu")
|
557 |
+
model = self.model.to(device=device)
|
558 |
+
kwargs = self.kwargs
|
559 |
+
deep_update(kwargs, cfg)
|
560 |
+
kwargs["device"] = device
|
561 |
+
del kwargs["model"]
|
562 |
+
model.eval()
|
563 |
+
|
564 |
+
type = kwargs.get("type", "onnx")
|
565 |
+
|
566 |
+
key_list, data_list = prepare_data_iterator(
|
567 |
+
input, input_len=None, data_type=kwargs.get("data_type", None), key=None
|
568 |
+
)
|
569 |
+
|
570 |
+
with torch.no_grad():
|
571 |
+
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
|
572 |
+
|
573 |
+
return export_dir
|
tools/sensevoice/fun_asr.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from audio_separator.separator import Separator
|
6 |
+
|
7 |
+
os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
|
8 |
+
os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
|
9 |
+
import json
|
10 |
+
import subprocess
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import click
|
14 |
+
import torch
|
15 |
+
from loguru import logger
|
16 |
+
from pydub import AudioSegment
|
17 |
+
from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
|
21 |
+
from tools.sensevoice.auto_model import AutoModel
|
22 |
+
|
23 |
+
|
24 |
+
def uvr5_cli(
|
25 |
+
audio_dir: Path,
|
26 |
+
output_folder: Path,
|
27 |
+
audio_files: list[Path] | None = None,
|
28 |
+
output_format: str = "flac",
|
29 |
+
model: str = "BS-Roformer-Viperx-1297.ckpt",
|
30 |
+
):
|
31 |
+
# ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
|
32 |
+
sepr = Separator(
|
33 |
+
model_file_dir=os.environ["UVR5_CACHE"],
|
34 |
+
output_dir=output_folder,
|
35 |
+
output_format=output_format,
|
36 |
+
)
|
37 |
+
dictmodel = {
|
38 |
+
"BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
|
39 |
+
"BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
|
40 |
+
"BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
|
41 |
+
"Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
|
42 |
+
}
|
43 |
+
roformer_model = dictmodel[model]
|
44 |
+
sepr.load_model(roformer_model)
|
45 |
+
if audio_files is None:
|
46 |
+
audio_files = list_files(
|
47 |
+
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
48 |
+
)
|
49 |
+
total_files = len(audio_files)
|
50 |
+
|
51 |
+
print(f"{total_files} audio files found")
|
52 |
+
|
53 |
+
res = []
|
54 |
+
for audio in tqdm(audio_files, desc="Denoising: "):
|
55 |
+
file_path = str(audio_dir / audio)
|
56 |
+
sep_out = sepr.separate(file_path)
|
57 |
+
if isinstance(sep_out, str):
|
58 |
+
res.append(sep_out)
|
59 |
+
elif isinstance(sep_out, list):
|
60 |
+
res.extend(sep_out)
|
61 |
+
del sepr
|
62 |
+
gc.collect()
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
torch.cuda.empty_cache()
|
65 |
+
|
66 |
+
return res, roformer_model
|
67 |
+
|
68 |
+
|
69 |
+
def get_sample_rate(media_path: Path):
|
70 |
+
result = subprocess.run(
|
71 |
+
[
|
72 |
+
"ffprobe",
|
73 |
+
"-v",
|
74 |
+
"quiet",
|
75 |
+
"-print_format",
|
76 |
+
"json",
|
77 |
+
"-show_streams",
|
78 |
+
str(media_path),
|
79 |
+
],
|
80 |
+
capture_output=True,
|
81 |
+
text=True,
|
82 |
+
check=True,
|
83 |
+
)
|
84 |
+
media_info = json.loads(result.stdout)
|
85 |
+
for stream in media_info.get("streams", []):
|
86 |
+
if stream.get("codec_type") == "audio":
|
87 |
+
return stream.get("sample_rate")
|
88 |
+
return "44100" # Default sample rate if not found
|
89 |
+
|
90 |
+
|
91 |
+
def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
|
92 |
+
sr = get_sample_rate(src_path)
|
93 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
94 |
+
if src_path.resolve() == out_path.resolve():
|
95 |
+
output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
|
96 |
+
else:
|
97 |
+
output = str(out_path)
|
98 |
+
subprocess.run(
|
99 |
+
[
|
100 |
+
"ffmpeg",
|
101 |
+
"-loglevel",
|
102 |
+
"error",
|
103 |
+
"-i",
|
104 |
+
str(src_path),
|
105 |
+
"-acodec",
|
106 |
+
"pcm_s16le" if out_fmt == "wav" else "flac",
|
107 |
+
"-ar",
|
108 |
+
sr,
|
109 |
+
"-ac",
|
110 |
+
"1",
|
111 |
+
"-y",
|
112 |
+
output,
|
113 |
+
],
|
114 |
+
check=True,
|
115 |
+
)
|
116 |
+
return out_path
|
117 |
+
|
118 |
+
|
119 |
+
def convert_video_to_audio(video_path: Path, audio_dir: Path):
|
120 |
+
cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
|
121 |
+
vocals = [
|
122 |
+
p
|
123 |
+
for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
|
124 |
+
if p.suffix in AUDIO_EXTENSIONS
|
125 |
+
]
|
126 |
+
if len(vocals) > 0:
|
127 |
+
return vocals[0]
|
128 |
+
audio_path = cur_dir / f"{video_path.stem}.wav"
|
129 |
+
convert_to_mono(video_path, audio_path)
|
130 |
+
return audio_path
|
131 |
+
|
132 |
+
|
133 |
+
@click.command()
|
134 |
+
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
135 |
+
@click.option(
|
136 |
+
"--save-dir", required=True, help="Directory to save processed audio files"
|
137 |
+
)
|
138 |
+
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
139 |
+
@click.option("--language", default="auto", help="Language of the transcription")
|
140 |
+
@click.option(
|
141 |
+
"--max_single_segment_time",
|
142 |
+
default=20000,
|
143 |
+
type=int,
|
144 |
+
help="Maximum of Output single audio duration(ms)",
|
145 |
+
)
|
146 |
+
@click.option("--fsmn-vad/--silero-vad", default=False)
|
147 |
+
@click.option("--punc/--no-punc", default=False)
|
148 |
+
@click.option("--denoise/--no-denoise", default=False)
|
149 |
+
@click.option("--save_emo/--no_save_emo", default=False)
|
150 |
+
def main(
|
151 |
+
audio_dir: str,
|
152 |
+
save_dir: str,
|
153 |
+
device: str,
|
154 |
+
language: str,
|
155 |
+
max_single_segment_time: int,
|
156 |
+
fsmn_vad: bool,
|
157 |
+
punc: bool,
|
158 |
+
denoise: bool,
|
159 |
+
save_emo: bool,
|
160 |
+
):
|
161 |
+
|
162 |
+
audios_path = Path(audio_dir)
|
163 |
+
save_path = Path(save_dir)
|
164 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
165 |
+
|
166 |
+
video_files = list_files(
|
167 |
+
path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
|
168 |
+
)
|
169 |
+
v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
|
170 |
+
|
171 |
+
if denoise:
|
172 |
+
VOCAL = "_(Vocals)"
|
173 |
+
original_files = [
|
174 |
+
p
|
175 |
+
for p in audios_path.glob("**/*")
|
176 |
+
if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
|
177 |
+
]
|
178 |
+
|
179 |
+
_, cur_model = uvr5_cli(
|
180 |
+
audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
|
181 |
+
)
|
182 |
+
need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
|
183 |
+
need_remove.extend(original_files)
|
184 |
+
for _ in need_remove:
|
185 |
+
_.unlink()
|
186 |
+
vocal_files = [
|
187 |
+
p
|
188 |
+
for p in audios_path.glob("**/*")
|
189 |
+
if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
|
190 |
+
]
|
191 |
+
for f in vocal_files:
|
192 |
+
fn, ext = f.stem, f.suffix
|
193 |
+
|
194 |
+
v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
|
195 |
+
if v_pos != -1:
|
196 |
+
new_fn = fn[: v_pos + len(VOCAL)]
|
197 |
+
new_f = f.with_name(new_fn + ext)
|
198 |
+
f = f.rename(new_f)
|
199 |
+
convert_to_mono(f, f, "flac")
|
200 |
+
f.unlink()
|
201 |
+
|
202 |
+
audio_files = list_files(
|
203 |
+
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
204 |
+
)
|
205 |
+
|
206 |
+
logger.info("Loading / Downloading Funasr model...")
|
207 |
+
|
208 |
+
model_dir = "iic/SenseVoiceSmall"
|
209 |
+
|
210 |
+
vad_model = "fsmn-vad" if fsmn_vad else None
|
211 |
+
vad_kwargs = {"max_single_segment_time": max_single_segment_time}
|
212 |
+
punc_model = "ct-punc" if punc else None
|
213 |
+
|
214 |
+
manager = AutoModel(
|
215 |
+
model=model_dir,
|
216 |
+
trust_remote_code=False,
|
217 |
+
vad_model=vad_model,
|
218 |
+
vad_kwargs=vad_kwargs,
|
219 |
+
punc_model=punc_model,
|
220 |
+
device=device,
|
221 |
+
)
|
222 |
+
|
223 |
+
if not fsmn_vad and vad_model is None:
|
224 |
+
vad_model = load_silero_vad()
|
225 |
+
|
226 |
+
logger.info("Model loaded.")
|
227 |
+
|
228 |
+
pattern = re.compile(r"_\d{3}\.")
|
229 |
+
|
230 |
+
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
231 |
+
|
232 |
+
if pattern.search(file_path.name):
|
233 |
+
# logger.info(f"Skipping {file_path} as it has already been processed.")
|
234 |
+
continue
|
235 |
+
|
236 |
+
file_stem = file_path.stem
|
237 |
+
file_suffix = file_path.suffix
|
238 |
+
|
239 |
+
rel_path = Path(file_path).relative_to(audio_dir)
|
240 |
+
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
241 |
+
|
242 |
+
audio = AudioSegment.from_file(file_path)
|
243 |
+
|
244 |
+
cfg = dict(
|
245 |
+
cache={},
|
246 |
+
language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
|
247 |
+
use_itn=False,
|
248 |
+
batch_size_s=60,
|
249 |
+
)
|
250 |
+
|
251 |
+
if fsmn_vad:
|
252 |
+
elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
|
253 |
+
else:
|
254 |
+
wav = read_audio(
|
255 |
+
str(file_path)
|
256 |
+
) # backend (sox, soundfile, or ffmpeg) required!
|
257 |
+
audio_key = file_path.stem
|
258 |
+
audio_val = []
|
259 |
+
speech_timestamps = get_speech_timestamps(
|
260 |
+
wav,
|
261 |
+
vad_model,
|
262 |
+
max_speech_duration_s=max_single_segment_time // 1000,
|
263 |
+
return_seconds=True,
|
264 |
+
)
|
265 |
+
|
266 |
+
audio_val = [
|
267 |
+
[int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
|
268 |
+
for timestamp in speech_timestamps
|
269 |
+
]
|
270 |
+
vad_res = []
|
271 |
+
vad_res.append(dict(key=audio_key, value=audio_val))
|
272 |
+
|
273 |
+
res = manager.inference_with_vadres(
|
274 |
+
input=str(file_path), vad_res=vad_res, **cfg
|
275 |
+
)
|
276 |
+
|
277 |
+
for i, info in enumerate(res):
|
278 |
+
[start_ms, end_ms] = info["interval"]
|
279 |
+
text = info["text"]
|
280 |
+
emo = info["emo"]
|
281 |
+
sliced_audio = audio[start_ms:end_ms]
|
282 |
+
audio_save_path = (
|
283 |
+
save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
|
284 |
+
)
|
285 |
+
sliced_audio.export(audio_save_path, format=file_suffix[1:])
|
286 |
+
print(f"Exported {audio_save_path}: {text}")
|
287 |
+
|
288 |
+
transcript_save_path = (
|
289 |
+
save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
|
290 |
+
)
|
291 |
+
with open(
|
292 |
+
transcript_save_path,
|
293 |
+
"w",
|
294 |
+
encoding="utf-8",
|
295 |
+
) as f:
|
296 |
+
f.write(text)
|
297 |
+
|
298 |
+
if save_emo:
|
299 |
+
emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
|
300 |
+
with open(
|
301 |
+
emo_save_path,
|
302 |
+
"w",
|
303 |
+
encoding="utf-8",
|
304 |
+
) as f:
|
305 |
+
f.write(emo)
|
306 |
+
|
307 |
+
if audios_path.resolve() == save_path.resolve():
|
308 |
+
file_path.unlink()
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
main()
|
313 |
+
exit(0)
|
314 |
+
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
315 |
+
|
316 |
+
# Load the audio file
|
317 |
+
audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
|
318 |
+
model_dir = "iic/SenseVoiceSmall"
|
319 |
+
m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
|
320 |
+
m.eval()
|
321 |
+
|
322 |
+
res = m.inference(
|
323 |
+
data_in=f"{kwargs['model_path']}/example/zh.mp3",
|
324 |
+
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
325 |
+
use_itn=False,
|
326 |
+
ban_emo_unk=False,
|
327 |
+
**kwargs,
|
328 |
+
)
|
329 |
+
|
330 |
+
print(res)
|
331 |
+
text = rich_transcription_postprocess(res[0][0]["text"])
|
332 |
+
print(text)
|
tools/sensevoice/vad_utils.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.utils.rnn import pad_sequence
|
3 |
+
|
4 |
+
|
5 |
+
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
6 |
+
speech_list = []
|
7 |
+
speech_lengths_list = []
|
8 |
+
for i, segment in enumerate(vad_segments):
|
9 |
+
|
10 |
+
bed_idx = int(segment[0][0] * 16)
|
11 |
+
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
|
12 |
+
speech_i = speech[0, bed_idx:end_idx]
|
13 |
+
speech_lengths_i = end_idx - bed_idx
|
14 |
+
speech_list.append(speech_i)
|
15 |
+
speech_lengths_list.append(speech_lengths_i)
|
16 |
+
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
17 |
+
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
18 |
+
return feats_pad, speech_lengths_pad
|
19 |
+
|
20 |
+
|
21 |
+
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
22 |
+
speech_list = []
|
23 |
+
speech_lengths_list = []
|
24 |
+
intervals = []
|
25 |
+
for i, segment in enumerate(vad_segments):
|
26 |
+
bed_idx = int(segment[0][0] * 16)
|
27 |
+
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
28 |
+
speech_i = speech[bed_idx:end_idx]
|
29 |
+
speech_lengths_i = end_idx - bed_idx
|
30 |
+
speech_list.append(speech_i)
|
31 |
+
speech_lengths_list.append(speech_lengths_i)
|
32 |
+
intervals.append([bed_idx // 16, end_idx // 16])
|
33 |
+
|
34 |
+
return speech_list, speech_lengths_list, intervals
|
35 |
+
|
36 |
+
|
37 |
+
def merge_vad(vad_result, max_length=15000, min_length=0):
|
38 |
+
new_result = []
|
39 |
+
if len(vad_result) <= 1:
|
40 |
+
return vad_result
|
41 |
+
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
|
42 |
+
time_step = sorted(list(set(time_step)))
|
43 |
+
if len(time_step) == 0:
|
44 |
+
return []
|
45 |
+
bg = 0
|
46 |
+
for i in range(len(time_step) - 1):
|
47 |
+
time = time_step[i]
|
48 |
+
if time_step[i + 1] - bg < max_length:
|
49 |
+
continue
|
50 |
+
if time - bg > min_length:
|
51 |
+
new_result.append([bg, time])
|
52 |
+
# if time - bg < max_length * 1.5:
|
53 |
+
# new_result.append([bg, time])
|
54 |
+
# else:
|
55 |
+
# split_num = int(time - bg) // max_length + 1
|
56 |
+
# spl_l = int(time - bg) // split_num
|
57 |
+
# for j in range(split_num):
|
58 |
+
# new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
|
59 |
+
bg = time
|
60 |
+
new_result.append([bg, time_step[-1]])
|
61 |
+
return new_result
|
tools/smart_pad.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from multiprocessing import Pool
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import click
|
6 |
+
import librosa
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
12 |
+
|
13 |
+
threshold = 10 ** (-50 / 20.0)
|
14 |
+
|
15 |
+
|
16 |
+
def process(file):
|
17 |
+
waveform, sample_rate = torchaudio.load(str(file), backend="sox")
|
18 |
+
loudness = librosa.feature.rms(
|
19 |
+
y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
|
20 |
+
)[0]
|
21 |
+
for i in range(len(loudness) - 1, 0, -1):
|
22 |
+
if loudness[i] > threshold:
|
23 |
+
break
|
24 |
+
|
25 |
+
silent_time = (len(loudness) - i) * 512 / sample_rate
|
26 |
+
|
27 |
+
if silent_time <= 0.3:
|
28 |
+
random_time = random.uniform(0.3, 0.7)
|
29 |
+
waveform = F.pad(
|
30 |
+
waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
|
31 |
+
)
|
32 |
+
|
33 |
+
torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
|
34 |
+
|
35 |
+
|
36 |
+
@click.command()
|
37 |
+
@click.argument("source", type=Path)
|
38 |
+
@click.option("--num-workers", type=int, default=12)
|
39 |
+
def main(source, num_workers):
|
40 |
+
files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
|
41 |
+
|
42 |
+
with Pool(num_workers) as p:
|
43 |
+
list(tqdm(p.imap_unordered(process, files), total=len(files)))
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
main()
|
tools/vqgan/create_train_split.py
CHANGED
@@ -7,7 +7,7 @@ from loguru import logger
|
|
7 |
from pydub import AudioSegment
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
-
from
|
11 |
|
12 |
|
13 |
@click.command()
|
|
|
7 |
from pydub import AudioSegment
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
+
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
11 |
|
12 |
|
13 |
@click.command()
|
tools/vqgan/extract_vq.py
CHANGED
@@ -17,7 +17,7 @@ from lightning import LightningModule
|
|
17 |
from loguru import logger
|
18 |
from omegaconf import OmegaConf
|
19 |
|
20 |
-
from
|
21 |
|
22 |
# register eval resolver
|
23 |
OmegaConf.register_new_resolver("eval", eval)
|
@@ -42,7 +42,7 @@ logger.add(sys.stderr, format=logger_format)
|
|
42 |
@lru_cache(maxsize=1)
|
43 |
def get_model(
|
44 |
config_name: str = "firefly_gan_vq",
|
45 |
-
checkpoint_path: str = "checkpoints/fish-speech-1.
|
46 |
device: str | torch.device = "cuda",
|
47 |
):
|
48 |
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
@@ -133,7 +133,7 @@ def process_batch(files: list[Path], model) -> float:
|
|
133 |
@click.option("--config-name", default="firefly_gan_vq")
|
134 |
@click.option(
|
135 |
"--checkpoint-path",
|
136 |
-
default="checkpoints/fish-speech-1.
|
137 |
)
|
138 |
@click.option("--batch-size", default=64)
|
139 |
@click.option("--filelist", default=None, type=Path)
|
|
|
17 |
from loguru import logger
|
18 |
from omegaconf import OmegaConf
|
19 |
|
20 |
+
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
21 |
|
22 |
# register eval resolver
|
23 |
OmegaConf.register_new_resolver("eval", eval)
|
|
|
42 |
@lru_cache(maxsize=1)
|
43 |
def get_model(
|
44 |
config_name: str = "firefly_gan_vq",
|
45 |
+
checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
46 |
device: str | torch.device = "cuda",
|
47 |
):
|
48 |
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
|
|
133 |
@click.option("--config-name", default="firefly_gan_vq")
|
134 |
@click.option(
|
135 |
"--checkpoint-path",
|
136 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
137 |
)
|
138 |
@click.option("--batch-size", default=64)
|
139 |
@click.option("--filelist", default=None, type=Path)
|
tools/vqgan/inference.py
CHANGED
@@ -11,7 +11,7 @@ from hydra.utils import instantiate
|
|
11 |
from loguru import logger
|
12 |
from omegaconf import OmegaConf
|
13 |
|
14 |
-
from
|
15 |
|
16 |
# register eval resolver
|
17 |
OmegaConf.register_new_resolver("eval", eval)
|
@@ -59,7 +59,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
|
|
59 |
@click.option("--config-name", default="firefly_gan_vq")
|
60 |
@click.option(
|
61 |
"--checkpoint-path",
|
62 |
-
default="checkpoints/fish-speech-1.
|
63 |
)
|
64 |
@click.option(
|
65 |
"--device",
|
@@ -103,7 +103,9 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
|
|
103 |
|
104 |
# Restore
|
105 |
feature_lengths = torch.tensor([indices.shape[1]], device=device)
|
106 |
-
fake_audios = model.decode(
|
|
|
|
|
107 |
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
|
108 |
|
109 |
logger.info(
|
|
|
11 |
from loguru import logger
|
12 |
from omegaconf import OmegaConf
|
13 |
|
14 |
+
from tools.file import AUDIO_EXTENSIONS
|
15 |
|
16 |
# register eval resolver
|
17 |
OmegaConf.register_new_resolver("eval", eval)
|
|
|
59 |
@click.option("--config-name", default="firefly_gan_vq")
|
60 |
@click.option(
|
61 |
"--checkpoint-path",
|
62 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
63 |
)
|
64 |
@click.option(
|
65 |
"--device",
|
|
|
103 |
|
104 |
# Restore
|
105 |
feature_lengths = torch.tensor([indices.shape[1]], device=device)
|
106 |
+
fake_audios, _ = model.decode(
|
107 |
+
indices=indices[None], feature_lengths=feature_lengths
|
108 |
+
)
|
109 |
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
|
110 |
|
111 |
logger.info(
|
tools/webui.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import html
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import queue
|
6 |
+
import wave
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from functools import partial
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import librosa
|
13 |
+
import numpy as np
|
14 |
+
import pyrootutils
|
15 |
+
import torch
|
16 |
+
from loguru import logger
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
20 |
+
|
21 |
+
|
22 |
+
from fish_speech.i18n import i18n
|
23 |
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
24 |
+
from fish_speech.utils import autocast_exclude_mps
|
25 |
+
from tools.api import decode_vq_tokens, encode_reference
|
26 |
+
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
27 |
+
from tools.llama.generate import (
|
28 |
+
GenerateRequest,
|
29 |
+
GenerateResponse,
|
30 |
+
WrappedGenerateResponse,
|
31 |
+
launch_thread_safe_queue,
|
32 |
+
)
|
33 |
+
from tools.vqgan.inference import load_model as load_decoder_model
|
34 |
+
|
35 |
+
# Make einx happy
|
36 |
+
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
37 |
+
|
38 |
+
|
39 |
+
HEADER_MD = f"""# Fish Speech
|
40 |
+
|
41 |
+
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
|
42 |
+
|
43 |
+
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
|
44 |
+
|
45 |
+
{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
|
46 |
+
|
47 |
+
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
|
48 |
+
"""
|
49 |
+
|
50 |
+
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
51 |
+
SPACE_IMPORTED = False
|
52 |
+
|
53 |
+
|
54 |
+
def build_html_error_message(error):
|
55 |
+
return f"""
|
56 |
+
<div style="color: red;
|
57 |
+
font-weight: bold;">
|
58 |
+
{html.escape(str(error))}
|
59 |
+
</div>
|
60 |
+
"""
|
61 |
+
|
62 |
+
|
63 |
+
@torch.inference_mode()
|
64 |
+
def inference(
|
65 |
+
text,
|
66 |
+
enable_reference_audio,
|
67 |
+
reference_audio,
|
68 |
+
reference_text,
|
69 |
+
max_new_tokens,
|
70 |
+
chunk_length,
|
71 |
+
top_p,
|
72 |
+
repetition_penalty,
|
73 |
+
temperature,
|
74 |
+
streaming=False,
|
75 |
+
):
|
76 |
+
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
77 |
+
return (
|
78 |
+
None,
|
79 |
+
None,
|
80 |
+
i18n("Text is too long, please keep it under {} characters.").format(
|
81 |
+
args.max_gradio_length
|
82 |
+
),
|
83 |
+
)
|
84 |
+
|
85 |
+
# Parse reference audio aka prompt
|
86 |
+
prompt_tokens = encode_reference(
|
87 |
+
decoder_model=decoder_model,
|
88 |
+
reference_audio=reference_audio,
|
89 |
+
enable_reference_audio=enable_reference_audio,
|
90 |
+
)
|
91 |
+
|
92 |
+
# LLAMA Inference
|
93 |
+
request = dict(
|
94 |
+
device=decoder_model.device,
|
95 |
+
max_new_tokens=max_new_tokens,
|
96 |
+
text=text,
|
97 |
+
top_p=top_p,
|
98 |
+
repetition_penalty=repetition_penalty,
|
99 |
+
temperature=temperature,
|
100 |
+
compile=args.compile,
|
101 |
+
iterative_prompt=chunk_length > 0,
|
102 |
+
chunk_length=chunk_length,
|
103 |
+
max_length=2048,
|
104 |
+
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
105 |
+
prompt_text=reference_text if enable_reference_audio else None,
|
106 |
+
)
|
107 |
+
|
108 |
+
response_queue = queue.Queue()
|
109 |
+
llama_queue.put(
|
110 |
+
GenerateRequest(
|
111 |
+
request=request,
|
112 |
+
response_queue=response_queue,
|
113 |
+
)
|
114 |
+
)
|
115 |
+
|
116 |
+
if streaming:
|
117 |
+
yield wav_chunk_header(), None, None
|
118 |
+
|
119 |
+
segments = []
|
120 |
+
|
121 |
+
while True:
|
122 |
+
result: WrappedGenerateResponse = response_queue.get()
|
123 |
+
if result.status == "error":
|
124 |
+
yield None, None, build_html_error_message(result.response)
|
125 |
+
break
|
126 |
+
|
127 |
+
result: GenerateResponse = result.response
|
128 |
+
if result.action == "next":
|
129 |
+
break
|
130 |
+
|
131 |
+
with autocast_exclude_mps(
|
132 |
+
device_type=decoder_model.device.type, dtype=args.precision
|
133 |
+
):
|
134 |
+
fake_audios = decode_vq_tokens(
|
135 |
+
decoder_model=decoder_model,
|
136 |
+
codes=result.codes,
|
137 |
+
)
|
138 |
+
|
139 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
140 |
+
segments.append(fake_audios)
|
141 |
+
|
142 |
+
if streaming:
|
143 |
+
yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
|
144 |
+
|
145 |
+
if len(segments) == 0:
|
146 |
+
return (
|
147 |
+
None,
|
148 |
+
None,
|
149 |
+
build_html_error_message(
|
150 |
+
i18n("No audio generated, please check the input text.")
|
151 |
+
),
|
152 |
+
)
|
153 |
+
|
154 |
+
# No matter streaming or not, we need to return the final audio
|
155 |
+
audio = np.concatenate(segments, axis=0)
|
156 |
+
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
157 |
+
|
158 |
+
if torch.cuda.is_available():
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
gc.collect()
|
161 |
+
|
162 |
+
|
163 |
+
def inference_with_auto_rerank(
|
164 |
+
text,
|
165 |
+
enable_reference_audio,
|
166 |
+
reference_audio,
|
167 |
+
reference_text,
|
168 |
+
max_new_tokens,
|
169 |
+
chunk_length,
|
170 |
+
top_p,
|
171 |
+
repetition_penalty,
|
172 |
+
temperature,
|
173 |
+
use_auto_rerank,
|
174 |
+
streaming=False,
|
175 |
+
):
|
176 |
+
|
177 |
+
max_attempts = 2 if use_auto_rerank else 1
|
178 |
+
best_wer = float("inf")
|
179 |
+
best_audio = None
|
180 |
+
best_sample_rate = None
|
181 |
+
|
182 |
+
for attempt in range(max_attempts):
|
183 |
+
audio_generator = inference(
|
184 |
+
text,
|
185 |
+
enable_reference_audio,
|
186 |
+
reference_audio,
|
187 |
+
reference_text,
|
188 |
+
max_new_tokens,
|
189 |
+
chunk_length,
|
190 |
+
top_p,
|
191 |
+
repetition_penalty,
|
192 |
+
temperature,
|
193 |
+
streaming=False,
|
194 |
+
)
|
195 |
+
|
196 |
+
# 获取音频数据
|
197 |
+
for _ in audio_generator:
|
198 |
+
pass
|
199 |
+
_, (sample_rate, audio), message = _
|
200 |
+
|
201 |
+
if audio is None:
|
202 |
+
return None, None, message
|
203 |
+
|
204 |
+
if not use_auto_rerank:
|
205 |
+
return None, (sample_rate, audio), None
|
206 |
+
|
207 |
+
asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
|
208 |
+
wer = calculate_wer(text, asr_result["text"])
|
209 |
+
if wer <= 0.3 and not asr_result["huge_gap"]:
|
210 |
+
return None, (sample_rate, audio), None
|
211 |
+
|
212 |
+
if wer < best_wer:
|
213 |
+
best_wer = wer
|
214 |
+
best_audio = audio
|
215 |
+
best_sample_rate = sample_rate
|
216 |
+
|
217 |
+
if attempt == max_attempts - 1:
|
218 |
+
break
|
219 |
+
|
220 |
+
return None, (best_sample_rate, best_audio), None
|
221 |
+
|
222 |
+
|
223 |
+
inference_stream = partial(inference, streaming=True)
|
224 |
+
|
225 |
+
n_audios = 4
|
226 |
+
|
227 |
+
global_audio_list = []
|
228 |
+
global_error_list = []
|
229 |
+
|
230 |
+
|
231 |
+
def inference_wrapper(
|
232 |
+
text,
|
233 |
+
enable_reference_audio,
|
234 |
+
reference_audio,
|
235 |
+
reference_text,
|
236 |
+
max_new_tokens,
|
237 |
+
chunk_length,
|
238 |
+
top_p,
|
239 |
+
repetition_penalty,
|
240 |
+
temperature,
|
241 |
+
batch_infer_num,
|
242 |
+
if_load_asr_model,
|
243 |
+
):
|
244 |
+
audios = []
|
245 |
+
errors = []
|
246 |
+
|
247 |
+
for _ in range(batch_infer_num):
|
248 |
+
result = inference_with_auto_rerank(
|
249 |
+
text,
|
250 |
+
enable_reference_audio,
|
251 |
+
reference_audio,
|
252 |
+
reference_text,
|
253 |
+
max_new_tokens,
|
254 |
+
chunk_length,
|
255 |
+
top_p,
|
256 |
+
repetition_penalty,
|
257 |
+
temperature,
|
258 |
+
if_load_asr_model,
|
259 |
+
)
|
260 |
+
|
261 |
+
_, audio_data, error_message = result
|
262 |
+
|
263 |
+
audios.append(
|
264 |
+
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
265 |
+
)
|
266 |
+
errors.append(
|
267 |
+
gr.HTML(value=error_message if error_message else None, visible=True),
|
268 |
+
)
|
269 |
+
|
270 |
+
for _ in range(batch_infer_num, n_audios):
|
271 |
+
audios.append(
|
272 |
+
gr.Audio(value=None, visible=False),
|
273 |
+
)
|
274 |
+
errors.append(
|
275 |
+
gr.HTML(value=None, visible=False),
|
276 |
+
)
|
277 |
+
|
278 |
+
return None, *audios, *errors
|
279 |
+
|
280 |
+
|
281 |
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
282 |
+
buffer = io.BytesIO()
|
283 |
+
|
284 |
+
with wave.open(buffer, "wb") as wav_file:
|
285 |
+
wav_file.setnchannels(channels)
|
286 |
+
wav_file.setsampwidth(bit_depth // 8)
|
287 |
+
wav_file.setframerate(sample_rate)
|
288 |
+
|
289 |
+
wav_header_bytes = buffer.getvalue()
|
290 |
+
buffer.close()
|
291 |
+
return wav_header_bytes
|
292 |
+
|
293 |
+
|
294 |
+
def normalize_text(user_input, use_normalization):
|
295 |
+
if use_normalization:
|
296 |
+
return ChnNormedText(raw_text=user_input).normalize()
|
297 |
+
else:
|
298 |
+
return user_input
|
299 |
+
|
300 |
+
|
301 |
+
asr_model = None
|
302 |
+
|
303 |
+
|
304 |
+
def change_if_load_asr_model(if_load):
|
305 |
+
global asr_model
|
306 |
+
|
307 |
+
if if_load:
|
308 |
+
gr.Warning("Loading faster whisper model...")
|
309 |
+
if asr_model is None:
|
310 |
+
asr_model = load_model()
|
311 |
+
return gr.Checkbox(label="Unload faster whisper model", value=if_load)
|
312 |
+
|
313 |
+
if if_load is False:
|
314 |
+
gr.Warning("Unloading faster whisper model...")
|
315 |
+
del asr_model
|
316 |
+
asr_model = None
|
317 |
+
if torch.cuda.is_available():
|
318 |
+
torch.cuda.empty_cache()
|
319 |
+
gc.collect()
|
320 |
+
return gr.Checkbox(label="Load faster whisper model", value=if_load)
|
321 |
+
|
322 |
+
|
323 |
+
def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
|
324 |
+
if if_load and asr_model is not None:
|
325 |
+
if (
|
326 |
+
if_auto_label
|
327 |
+
and enable_ref
|
328 |
+
and ref_audio is not None
|
329 |
+
and ref_text.strip() == ""
|
330 |
+
):
|
331 |
+
data, sample_rate = librosa.load(ref_audio)
|
332 |
+
res = batch_asr(asr_model, [data], sample_rate)[0]
|
333 |
+
ref_text = res["text"]
|
334 |
+
else:
|
335 |
+
gr.Warning("Whisper model not loaded!")
|
336 |
+
|
337 |
+
return gr.Textbox(value=ref_text)
|
338 |
+
|
339 |
+
|
340 |
+
def build_app():
|
341 |
+
with gr.Blocks(theme=gr.themes.Base()) as app:
|
342 |
+
gr.Markdown(HEADER_MD)
|
343 |
+
|
344 |
+
# Use light theme by default
|
345 |
+
app.load(
|
346 |
+
None,
|
347 |
+
None,
|
348 |
+
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
349 |
+
% args.theme,
|
350 |
+
)
|
351 |
+
|
352 |
+
# Inference
|
353 |
+
with gr.Row():
|
354 |
+
with gr.Column(scale=3):
|
355 |
+
text = gr.Textbox(
|
356 |
+
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
357 |
+
)
|
358 |
+
refined_text = gr.Textbox(
|
359 |
+
label=i18n("Realtime Transform Text"),
|
360 |
+
placeholder=i18n(
|
361 |
+
"Normalization Result Preview (Currently Only Chinese)"
|
362 |
+
),
|
363 |
+
lines=5,
|
364 |
+
interactive=False,
|
365 |
+
)
|
366 |
+
|
367 |
+
with gr.Row():
|
368 |
+
if_refine_text = gr.Checkbox(
|
369 |
+
label=i18n("Text Normalization"),
|
370 |
+
value=False,
|
371 |
+
scale=1,
|
372 |
+
)
|
373 |
+
|
374 |
+
if_load_asr_model = gr.Checkbox(
|
375 |
+
label=i18n("Load / Unload ASR model for auto-reranking"),
|
376 |
+
value=False,
|
377 |
+
scale=3,
|
378 |
+
)
|
379 |
+
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Tab(label=i18n("Advanced Config")):
|
382 |
+
chunk_length = gr.Slider(
|
383 |
+
label=i18n("Iterative Prompt Length, 0 means off"),
|
384 |
+
minimum=50,
|
385 |
+
maximum=300,
|
386 |
+
value=200,
|
387 |
+
step=8,
|
388 |
+
)
|
389 |
+
|
390 |
+
max_new_tokens = gr.Slider(
|
391 |
+
label=i18n("Maximum tokens per batch, 0 means no limit"),
|
392 |
+
minimum=0,
|
393 |
+
maximum=2048,
|
394 |
+
value=1024, # 0 means no limit
|
395 |
+
step=8,
|
396 |
+
)
|
397 |
+
|
398 |
+
top_p = gr.Slider(
|
399 |
+
label="Top-P",
|
400 |
+
minimum=0.6,
|
401 |
+
maximum=0.9,
|
402 |
+
value=0.7,
|
403 |
+
step=0.01,
|
404 |
+
)
|
405 |
+
|
406 |
+
repetition_penalty = gr.Slider(
|
407 |
+
label=i18n("Repetition Penalty"),
|
408 |
+
minimum=1,
|
409 |
+
maximum=1.5,
|
410 |
+
value=1.2,
|
411 |
+
step=0.01,
|
412 |
+
)
|
413 |
+
|
414 |
+
temperature = gr.Slider(
|
415 |
+
label="Temperature",
|
416 |
+
minimum=0.6,
|
417 |
+
maximum=0.9,
|
418 |
+
value=0.7,
|
419 |
+
step=0.01,
|
420 |
+
)
|
421 |
+
|
422 |
+
with gr.Tab(label=i18n("Reference Audio")):
|
423 |
+
gr.Markdown(
|
424 |
+
i18n(
|
425 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
426 |
+
)
|
427 |
+
)
|
428 |
+
|
429 |
+
enable_reference_audio = gr.Checkbox(
|
430 |
+
label=i18n("Enable Reference Audio"),
|
431 |
+
)
|
432 |
+
reference_audio = gr.Audio(
|
433 |
+
label=i18n("Reference Audio"),
|
434 |
+
type="filepath",
|
435 |
+
)
|
436 |
+
with gr.Row():
|
437 |
+
if_auto_label = gr.Checkbox(
|
438 |
+
label=i18n("Auto Labeling"),
|
439 |
+
min_width=100,
|
440 |
+
scale=0,
|
441 |
+
value=False,
|
442 |
+
)
|
443 |
+
reference_text = gr.Textbox(
|
444 |
+
label=i18n("Reference Text"),
|
445 |
+
lines=1,
|
446 |
+
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
447 |
+
value="",
|
448 |
+
)
|
449 |
+
with gr.Tab(label=i18n("Batch Inference")):
|
450 |
+
batch_infer_num = gr.Slider(
|
451 |
+
label="Batch infer nums",
|
452 |
+
minimum=1,
|
453 |
+
maximum=n_audios,
|
454 |
+
step=1,
|
455 |
+
value=1,
|
456 |
+
)
|
457 |
+
|
458 |
+
with gr.Column(scale=3):
|
459 |
+
for _ in range(n_audios):
|
460 |
+
with gr.Row():
|
461 |
+
error = gr.HTML(
|
462 |
+
label=i18n("Error Message"),
|
463 |
+
visible=True if _ == 0 else False,
|
464 |
+
)
|
465 |
+
global_error_list.append(error)
|
466 |
+
with gr.Row():
|
467 |
+
audio = gr.Audio(
|
468 |
+
label=i18n("Generated Audio"),
|
469 |
+
type="numpy",
|
470 |
+
interactive=False,
|
471 |
+
visible=True if _ == 0 else False,
|
472 |
+
)
|
473 |
+
global_audio_list.append(audio)
|
474 |
+
|
475 |
+
with gr.Row():
|
476 |
+
stream_audio = gr.Audio(
|
477 |
+
label=i18n("Streaming Audio"),
|
478 |
+
streaming=True,
|
479 |
+
autoplay=True,
|
480 |
+
interactive=False,
|
481 |
+
show_download_button=True,
|
482 |
+
)
|
483 |
+
with gr.Row():
|
484 |
+
with gr.Column(scale=3):
|
485 |
+
generate = gr.Button(
|
486 |
+
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
487 |
+
)
|
488 |
+
generate_stream = gr.Button(
|
489 |
+
value="\U0001F3A7 " + i18n("Streaming Generate"),
|
490 |
+
variant="primary",
|
491 |
+
)
|
492 |
+
|
493 |
+
text.input(
|
494 |
+
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
495 |
+
)
|
496 |
+
|
497 |
+
if_load_asr_model.change(
|
498 |
+
fn=change_if_load_asr_model,
|
499 |
+
inputs=[if_load_asr_model],
|
500 |
+
outputs=[if_load_asr_model],
|
501 |
+
)
|
502 |
+
|
503 |
+
if_auto_label.change(
|
504 |
+
fn=lambda: gr.Textbox(value=""),
|
505 |
+
inputs=[],
|
506 |
+
outputs=[reference_text],
|
507 |
+
).then(
|
508 |
+
fn=change_if_auto_label,
|
509 |
+
inputs=[
|
510 |
+
if_load_asr_model,
|
511 |
+
if_auto_label,
|
512 |
+
enable_reference_audio,
|
513 |
+
reference_audio,
|
514 |
+
reference_text,
|
515 |
+
],
|
516 |
+
outputs=[reference_text],
|
517 |
+
)
|
518 |
+
|
519 |
+
# # Submit
|
520 |
+
generate.click(
|
521 |
+
inference_wrapper,
|
522 |
+
[
|
523 |
+
refined_text,
|
524 |
+
enable_reference_audio,
|
525 |
+
reference_audio,
|
526 |
+
reference_text,
|
527 |
+
max_new_tokens,
|
528 |
+
chunk_length,
|
529 |
+
top_p,
|
530 |
+
repetition_penalty,
|
531 |
+
temperature,
|
532 |
+
batch_infer_num,
|
533 |
+
if_load_asr_model,
|
534 |
+
],
|
535 |
+
[stream_audio, *global_audio_list, *global_error_list],
|
536 |
+
concurrency_limit=1,
|
537 |
+
)
|
538 |
+
|
539 |
+
generate_stream.click(
|
540 |
+
inference_stream,
|
541 |
+
[
|
542 |
+
refined_text,
|
543 |
+
enable_reference_audio,
|
544 |
+
reference_audio,
|
545 |
+
reference_text,
|
546 |
+
max_new_tokens,
|
547 |
+
chunk_length,
|
548 |
+
top_p,
|
549 |
+
repetition_penalty,
|
550 |
+
temperature,
|
551 |
+
],
|
552 |
+
[stream_audio, global_audio_list[0], global_error_list[0]],
|
553 |
+
concurrency_limit=10,
|
554 |
+
)
|
555 |
+
return app
|
556 |
+
|
557 |
+
|
558 |
+
def parse_args():
|
559 |
+
parser = ArgumentParser()
|
560 |
+
parser.add_argument(
|
561 |
+
"--llama-checkpoint-path",
|
562 |
+
type=Path,
|
563 |
+
default="checkpoints/fish-speech-1.4",
|
564 |
+
)
|
565 |
+
parser.add_argument(
|
566 |
+
"--decoder-checkpoint-path",
|
567 |
+
type=Path,
|
568 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
569 |
+
)
|
570 |
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
571 |
+
parser.add_argument("--device", type=str, default="cuda")
|
572 |
+
parser.add_argument("--half", action="store_true")
|
573 |
+
parser.add_argument("--compile", action="store_true")
|
574 |
+
parser.add_argument("--max-gradio-length", type=int, default=0)
|
575 |
+
parser.add_argument("--theme", type=str, default="light")
|
576 |
+
|
577 |
+
return parser.parse_args()
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
args = parse_args()
|
582 |
+
args.precision = torch.half if args.half else torch.bfloat16
|
583 |
+
|
584 |
+
logger.info("Loading Llama model...")
|
585 |
+
llama_queue = launch_thread_safe_queue(
|
586 |
+
checkpoint_path=args.llama_checkpoint_path,
|
587 |
+
device=args.device,
|
588 |
+
precision=args.precision,
|
589 |
+
compile=args.compile,
|
590 |
+
)
|
591 |
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
592 |
+
|
593 |
+
decoder_model = load_decoder_model(
|
594 |
+
config_name=args.decoder_config_name,
|
595 |
+
checkpoint_path=args.decoder_checkpoint_path,
|
596 |
+
device=args.device,
|
597 |
+
)
|
598 |
+
|
599 |
+
logger.info("Decoder model loaded, warming up...")
|
600 |
+
|
601 |
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
602 |
+
list(
|
603 |
+
inference(
|
604 |
+
text="Hello, world!",
|
605 |
+
enable_reference_audio=False,
|
606 |
+
reference_audio=None,
|
607 |
+
reference_text="",
|
608 |
+
max_new_tokens=0,
|
609 |
+
chunk_length=100,
|
610 |
+
top_p=0.7,
|
611 |
+
repetition_penalty=1.2,
|
612 |
+
temperature=0.7,
|
613 |
+
)
|
614 |
+
)
|
615 |
+
|
616 |
+
logger.info("Warming up done, launching the web UI...")
|
617 |
+
|
618 |
+
app = build_app()
|
619 |
+
app.launch(show_api=True)
|
tools/whisper_asr.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Used to transcribe all audio files in one folder into another folder.
|
3 |
+
e.g.
|
4 |
+
Directory structure:
|
5 |
+
--pre_data_root
|
6 |
+
----SP_1
|
7 |
+
------01.wav
|
8 |
+
------02.wav
|
9 |
+
------......
|
10 |
+
----SP_2
|
11 |
+
------01.wav
|
12 |
+
------02.wav
|
13 |
+
------......
|
14 |
+
Use
|
15 |
+
python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
|
16 |
+
to transcribe the first speaker.
|
17 |
+
|
18 |
+
Use
|
19 |
+
python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
|
20 |
+
to transcribe the second speaker.
|
21 |
+
|
22 |
+
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import re
|
26 |
+
from pathlib import Path
|
27 |
+
|
28 |
+
import click
|
29 |
+
import soundfile as sf
|
30 |
+
from faster_whisper import WhisperModel
|
31 |
+
from loguru import logger
|
32 |
+
from pydub import AudioSegment
|
33 |
+
from tqdm import tqdm
|
34 |
+
|
35 |
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
36 |
+
|
37 |
+
|
38 |
+
@click.command()
|
39 |
+
@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
|
40 |
+
@click.option(
|
41 |
+
"--compute-type",
|
42 |
+
default="float16",
|
43 |
+
help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
|
44 |
+
)
|
45 |
+
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
46 |
+
@click.option(
|
47 |
+
"--save-dir", required=True, help="Directory to save processed audio files"
|
48 |
+
)
|
49 |
+
@click.option(
|
50 |
+
"--sample-rate",
|
51 |
+
default=44100,
|
52 |
+
type=int,
|
53 |
+
help="Output sample rate, default to input sample rate",
|
54 |
+
)
|
55 |
+
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
56 |
+
@click.option("--language", default="auto", help="Language of the transcription")
|
57 |
+
@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
|
58 |
+
def main(
|
59 |
+
model_size,
|
60 |
+
compute_type,
|
61 |
+
audio_dir,
|
62 |
+
save_dir,
|
63 |
+
sample_rate,
|
64 |
+
device,
|
65 |
+
language,
|
66 |
+
initial_prompt,
|
67 |
+
):
|
68 |
+
logger.info("Loading / Downloading Faster Whisper model...")
|
69 |
+
|
70 |
+
model = WhisperModel(
|
71 |
+
model_size,
|
72 |
+
device=device,
|
73 |
+
compute_type=compute_type,
|
74 |
+
download_root="faster_whisper",
|
75 |
+
)
|
76 |
+
|
77 |
+
logger.info("Model loaded.")
|
78 |
+
|
79 |
+
save_path = Path(save_dir)
|
80 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
81 |
+
|
82 |
+
audio_files = list_files(
|
83 |
+
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
84 |
+
)
|
85 |
+
|
86 |
+
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
87 |
+
file_stem = file_path.stem
|
88 |
+
file_suffix = file_path.suffix
|
89 |
+
|
90 |
+
rel_path = Path(file_path).relative_to(audio_dir)
|
91 |
+
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
92 |
+
|
93 |
+
audio = AudioSegment.from_file(file_path)
|
94 |
+
|
95 |
+
segments, info = model.transcribe(
|
96 |
+
file_path,
|
97 |
+
beam_size=5,
|
98 |
+
language=None if language == "auto" else language,
|
99 |
+
initial_prompt=initial_prompt,
|
100 |
+
)
|
101 |
+
|
102 |
+
print(
|
103 |
+
"Detected language '%s' with probability %f"
|
104 |
+
% (info.language, info.language_probability)
|
105 |
+
)
|
106 |
+
print("Total len(ms): ", len(audio))
|
107 |
+
|
108 |
+
whole_text = None
|
109 |
+
for segment in segments:
|
110 |
+
id, start, end, text = (
|
111 |
+
segment.id,
|
112 |
+
segment.start,
|
113 |
+
segment.end,
|
114 |
+
segment.text,
|
115 |
+
)
|
116 |
+
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
|
117 |
+
if not whole_text:
|
118 |
+
whole_text = text
|
119 |
+
else:
|
120 |
+
whole_text += ", " + text
|
121 |
+
|
122 |
+
whole_text += "."
|
123 |
+
|
124 |
+
audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
|
125 |
+
audio.export(audio_save_path, format=file_suffix[1:])
|
126 |
+
print(f"Exported {audio_save_path}")
|
127 |
+
|
128 |
+
transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
|
129 |
+
with open(
|
130 |
+
transcript_save_path,
|
131 |
+
"w",
|
132 |
+
encoding="utf-8",
|
133 |
+
) as f:
|
134 |
+
f.write(whole_text)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
main()
|
139 |
+
exit(0)
|
140 |
+
|
141 |
+
audio = AudioSegment.from_wav(
|
142 |
+
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
|
143 |
+
)
|
144 |
+
|
145 |
+
model_size = "large-v3"
|
146 |
+
|
147 |
+
model = WhisperModel(
|
148 |
+
model_size,
|
149 |
+
device="cuda",
|
150 |
+
compute_type="float16",
|
151 |
+
download_root="faster_whisper",
|
152 |
+
)
|
153 |
+
|
154 |
+
segments, info = model.transcribe(
|
155 |
+
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
|
156 |
+
beam_size=5,
|
157 |
+
)
|
158 |
+
|
159 |
+
print(
|
160 |
+
"Detected language '%s' with probability %f"
|
161 |
+
% (info.language, info.language_probability)
|
162 |
+
)
|
163 |
+
print("Total len(ms): ", len(audio))
|
164 |
+
|
165 |
+
for i, segment in enumerate(segments):
|
166 |
+
print(
|
167 |
+
"Segment %03d [%.2fs -> %.2fs] %s"
|
168 |
+
% (i, segment.start, segment.end, segment.text)
|
169 |
+
)
|
170 |
+
start_ms = int(segment.start * 1000)
|
171 |
+
end_ms = int(segment.end * 1000)
|
172 |
+
segment_audio = audio[start_ms:end_ms]
|
173 |
+
segment_audio.export(f"segment_{i:03d}.wav", format="wav")
|
174 |
+
print(f"Exported segment_{i:03d}.wav")
|
175 |
+
|
176 |
+
print("All segments have been exported.")
|