KevinGeng commited on
Commit
1ead8e8
1 Parent(s): fabced5
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flagged
2
+ wav
3
+ samples
4
+ wav
5
+ wav.bak
local/ASR_compare.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO:
3
+ + [x] Load Configuration
4
+ + [ ] Checking
5
+ + [ ] Better saving directory
6
+ """
7
+ import numpy as np
8
+ from pathlib import Path
9
+ import jiwer
10
+ import pdb
11
+ import torch.nn as nn
12
+ import torch
13
+ import torchaudio
14
+ from transformers import pipeline
15
+ from time import process_time, time
16
+ from pathlib import Path
17
+
18
+ # local import
19
+ import sys
20
+ from espnet2.bin.tts_inference import Text2Speech
21
+
22
+ # pdb.set_trace()
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ sys.path.append("src")
26
+
27
+ import gradio as gr
28
+
29
+ # ASR part
30
+
31
+ audio_files = [
32
+ str(x)
33
+ for x in sorted(
34
+ Path(
35
+ "/home/kevingeng/Disk2/laronix/laronix_automos/data/20230103_video"
36
+ ).glob("**/*wav")
37
+ )
38
+ ]
39
+ # audio_files = [str(x) for x in sorted(Path("./data/Patient_sil_trim_16k_normed_5_snr_40/Rainbow").glob("**/*wav"))]
40
+ transcriber = pipeline(
41
+ "automatic-speech-recognition",
42
+ model="KevinGeng/PAL_John_128_train_dev_test_seed_1",
43
+ )
44
+ old_transcriber = pipeline(
45
+ "automatic-speech-recognition", "facebook/wav2vec2-base-960h"
46
+ )
47
+ # transcriber = pipeline("automatic-speech-recognition", model="KevinGeng/PAL_John_128_p326_300_train_dev_test_seed_1")
48
+ # 【Female】kan-bayashi ljspeech parallel wavegan
49
+ # tts_model = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits")
50
+ # 【Male】fastspeech2-en-200_speaker-cv4, hifigan vocoder
51
+ # pdb.set_trace()
52
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
53
+ from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
54
+
55
+ # @title English multi-speaker pretrained model { run: "auto" }
56
+ lang = "English"
57
+ tag = "kan-bayashi/libritts_xvector_vits"
58
+ # vits needs no
59
+ vocoder_tag = "parallel_wavegan/vctk_parallel_wavegan.v1.long" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"}
60
+ from espnet2.bin.tts_inference import Text2Speech
61
+ from espnet2.utils.types import str_or_none
62
+
63
+ text2speech = Text2Speech.from_pretrained(
64
+ model_tag=str_or_none(tag),
65
+ vocoder_tag=str_or_none(vocoder_tag),
66
+ device="cuda",
67
+ use_att_constraint=False,
68
+ backward_window=1,
69
+ forward_window=3,
70
+ speed_control_alpha=1.0,
71
+ )
72
+
73
+
74
+ import glob
75
+ import os
76
+ import numpy as np
77
+ import kaldiio
78
+
79
+ # Get model directory path
80
+ from espnet_model_zoo.downloader import ModelDownloader
81
+
82
+ d = ModelDownloader()
83
+ model_dir = os.path.dirname(d.download_and_unpack(tag)["train_config"])
84
+
85
+ # Speaker x-vector selection
86
+
87
+ xvector_ark = [
88
+ p
89
+ for p in glob.glob(
90
+ f"{model_dir}/../../dump/**/spk_xvector.ark", recursive=True
91
+ )
92
+ if "tr" in p
93
+ ][0]
94
+ xvectors = {k: v for k, v in kaldiio.load_ark(xvector_ark)}
95
+ spks = list(xvectors.keys())
96
+
97
+ male_spks = {
98
+ "M1": "2300_131720",
99
+ "M2": "1320_122612",
100
+ "M3": "1188_133604",
101
+ "M4": "61_70970",
102
+ }
103
+ female_spks = {"F1": "2961_961", "F2": "8463_287645", "F3": "121_121726"}
104
+ spks = dict(male_spks, **female_spks)
105
+ spk_names = sorted(spks.keys())
106
+
107
+
108
+ ## 20230224 Mousa: No reference,
109
+ def ASRold(audio_file):
110
+ reg_text = old_transcriber(audio_file)["text"]
111
+ return reg_text
112
+
113
+
114
+ def ASRnew(audio_file):
115
+ reg_text = transcriber(audio_file)["text"]
116
+ return reg_text
117
+
118
+
119
+ # def ref_reg_callback(audio_file, spk_name, ref_text):
120
+ # reg_text = ref_text
121
+ # return audio_file, spk_name, reg_text
122
+
123
+ reference_textbox = gr.Textbox(
124
+ value="",
125
+ placeholder="Input reference here",
126
+ label="Reference",
127
+ )
128
+
129
+ recognization_textbox = gr.Textbox(
130
+ value="",
131
+ placeholder="Output recognization here",
132
+ label="recognization_textbox",
133
+ )
134
+
135
+ speaker_option = gr.Radio(choices=spk_names, label="Speaker")
136
+ # speaker_profiles = {
137
+ # "Male_1": "speaker_icons/male1.png",
138
+ # "Male_2": "speaker_icons/male2.png",
139
+ # "Female_1": "speaker_icons/female1.png",
140
+ # "Female_2": "speaker_icons/female2.png",
141
+ # }
142
+
143
+ # speaker_option = gr.Image(label="Choose your speaker profile",
144
+ # image_mode="RGB",
145
+ # options=speaker_profiles
146
+ # )
147
+
148
+ input_audio = gr.Audio(
149
+ source="upload", type="filepath", label="Audio_to_Evaluate"
150
+ )
151
+ output_audio = gr.Audio(
152
+ source="upload", file="filepath", label="Synthesized Audio"
153
+ )
154
+ examples = [
155
+ ["./samples/001.wav", "M1", ""],
156
+ ["./samples/002.wav", "M2", ""],
157
+ ["./samples/003.wav", "F1", ""],
158
+ ["./samples/004.wav", "F2", ""],
159
+ ]
160
+
161
+
162
+ def change_audiobox(choice):
163
+ if choice == "upload":
164
+ input_audio = gr.Audio.update(source="upload", visible=True)
165
+ elif choice == "microphone":
166
+ input_audio = gr.Audio.update(source="microphone", visible=True)
167
+ else:
168
+ input_audio = gr.Audio.update(visible=False)
169
+ return input_audio
170
+
171
+
172
+ with gr.Blocks(
173
+ analytics_enabled=False,
174
+ css=".gradio-container {background-color: #78BD91}",
175
+ ) as demo:
176
+ with gr.Column():
177
+ input_format = gr.Radio(
178
+ choices=["upload", "microphone"], label="Choose your input format"
179
+ )
180
+ input_audio = gr.Audio(
181
+ source="upload",
182
+ type="filepath",
183
+ label="Input Audio",
184
+ interactive=True,
185
+ visible=False,
186
+ )
187
+ input_format.change(
188
+ fn=change_audiobox, inputs=input_format, outputs=input_audio
189
+ )
190
+
191
+ with gr.Row():
192
+ b1 = gr.Button("Conventional Speech Recognition Engine")
193
+ old_recognization_textbox = gr.Textbox(
194
+ value="",
195
+ placeholder="Recognition output",
196
+ label="Convertional",
197
+ )
198
+ b1.click(
199
+ ASRold, inputs=[input_audio], outputs=old_recognization_textbox
200
+ )
201
+
202
+ with gr.Row():
203
+ b2 = gr.Button("Laronix Speech Recognition Engine")
204
+ new_recognization_textbox = gr.Textbox(
205
+ value="",
206
+ placeholder="Recognition output",
207
+ label="Purposed",
208
+ )
209
+
210
+ b2.click(
211
+ ASRnew, inputs=[input_audio], outputs=new_recognization_textbox
212
+ )
213
+
214
+ demo.launch(share=True)
local/app_batch.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO:
3
+ + [x] Load Configuration
4
+ + [ ] Checking
5
+ + [ ] Better saving directory
6
+ """
7
+ import numpy as np
8
+ from pathlib import Path
9
+ import jiwer
10
+ import pdb
11
+ import torch.nn as nn
12
+ import torch
13
+ import torchaudio
14
+ from transformers import pipeline
15
+ from time import process_time, time
16
+ from pathlib import Path
17
+ # local import
18
+ import sys
19
+ from espnet2.bin.tts_inference import Text2Speech
20
+
21
+ # pdb.set_trace()
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ sys.path.append("src")
25
+
26
+ # ASR part
27
+
28
+ audio_files = [str(x) for x in sorted(Path("/home/kevingeng/Disk2/laronix/laronix_automos/data/20230103_video").glob("**/*wav"))]
29
+ # audio_files = [str(x) for x in sorted(Path("/mnt/Disk2/laronix/laronix_PAL_ASR_TTS/wav/20221228_video_good_normed_5").glob("**/*wav"))]
30
+ # pdb.set_trace()
31
+ # audio_files = [str(x) for x in sorted(Path("./data/Patient_sil_trim_16k_normed_5_snr_40/Rainbow").glob("**/*wav"))]
32
+ transcriber = pipeline("automatic-speech-recognition", model="KevinGeng/PAL_John_128_train_dev_test_seed_1")
33
+ # transcriber = pipeline("automatic-speech-recognition", model="KevinGeng/PAL_John_128_p326_300_train_dev_test_seed_1")
34
+ # 【Female】kan-bayashi ljspeech parallel wavegan
35
+ # tts_model = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits")
36
+ # 【Male】fastspeech2-en-200_speaker-cv4, hifigan vocoder
37
+ # pdb.set_trace()
38
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
39
+ from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
40
+
41
+ #@title English multi-speaker pretrained model { run: "auto" }
42
+ lang = 'English'
43
+ # tag = 'kan-bayashi/vctk_multi_spk_vits' #@param ["kan-bayashi/vctk_gst_tacotron2", "kan-bayashi/vctk_gst_transformer", "kan-bayashi/vctk_xvector_tacotron2", "kan-bayashi/vctk_xvector_transformer", "kan-bayashi/vctk_xvector_conformer_fastspeech2", "kan-bayashi/vctk_gst+xvector_tacotron2", "kan-bayashi/vctk_gst+xvector_transformer", "kan-bayashi/vctk_gst+xvector_conformer_fastspeech2", "kan-bayashi/vctk_multi_spk_vits", "kan-bayashi/vctk_full_band_multi_spk_vits", "kan-bayashi/libritts_xvector_transformer", "kan-bayashi/libritts_xvector_conformer_fastspeech2", "kan-bayashi/libritts_gst+xvector_transformer", "kan-bayashi/libritts_gst+xvector_conformer_fastspeech2", "kan-bayashi/libritts_xvector_vits"] {type:"string"}
44
+ tag = 'kan-bayashi/libritts_xvector_vits'
45
+ # vits needs no
46
+ vocoder_tag = "parallel_wavegan/vctk_parallel_wavegan.v1.long" #@param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"}
47
+ from espnet2.bin.tts_inference import Text2Speech
48
+ from espnet2.utils.types import str_or_none
49
+
50
+ text2speech = Text2Speech.from_pretrained(
51
+ model_tag=str_or_none(tag),
52
+ vocoder_tag=str_or_none(vocoder_tag),
53
+ device="cuda",
54
+ use_att_constraint=False,
55
+ backward_window=1,
56
+ forward_window=3,
57
+ speed_control_alpha=1.0,
58
+ )
59
+
60
+
61
+ import glob
62
+ import os
63
+ import numpy as np
64
+ import kaldiio
65
+
66
+ # Get model directory path
67
+ from espnet_model_zoo.downloader import ModelDownloader
68
+ d = ModelDownloader()
69
+ model_dir = os.path.dirname(d.download_and_unpack(tag)["train_config"])
70
+
71
+ # Speaker x-vector selection
72
+ # pdb.set_trace()
73
+ xvector_ark = [p for p in glob.glob(f"{model_dir}/../../dump/**/spk_xvector.ark", recursive=True) if "tr" in p][0]
74
+ xvectors = {k: v for k, v in kaldiio.load_ark(xvector_ark)}
75
+ # spks = list(xvectors.keys())
76
+
77
+ male_spks = {"M1": "2300_131720", "M2": "1320_122612", "M3": "1188_133604", "M4": "61_70970"}
78
+ female_spks = {"F1": "2961_961", "F2": "8463_287645", "F3": "121_121726"}
79
+ spks = dict(male_spks, **female_spks)
80
+ spk_names = sorted(spks.keys())
81
+ # pdb.set_trace()
82
+ selected_xvectors = [xvectors[x] for x in spks.values()]
83
+ selected_xvectors_dict = dict(zip(spks.keys(), selected_xvectors))
84
+
85
+ for audio_file in audio_files:
86
+ t_start = time()
87
+ text = transcriber(audio_file)['text']
88
+ speech, sr = torchaudio.load(audio_file) # reference speech
89
+ duration = len(speech)/sr
90
+ for spks,spembs in selected_xvectors_dict.items():
91
+ wav_tensor_spembs = text2speech(text=text, speech=speech, spembs=spembs)["wav"]
92
+ torchaudio.save("./wav/" + Path(audio_file).stem + "_" + spks +"_spkembs.wav", src=wav_tensor_spembs.unsqueeze(0).to("cpu"), sample_rate=22050)
93
+
94
+ # torchaudio.save("./wav/" + Path(audio_file).stem + "_" + spk + "_dur_t_text.wav", src=wav_tensor_duration_t_text.unsqueeze(0).to("cpu"), sample_rate=22050)
requirements.txt ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ antlr4-python3-runtime==4.8
6
+ anyio==3.6.2
7
+ appdirs==1.4.4
8
+ argcomplete==2.0.0
9
+ async-timeout==4.0.2
10
+ asynctest==0.13.0
11
+ attrs==22.2.0
12
+ audioread==3.0.0
13
+ beautifulsoup4==4.11.2
14
+ bitarray==2.7.2
15
+ black==23.1.0
16
+ brotlipy==0.7.0
17
+ cchardet==2.1.7
18
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
19
+ cffi @ file:///croot/cffi_1670423208954/work
20
+ chardet==5.1.0
21
+ charset-normalizer==3.0.1
22
+ ci-sdr==0.0.2
23
+ click==8.1.3
24
+ colorama==0.4.6
25
+ ConfigArgParse==1.5.3
26
+ cryptography @ file:///croot/cryptography_1673298753778/work
27
+ ctc-segmentation==1.7.4
28
+ cycler==0.11.0
29
+ Cython==0.29.33
30
+ decorator==5.1.1
31
+ Distance==0.1.3
32
+ editdistance==0.6.2
33
+ einops==0.6.0
34
+ entrypoints==0.4
35
+ espnet==202301
36
+ espnet-model-zoo==0.1.7
37
+ espnet-tts-frontend==0.0.3
38
+ fairseq==0.12.2
39
+ fast-bss-eval==0.1.3
40
+ fastapi==0.91.0
41
+ ffmpy==0.3.0
42
+ filelock==3.9.0
43
+ flit_core @ file:///opt/conda/conda-bld/flit-core_1644941570762/work/source/flit_core
44
+ fonttools==4.38.0
45
+ frozenlist==1.3.3
46
+ fsspec==2023.1.0
47
+ g2p-en==2.1.0
48
+ gdown==4.6.3
49
+ gradio==3.18.0
50
+ h11==0.14.0
51
+ h5py==3.8.0
52
+ httpcore==0.16.3
53
+ httpx==0.23.3
54
+ huggingface-hub==0.12.0
55
+ humanfriendly==10.0
56
+ hydra-core==1.0.7
57
+ idna @ file:///croot/idna_1666125576474/work
58
+ importlib-metadata==4.13.0
59
+ importlib-resources==5.10.2
60
+ inflect==6.0.2
61
+ jaconv==0.3.3
62
+ jamo==0.4.1
63
+ Jinja2==3.1.2
64
+ jiwer==2.5.1
65
+ joblib==1.2.0
66
+ jsonschema==4.17.3
67
+ kaldiio==2.17.2
68
+ kiwisolver==1.4.4
69
+ Levenshtein==0.20.2
70
+ librosa==0.9.2
71
+ linkify-it-py==1.0.3
72
+ llvmlite==0.39.1
73
+ lxml==4.9.2
74
+ markdown-it-py==2.1.0
75
+ MarkupSafe==2.1.2
76
+ matplotlib==3.5.3
77
+ mdit-py-plugins==0.3.3
78
+ mdurl==0.1.2
79
+ mkl-fft==1.3.1
80
+ mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work
81
+ mkl-service==2.4.0
82
+ multidict==6.0.4
83
+ mypy-extensions==1.0.0
84
+ nltk==3.8.1
85
+ numba==0.56.4
86
+ numpy==1.21.6
87
+ omegaconf==2.0.6
88
+ opt-einsum==3.3.0
89
+ orjson==3.8.6
90
+ packaging==23.0
91
+ pandas==1.3.5
92
+ parallel-wavegan==0.5.5
93
+ pathspec==0.11.0
94
+ Pillow==9.3.0
95
+ pkgutil_resolve_name==1.3.10
96
+ platformdirs==3.0.0
97
+ pooch==1.6.0
98
+ portalocker==2.7.0
99
+ protobuf==3.20.1
100
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
101
+ pycryptodome==3.17
102
+ pydantic==1.10.4
103
+ pydub==0.25.1
104
+ pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
105
+ pyparsing==3.0.9
106
+ pypinyin==0.44.0
107
+ pyrsistent==0.19.3
108
+ PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work
109
+ python-dateutil==2.8.2
110
+ python-multipart==0.0.5
111
+ pytorch-wpe==0.0.1
112
+ pytz==2022.7.1
113
+ pyworld==0.3.2
114
+ PyYAML==6.0
115
+ rapidfuzz==2.13.7
116
+ regex==2022.10.31
117
+ requests==2.28.2
118
+ resampy==0.4.2
119
+ rfc3986==1.5.0
120
+ sacrebleu==2.3.1
121
+ scikit-learn==1.0.2
122
+ scipy==1.7.3
123
+ sentencepiece==0.1.97
124
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
125
+ sniffio==1.3.0
126
+ soundfile==0.11.0
127
+ soupsieve==2.4
128
+ starlette==0.24.0
129
+ tabulate==0.9.0
130
+ tensorboardX==2.6
131
+ threadpoolctl==3.1.0
132
+ tokenizers==0.13.2
133
+ toml==0.10.2
134
+ tomli==2.0.1
135
+ toolz==0.12.0
136
+ torch==1.12.1
137
+ torch-complex==0.4.3
138
+ torchaudio==0.12.1
139
+ torchvision==0.13.1
140
+ tqdm==4.64.1
141
+ transformers==4.26.1
142
+ typed-ast==1.5.4
143
+ typeguard==2.13.3
144
+ typing_extensions @ file:///croot/typing_extensions_1669924550328/work
145
+ uc-micro-py==1.0.1
146
+ Unidecode==1.3.6
147
+ urllib3 @ file:///croot/urllib3_1673575502006/work
148
+ uvicorn==0.20.0
149
+ websockets==10.4
150
+ xmltodict==0.13.0
151
+ yarl==1.8.2
152
+ yq==3.1.0
153
+ zipp==3.13.0
speaker_icons/female-4.png ADDED
speaker_icons/female-5.png ADDED
speaker_icons/female-6.png ADDED
speaker_icons/female1.png ADDED
speaker_icons/female2.png ADDED
speaker_icons/female3.png ADDED
speaker_icons/male icon.png ADDED
speaker_icons/male-4.png ADDED
speaker_icons/male1.png ADDED
speaker_icons/male2.png ADDED
speaker_icons/male3.png ADDED
speaker_icons/neutral.png ADDED
speaker_icons/profile-icons.png ADDED